Compare commits

...

3 Commits

Author SHA1 Message Date
coletdjnz
9e05210991
Add socks proxy support 2023-10-28 15:09:59 +13:00
coletdjnz
3f3e38c72b
ping to 12.0 and above 2023-10-28 12:59:45 +13:00
coletdjnz
d4e47cef7e
tests 2023-10-28 12:40:40 +13:00
9 changed files with 404 additions and 155 deletions

View File

@ -1,8 +1,8 @@
mutagen mutagen
pycryptodomex pycryptodomex
websockets
brotli; platform_python_implementation=='CPython' brotli; platform_python_implementation=='CPython'
brotlicffi; platform_python_implementation!='CPython' brotlicffi; platform_python_implementation!='CPython'
certifi certifi
requests>=2.31.0,<3 requests>=2.31.0,<3
urllib3>=1.26.17,<3 urllib3>=1.26.17,<3
websockets>=12.0

View File

@ -6,16 +6,3 @@ import pytest
from yt_dlp.networking import RequestHandler from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)

View File

View File

@ -0,0 +1,26 @@
import functools
import inspect
import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -55,6 +55,8 @@ from yt_dlp.networking.exceptions import (
from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
@ -278,11 +280,6 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode()) self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)
class TestRequestHandlerBase: class TestRequestHandlerBase:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
@ -296,7 +293,7 @@ class TestRequestHandlerBase:
cls.http_server_thread.start() cls.http_server_thread.start()
# HTTPS server # HTTPS server
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, '../testcert.pem')
cls.https_httpd = http.server.ThreadingHTTPServer( cls.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler) ('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -375,6 +372,7 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
def test_raise_http_error(self, handler): def test_raise_http_error(self, handler):
with handler() as rh: with handler() as rh:
# TODO Return HTTP status code url
for bad_status in (400, 500, 599, 302): for bad_status in (400, 500, 599, 302):
with pytest.raises(HTTPError): with pytest.raises(HTTPError):
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status))) validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status)))
@ -696,8 +694,8 @@ class TestClientCertificate:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, '../testcert.pem')
cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate') cls.certdir = os.path.join(TEST_DIR, '../testdata', 'certificate')
cacertfn = os.path.join(cls.certdir, 'ca.crt') cacertfn = os.path.join(cls.certdir, 'ca.crt')
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler) cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -748,82 +746,6 @@ class TestClientCertificate:
}) })
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsockets:
@classmethod
def setup_class(cls):
import websockets.server
async def echo(websocket):
async for message in websocket:
if message == b'headers':
await websocket.send(json.dumps(dict(websocket.request_headers)))
elif message == 'source_address':
await websocket.send(websocket.remote_address[0])
else:
await websocket.send(message)
def run():
async def main():
async with websockets.server.serve(echo, "localhost", 8765):
await asyncio.Future()
asyncio.run(main())
cls.ws_server_thread = threading.Thread(target=run)
cls.ws_server_thread.daemon = True
cls.ws_server_thread.start()
time.sleep(1) # wait for server to start
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_recv(self, handler):
with handler() as rh:
ws = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
ws.send(b'foo')
assert ws.recv() == b'foo'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [
({'timeout': 0.00001}, {}),
({}, {'timeout': 0.00001}),
])
def test_timeout(self, handler, params, extensions):
with handler(**params) as rh:
with pytest.raises(TransportError):
validate_and_send(rh, Request('ws://127.0.0.1:8765', extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
version=0, name='test', value='ytdlp', port=None, port_specified=False,
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
path_specified=True, secure=False, expires=None, discard=False, comment=None,
comment_url=None, rest={}))
with handler(cookiejar=cookiejar) as rh:
res = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
res.send(b'headers')
assert json.loads(res.recv())['cookie'] == 'test=ytdlp'
with handler() as rh:
res = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
res.send(b'headers')
assert 'cookie' not in json.loads(res.recv())
res = validate_and_send(rh, Request('ws://127.0.0.1:8765', extensions={'cookiejar': cookiejar}))
res.send(b'headers')
assert json.loads(res.recv())['cookie'] == 'test=ytdlp'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
with handler(source_address=source_address) as rh:
res = validate_and_send(
rh, Request(f'ws://127.0.0.1:8765/source_address'))
res.send('source_address')
assert source_address == res.recv()
class TestUrllibRequestHandler(TestRequestHandlerBase): class TestUrllibRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_file_urls(self, handler): def test_file_urls(self, handler):

View File

@ -210,6 +210,16 @@ class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestR
self.wfile.write(payload.encode()) self.wfile.write(payload.encode())
class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
def handle(self):
import websockets.sync.server
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
connection.send(json.dumps(self.socks_info))
connection.close()
@contextlib.contextmanager @contextlib.contextmanager
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs): def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
server = server_thread = None server = server_thread = None
@ -252,8 +262,21 @@ class HTTPSocksTestProxyContext(SocksProxyTestContext):
return json.loads(handler.send(request).read().decode()) return json.loads(handler.send(request).read().decode())
class WebSocketSocksTestProxyContext(SocksProxyTestContext):
REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
handler.validate(request)
ws = handler.send(request)
ws.send('socks_info')
socks_info = ws.recv()
ws.close()
return json.loads(socks_info)
CTX_MAP = { CTX_MAP = {
'http': HTTPSocksTestProxyContext, 'http': HTTPSocksTestProxyContext,
'ws': WebSocketSocksTestProxyContext,
} }
@ -263,7 +286,7 @@ def ctx(request):
class TestSocks4Proxy: class TestSocks4Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4_no_auth(self, handler, ctx): def test_socks4_no_auth(self, handler, ctx):
with handler() as rh: with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
@ -271,7 +294,7 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://{server_address}'}) rh, proxies={'all': f'socks4://{server_address}'})
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4_auth(self, handler, ctx): def test_socks4_auth(self, handler, ctx):
with handler() as rh: with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address: with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
@ -281,7 +304,7 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://user:@{server_address}'}) rh, proxies={'all': f'socks4://user:@{server_address}'})
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4a_ipv4_target(self, handler, ctx): def test_socks4a_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@ -289,7 +312,7 @@ class TestSocks4Proxy:
assert response['version'] == 4 assert response['version'] == 4
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1') assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4a_domain_target(self, handler, ctx): def test_socks4a_domain_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@ -298,7 +321,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] is None assert response['ipv4_address'] is None
assert response['domain_address'] == 'localhost' assert response['domain_address'] == 'localhost'
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx): def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
@ -308,7 +331,7 @@ class TestSocks4Proxy:
assert response['client_address'][0] == source_address assert response['client_address'][0] == source_address
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('reply_code', [ @pytest.mark.parametrize('reply_code', [
Socks4CD.REQUEST_REJECTED_OR_FAILED, Socks4CD.REQUEST_REJECTED_OR_FAILED,
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD, Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
@ -320,7 +343,7 @@ class TestSocks4Proxy:
with pytest.raises(ProxyError): with pytest.raises(ProxyError):
ctx.socks_info_request(rh) ctx.socks_info_request(rh)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv6_socks4_proxy(self, handler, ctx): def test_ipv6_socks4_proxy(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address: with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}) as rh: with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
@ -329,7 +352,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] == '127.0.0.1' assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_timeout(self, handler, ctx): def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address: with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh: with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
@ -339,7 +362,7 @@ class TestSocks4Proxy:
class TestSocks5Proxy: class TestSocks5Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_no_auth(self, handler, ctx): def test_socks5_no_auth(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -347,7 +370,7 @@ class TestSocks5Proxy:
assert response['auth_methods'] == [0x0] assert response['auth_methods'] == [0x0]
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_user_pass(self, handler, ctx): def test_socks5_user_pass(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address: with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
with handler() as rh: with handler() as rh:
@ -360,7 +383,7 @@ class TestSocks5Proxy:
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS] assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_ipv4_target(self, handler, ctx): def test_socks5_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -368,7 +391,7 @@ class TestSocks5Proxy:
assert response['ipv4_address'] == '127.0.0.1' assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_domain_target(self, handler, ctx): def test_socks5_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -376,7 +399,7 @@ class TestSocks5Proxy:
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1') assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5h_domain_target(self, handler, ctx): def test_socks5h_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@ -385,7 +408,7 @@ class TestSocks5Proxy:
assert response['domain_address'] == 'localhost' assert response['domain_address'] == 'localhost'
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5h_ip_target(self, handler, ctx): def test_socks5h_ip_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@ -394,7 +417,7 @@ class TestSocks5Proxy:
assert response['domain_address'] is None assert response['domain_address'] is None
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_ipv6_destination(self, handler, ctx): def test_socks5_ipv6_destination(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -402,7 +425,7 @@ class TestSocks5Proxy:
assert response['ipv6_address'] == '::1' assert response['ipv6_address'] == '::1'
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv6_socks5_proxy(self, handler, ctx): def test_ipv6_socks5_proxy(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address: with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -413,7 +436,7 @@ class TestSocks5Proxy:
# XXX: is there any feasible way of testing IPv6 source addresses? # XXX: is there any feasible way of testing IPv6 source addresses?
# Same would go for non-proxy source_address test... # Same would go for non-proxy source_address test...
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx): def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
@ -422,7 +445,7 @@ class TestSocks5Proxy:
assert response['client_address'][0] == source_address assert response['client_address'][0] == source_address
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('reply_code', [ @pytest.mark.parametrize('reply_code', [
Socks5Reply.GENERAL_FAILURE, Socks5Reply.GENERAL_FAILURE,
Socks5Reply.CONNECTION_NOT_ALLOWED, Socks5Reply.CONNECTION_NOT_ALLOWED,
@ -439,7 +462,7 @@ class TestSocks5Proxy:
with pytest.raises(ProxyError): with pytest.raises(ProxyError):
ctx.socks_info_request(rh) ctx.socks_info_request(rh)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
def test_timeout(self, handler, ctx): def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address: with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh: with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:

View File

@ -0,0 +1,258 @@
#!/usr/bin/env python3
import enum
# Allow direct execution
import os
import sys
from queue import Queue
import pytest
from websockets.datastructures import Headers
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import gzip
import http.client
import http.cookiejar
import http.server
import io
import json
import pathlib
import random
import ssl
import tempfile
import threading
import time
import urllib.error
import urllib.request
import warnings
import zlib
from email.message import Message
from http.cookiejar import CookieJar
import websockets.sync
from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3, websockets
from yt_dlp.networking import (
HEADRequest,
PUTRequest,
Request,
RequestDirector,
RequestHandler,
Response,
)
from yt_dlp.networking._urllib import UrllibRH
from yt_dlp.networking.exceptions import (
CertificateVerifyError,
HTTPError,
IncompleteRead,
NoSupportingHandlers,
ProxyError,
RequestError,
SSLError,
TransportError,
UnsupportedRequest,
)
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
def websocket_handler(websocket):
for message in websocket:
if message == 'headers':
return websocket.send(json.dumps(dict(websocket.request.headers)))
elif message == 'path':
return websocket.send(websocket.request.path)
elif message == 'source_address':
return websocket.send(websocket.remote_address[0])
else:
return websocket.send(message)
def process_request(self, request):
if request.path.startswith('/gen_'):
status = http.HTTPStatus(int(request.path[5:]))
if 300 <= status.value <= 300:
return websockets.http11.Response(
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
return self.protocol.reject(status.value, status.phrase)
return self.protocol.accept(request)
def create_websocket_server(**ws_kwargs):
import websockets.sync.server
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
ws_port = wsd.socket.getsockname()[1]
ws_server_thread = threading.Thread(target=wsd.serve_forever)
ws_server_thread.daemon = True
ws_server_thread.start()
return ws_server_thread, ws_port
def create_ws_websocket_server():
return create_websocket_server()
def create_wss_websocket_server():
certfn = os.path.join(TEST_DIR, '../testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance:
@classmethod
def setup_class(cls):
cls.ws_thread, cls.ws_port = create_ws_websocket_server()
cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
cls.wss_thread, cls.wss_port = create_wss_websocket_server()
cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler):
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
assert 'upgrade' in ws.headers
assert ws.status == 101
ws.send('foo')
assert ws.recv() == 'foo'
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_verify_cert(self, handler):
with handler() as rh:
with pytest.raises(CertificateVerifyError):
validate_and_send(rh, Request(self.wss_base_url))
with handler(verify=False) as rh:
ws = validate_and_send(rh, Request(self.wss_base_url))
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_ssl_error(self, handler):
with handler(verify=False) as rh:
with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('path,expected', [
# Unicode characters should be encoded with uppercase percent-encoding
('/中文', '/%E4%B8%AD%E6%96%87'),
# don't normalize existing percent encodings
('/%c7%9f', '/%c7%9f'),
])
def test_percent_encode(self, handler, path, expected):
with handler() as rh:
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
ws.send('path')
assert ws.recv() == expected
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_remove_dot_segments(self, handler):
with handler() as rh:
# This isn't a comprehensive test,
# but it should be enough to check whether the handler is removing dot segments
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
assert ws.status == 101
ws.send('path')
assert ws.recv() == '/test'
ws.close()
# We are restricted to known HTTP status codes in http.HTTPStatus
# Redirects are not supported for websockets
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
def test_raise_http_error(self, handler, status):
with handler() as rh:
with pytest.raises(HTTPError) as exc_info:
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
assert exc_info.value.status == status
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [
({'timeout': 0.00001}, {}),
({}, {'timeout': 0.00001}),
])
def test_timeout(self, handler, params, extensions):
with handler(**params) as rh:
with pytest.raises(TransportError):
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
version=0, name='test', value='ytdlp', port=None, port_specified=False,
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
path_specified=True, secure=False, expires=None, discard=False, comment=None,
comment_url=None, rest={}))
with handler(cookiejar=cookiejar) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert 'cookie' not in json.loads(ws.recv())
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
with handler(source_address=source_address) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('source_address')
assert source_address == ws.recv()
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_response_url(self, handler):
with handler() as rh:
url = f'{self.ws_base_url}/something'
ws = validate_and_send(rh, Request(url))
assert ws.url == url
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_request_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
# Global Headers
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test'
ws.close()
# Per request headers, merged with global
ws = validate_and_send(rh, Request(
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test'
assert headers['test2'] == 'changed'
assert headers['test3'] == 'test3'
ws.close()

View File

@ -1,27 +1,36 @@
# Request handler for https://github.com/python-websockets/websockets
from __future__ import annotations from __future__ import annotations
import asyncio
import atexit
import io import io
import logging import logging
import urllib.parse import ssl
import sys import sys
from ._helper import create_connection
from websockets.uri import parse_uri from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
from .common import Response, register_rh, Features
from .common import register_rh from .exceptions import (
from .exceptions import TransportError, RequestError CertificateVerifyError,
from .websocket import WebSocketResponse, WebSocketRequestHandler HTTPError,
RequestError,
SSLError,
TransportError, ProxyError,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..compat import functools
from ..dependencies import websockets from ..dependencies import websockets
from ..utils import int_or_none
from ..socks import ProxyError as SocksProxyError
if not websockets: if not websockets:
raise ImportError('websockets is not installed') raise ImportError('websockets is not installed')
import websockets.version
websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported')
import websockets.sync.client import websockets.sync.client
from websockets.exceptions import InvalidHandshake, InvalidURI, ConnectionClosed from websockets.uri import parse_uri
class WebsocketsResponseAdapter(WebSocketResponse): class WebsocketsResponseAdapter(WebSocketResponse):
@ -44,8 +53,10 @@ class WebsocketsResponseAdapter(WebSocketResponse):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
try: try:
return self.wsw.send(*args) return self.wsw.send(*args)
except (ConnectionClosed, RuntimeError) as e: except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except TypeError as e: except TypeError as e:
raise RequestError(cause=e) from e raise RequestError(cause=e) from e
@ -53,13 +64,22 @@ class WebsocketsResponseAdapter(WebSocketResponse):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
try: try:
return self.wsw.recv(*args) return self.wsw.recv(*args)
except (ConnectionClosed, RuntimeError) as e: except SocksProxyError as e:
raise ProxyError(cause=e) from e
except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e
@register_rh @register_rh
class WebsocketsRH(WebSocketRequestHandler): class WebsocketsRH(WebSocketRequestHandler):
"""
Websockets request handler
https://websockets.readthedocs.io
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws') _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets' RH_NAME = 'websockets'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -78,22 +98,6 @@ class WebsocketsRH(WebSocketRequestHandler):
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
def _send(self, request): def _send(self, request):
"""
https://websockets.readthedocs.io/en/stable/reference/sync/client.html
TODO:
- Cookie Support
- Test Exception Mapping
- Timeout handling for closing?
- WS Pinging
- KeyboardInterrupt doesn't seem to kill websockets
"""
ws_kwargs = {}
if urllib.parse.urlparse(request.url).scheme == 'wss':
ws_kwargs['ssl_context'] = self._make_sslcontext()
source_address = self.source_address
if source_address is not None:
ws_kwargs['source_address'] = source_address
timeout = float(request.extensions.get('timeout') or self.timeout) timeout = float(request.extensions.get('timeout') or self.timeout)
headers = self._merge_headers(request.headers) headers = self._merge_headers(request.headers)
if 'cookie' not in headers: if 'cookie' not in headers:
@ -103,24 +107,53 @@ class WebsocketsRH(WebSocketRequestHandler):
headers['cookie'] = cookie_header headers['cookie'] = cookie_header
wsuri = parse_uri(request.url) wsuri = parse_uri(request.url)
sock = create_connection( create_conn_kwargs = {
(wsuri.host, wsuri.port), 'source_address': (self.source_address, 0) if self.source_address else None,
source_address=(self.source_address, 0) if self.source_address else None, 'timeout': timeout
timeout=timeout }
) proxy = select_proxy(request.url, request.proxies or self.proxies or {})
try: try:
if proxy:
socks_proxy_options = make_socks_proxy_opts(proxy)
sock = create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs
)
else:
sock = create_connection(
address=(wsuri.host, wsuri.port),
**create_conn_kwargs
)
conn = websockets.sync.client.connect( conn = websockets.sync.client.connect(
sock=sock, sock=sock,
uri=request.url, uri=request.url,
additional_headers=headers, additional_headers=headers,
open_timeout=timeout, open_timeout=timeout,
user_agent_header=None, user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None,
) )
return WebsocketsResponseAdapter(conn, url=request.url) return WebsocketsResponseAdapter(conn, url=request.url)
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
except InvalidURI as e: except SocksProxyError as e:
raise ProxyError(cause=e) from e
except websockets.exceptions.InvalidURI as e:
raise RequestError(cause=e) from e raise RequestError(cause=e) from e
except (OSError, TimeoutError, InvalidHandshake) as e: except ssl.SSLCertVerificationError as e:
raise CertificateVerifyError(cause=e) from e
except ssl.SSLError as e:
raise SSLError(cause=e) from e
except websockets.exceptions.InvalidStatus as e:
raise HTTPError(
Response(
fp=io.BytesIO(e.response.body),
url=request.url,
headers=e.response.headers,
status=e.response.status_code,
reason=e.response.reason_phrase),
) from e
except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e