Compare commits

..

No commits in common. "9e05210991eebfafcd86912574bbd7dacd94fe55" and "a4310ef2d7a247ccbcc9e638c2decf8c42cfb760" have entirely different histories.

9 changed files with 155 additions and 404 deletions

View File

@ -1,8 +1,8 @@
mutagen
pycryptodomex
websockets
brotli; platform_python_implementation=='CPython'
brotlicffi; platform_python_implementation!='CPython'
certifi
requests>=2.31.0,<3
urllib3>=1.26.17,<3
websockets>=12.0

View File

@ -6,3 +6,16 @@ import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)

View File

@ -55,8 +55,6 @@ from yt_dlp.networking.exceptions import (
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
@ -280,6 +278,11 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)
class TestRequestHandlerBase:
@classmethod
def setup_class(cls):
@ -293,7 +296,7 @@ class TestRequestHandlerBase:
cls.http_server_thread.start()
# HTTPS server
certfn = os.path.join(TEST_DIR, '../testcert.pem')
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cls.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -372,7 +375,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
def test_raise_http_error(self, handler):
with handler() as rh:
# TODO Return HTTP status code url
for bad_status in (400, 500, 599, 302):
with pytest.raises(HTTPError):
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status)))
@ -694,8 +696,8 @@ class TestClientCertificate:
@classmethod
def setup_class(cls):
certfn = os.path.join(TEST_DIR, '../testcert.pem')
cls.certdir = os.path.join(TEST_DIR, '../testdata', 'certificate')
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate')
cacertfn = os.path.join(cls.certdir, 'ca.crt')
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -746,6 +748,82 @@ class TestClientCertificate:
})
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsockets:
@classmethod
def setup_class(cls):
import websockets.server
async def echo(websocket):
async for message in websocket:
if message == b'headers':
await websocket.send(json.dumps(dict(websocket.request_headers)))
elif message == 'source_address':
await websocket.send(websocket.remote_address[0])
else:
await websocket.send(message)
def run():
async def main():
async with websockets.server.serve(echo, "localhost", 8765):
await asyncio.Future()
asyncio.run(main())
cls.ws_server_thread = threading.Thread(target=run)
cls.ws_server_thread.daemon = True
cls.ws_server_thread.start()
time.sleep(1) # wait for server to start
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_recv(self, handler):
with handler() as rh:
ws = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
ws.send(b'foo')
assert ws.recv() == b'foo'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [
({'timeout': 0.00001}, {}),
({}, {'timeout': 0.00001}),
])
def test_timeout(self, handler, params, extensions):
with handler(**params) as rh:
with pytest.raises(TransportError):
validate_and_send(rh, Request('ws://127.0.0.1:8765', extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
version=0, name='test', value='ytdlp', port=None, port_specified=False,
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
path_specified=True, secure=False, expires=None, discard=False, comment=None,
comment_url=None, rest={}))
with handler(cookiejar=cookiejar) as rh:
res = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
res.send(b'headers')
assert json.loads(res.recv())['cookie'] == 'test=ytdlp'
with handler() as rh:
res = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
res.send(b'headers')
assert 'cookie' not in json.loads(res.recv())
res = validate_and_send(rh, Request('ws://127.0.0.1:8765', extensions={'cookiejar': cookiejar}))
res.send(b'headers')
assert json.loads(res.recv())['cookie'] == 'test=ytdlp'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
with handler(source_address=source_address) as rh:
res = validate_and_send(
rh, Request(f'ws://127.0.0.1:8765/source_address'))
res.send('source_address')
assert source_address == res.recv()
class TestUrllibRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_file_urls(self, handler):

View File

@ -1,26 +0,0 @@
import functools
import inspect
import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -1,258 +0,0 @@
#!/usr/bin/env python3
import enum
# Allow direct execution
import os
import sys
from queue import Queue
import pytest
from websockets.datastructures import Headers
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import gzip
import http.client
import http.cookiejar
import http.server
import io
import json
import pathlib
import random
import ssl
import tempfile
import threading
import time
import urllib.error
import urllib.request
import warnings
import zlib
from email.message import Message
from http.cookiejar import CookieJar
import websockets.sync
from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3, websockets
from yt_dlp.networking import (
HEADRequest,
PUTRequest,
Request,
RequestDirector,
RequestHandler,
Response,
)
from yt_dlp.networking._urllib import UrllibRH
from yt_dlp.networking.exceptions import (
CertificateVerifyError,
HTTPError,
IncompleteRead,
NoSupportingHandlers,
ProxyError,
RequestError,
SSLError,
TransportError,
UnsupportedRequest,
)
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
def websocket_handler(websocket):
for message in websocket:
if message == 'headers':
return websocket.send(json.dumps(dict(websocket.request.headers)))
elif message == 'path':
return websocket.send(websocket.request.path)
elif message == 'source_address':
return websocket.send(websocket.remote_address[0])
else:
return websocket.send(message)
def process_request(self, request):
if request.path.startswith('/gen_'):
status = http.HTTPStatus(int(request.path[5:]))
if 300 <= status.value <= 300:
return websockets.http11.Response(
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
return self.protocol.reject(status.value, status.phrase)
return self.protocol.accept(request)
def create_websocket_server(**ws_kwargs):
import websockets.sync.server
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
ws_port = wsd.socket.getsockname()[1]
ws_server_thread = threading.Thread(target=wsd.serve_forever)
ws_server_thread.daemon = True
ws_server_thread.start()
return ws_server_thread, ws_port
def create_ws_websocket_server():
return create_websocket_server()
def create_wss_websocket_server():
certfn = os.path.join(TEST_DIR, '../testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance:
@classmethod
def setup_class(cls):
cls.ws_thread, cls.ws_port = create_ws_websocket_server()
cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
cls.wss_thread, cls.wss_port = create_wss_websocket_server()
cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler):
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
assert 'upgrade' in ws.headers
assert ws.status == 101
ws.send('foo')
assert ws.recv() == 'foo'
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_verify_cert(self, handler):
with handler() as rh:
with pytest.raises(CertificateVerifyError):
validate_and_send(rh, Request(self.wss_base_url))
with handler(verify=False) as rh:
ws = validate_and_send(rh, Request(self.wss_base_url))
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_ssl_error(self, handler):
with handler(verify=False) as rh:
with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('path,expected', [
# Unicode characters should be encoded with uppercase percent-encoding
('/中文', '/%E4%B8%AD%E6%96%87'),
# don't normalize existing percent encodings
('/%c7%9f', '/%c7%9f'),
])
def test_percent_encode(self, handler, path, expected):
with handler() as rh:
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
ws.send('path')
assert ws.recv() == expected
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_remove_dot_segments(self, handler):
with handler() as rh:
# This isn't a comprehensive test,
# but it should be enough to check whether the handler is removing dot segments
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
assert ws.status == 101
ws.send('path')
assert ws.recv() == '/test'
ws.close()
# We are restricted to known HTTP status codes in http.HTTPStatus
# Redirects are not supported for websockets
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
def test_raise_http_error(self, handler, status):
with handler() as rh:
with pytest.raises(HTTPError) as exc_info:
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
assert exc_info.value.status == status
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [
({'timeout': 0.00001}, {}),
({}, {'timeout': 0.00001}),
])
def test_timeout(self, handler, params, extensions):
with handler(**params) as rh:
with pytest.raises(TransportError):
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
version=0, name='test', value='ytdlp', port=None, port_specified=False,
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
path_specified=True, secure=False, expires=None, discard=False, comment=None,
comment_url=None, rest={}))
with handler(cookiejar=cookiejar) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert 'cookie' not in json.loads(ws.recv())
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
with handler(source_address=source_address) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('source_address')
assert source_address == ws.recv()
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_response_url(self, handler):
with handler() as rh:
url = f'{self.ws_base_url}/something'
ws = validate_and_send(rh, Request(url))
assert ws.url == url
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_request_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
# Global Headers
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test'
ws.close()
# Per request headers, merged with global
ws = validate_and_send(rh, Request(
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test'
assert headers['test2'] == 'changed'
assert headers['test3'] == 'test3'
ws.close()

View File

@ -210,16 +210,6 @@ class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestR
self.wfile.write(payload.encode())
class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
def handle(self):
import websockets.sync.server
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
connection.send(json.dumps(self.socks_info))
connection.close()
@contextlib.contextmanager
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
server = server_thread = None
@ -262,21 +252,8 @@ class HTTPSocksTestProxyContext(SocksProxyTestContext):
return json.loads(handler.send(request).read().decode())
class WebSocketSocksTestProxyContext(SocksProxyTestContext):
REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
handler.validate(request)
ws = handler.send(request)
ws.send('socks_info')
socks_info = ws.recv()
ws.close()
return json.loads(socks_info)
CTX_MAP = {
'http': HTTPSocksTestProxyContext,
'ws': WebSocketSocksTestProxyContext,
}
@ -286,7 +263,7 @@ def ctx(request):
class TestSocks4Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks4_no_auth(self, handler, ctx):
with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler) as server_address:
@ -294,7 +271,7 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://{server_address}'})
assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks4_auth(self, handler, ctx):
with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
@ -304,7 +281,7 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://user:@{server_address}'})
assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks4a_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@ -312,7 +289,7 @@ class TestSocks4Proxy:
assert response['version'] == 4
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks4a_domain_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@ -321,7 +298,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] is None
assert response['domain_address'] == 'localhost'
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
@ -331,7 +308,7 @@ class TestSocks4Proxy:
assert response['client_address'][0] == source_address
assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
@pytest.mark.parametrize('reply_code', [
Socks4CD.REQUEST_REJECTED_OR_FAILED,
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
@ -343,7 +320,7 @@ class TestSocks4Proxy:
with pytest.raises(ProxyError):
ctx.socks_info_request(rh)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_ipv6_socks4_proxy(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
@ -352,7 +329,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
@ -362,7 +339,7 @@ class TestSocks4Proxy:
class TestSocks5Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5_no_auth(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -370,7 +347,7 @@ class TestSocks5Proxy:
assert response['auth_methods'] == [0x0]
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5_user_pass(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
with handler() as rh:
@ -383,7 +360,7 @@ class TestSocks5Proxy:
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -391,7 +368,7 @@ class TestSocks5Proxy:
assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -399,7 +376,7 @@ class TestSocks5Proxy:
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5h_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@ -408,7 +385,7 @@ class TestSocks5Proxy:
assert response['domain_address'] == 'localhost'
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5h_ip_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@ -417,7 +394,7 @@ class TestSocks5Proxy:
assert response['domain_address'] is None
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_socks5_ipv6_destination(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -425,7 +402,7 @@ class TestSocks5Proxy:
assert response['ipv6_address'] == '::1'
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_ipv6_socks5_proxy(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -436,7 +413,7 @@ class TestSocks5Proxy:
# XXX: is there any feasible way of testing IPv6 source addresses?
# Same would go for non-proxy source_address test...
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
@ -445,7 +422,7 @@ class TestSocks5Proxy:
assert response['client_address'][0] == source_address
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
@pytest.mark.parametrize('reply_code', [
Socks5Reply.GENERAL_FAILURE,
Socks5Reply.CONNECTION_NOT_ALLOWED,
@ -462,7 +439,7 @@ class TestSocks5Proxy:
with pytest.raises(ProxyError):
ctx.socks_info_request(rh)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:

View File

@ -1,36 +1,27 @@
# Request handler for https://github.com/python-websockets/websockets
from __future__ import annotations
import asyncio
import atexit
import io
import logging
import ssl
import urllib.parse
import sys
from ._helper import create_connection
from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
from .common import Response, register_rh, Features
from .exceptions import (
CertificateVerifyError,
HTTPError,
RequestError,
SSLError,
TransportError, ProxyError,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..compat import functools
from websockets.uri import parse_uri
from .common import register_rh
from .exceptions import TransportError, RequestError
from .websocket import WebSocketResponse, WebSocketRequestHandler
from ..dependencies import websockets
from ..utils import int_or_none
from ..socks import ProxyError as SocksProxyError
if not websockets:
raise ImportError('websockets is not installed')
import websockets.version
websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported')
import websockets.sync.client
from websockets.uri import parse_uri
from websockets.exceptions import InvalidHandshake, InvalidURI, ConnectionClosed
class WebsocketsResponseAdapter(WebSocketResponse):
@ -53,10 +44,8 @@ class WebsocketsResponseAdapter(WebSocketResponse):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
try:
return self.wsw.send(*args)
except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
except (ConnectionClosed, RuntimeError) as e:
raise TransportError(cause=e) from e
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except TypeError as e:
raise RequestError(cause=e) from e
@ -64,22 +53,13 @@ class WebsocketsResponseAdapter(WebSocketResponse):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
try:
return self.wsw.recv(*args)
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
except (ConnectionClosed, RuntimeError) as e:
raise TransportError(cause=e) from e
@register_rh
class WebsocketsRH(WebSocketRequestHandler):
"""
Websockets request handler
https://websockets.readthedocs.io
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets'
def __init__(self, *args, **kwargs):
@ -98,6 +78,22 @@ class WebsocketsRH(WebSocketRequestHandler):
extensions.pop('cookiejar', None)
def _send(self, request):
"""
https://websockets.readthedocs.io/en/stable/reference/sync/client.html
TODO:
- Cookie Support
- Test Exception Mapping
- Timeout handling for closing?
- WS Pinging
- KeyboardInterrupt doesn't seem to kill websockets
"""
ws_kwargs = {}
if urllib.parse.urlparse(request.url).scheme == 'wss':
ws_kwargs['ssl_context'] = self._make_sslcontext()
source_address = self.source_address
if source_address is not None:
ws_kwargs['source_address'] = source_address
timeout = float(request.extensions.get('timeout') or self.timeout)
headers = self._merge_headers(request.headers)
if 'cookie' not in headers:
@ -107,53 +103,24 @@ class WebsocketsRH(WebSocketRequestHandler):
headers['cookie'] = cookie_header
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout
}
proxy = select_proxy(request.url, request.proxies or self.proxies or {})
sock = create_connection(
(wsuri.host, wsuri.port),
source_address=(self.source_address, 0) if self.source_address else None,
timeout=timeout
)
try:
if proxy:
socks_proxy_options = make_socks_proxy_opts(proxy)
sock = create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs
)
else:
sock = create_connection(
address=(wsuri.host, wsuri.port),
**create_conn_kwargs
)
conn = websockets.sync.client.connect(
sock=sock,
uri=request.url,
additional_headers=headers,
open_timeout=timeout,
user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None,
)
return WebsocketsResponseAdapter(conn, url=request.url)
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except websockets.exceptions.InvalidURI as e:
except InvalidURI as e:
raise RequestError(cause=e) from e
except ssl.SSLCertVerificationError as e:
raise CertificateVerifyError(cause=e) from e
except ssl.SSLError as e:
raise SSLError(cause=e) from e
except websockets.exceptions.InvalidStatus as e:
raise HTTPError(
Response(
fp=io.BytesIO(e.response.body),
url=request.url,
headers=e.response.headers,
status=e.response.status_code,
reason=e.response.reason_phrase),
) from e
except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
except (OSError, TimeoutError, InvalidHandshake) as e:
raise TransportError(cause=e) from e