mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-09-21 09:51:25 +02:00
Compare commits
No commits in common. "9e05210991eebfafcd86912574bbd7dacd94fe55" and "a4310ef2d7a247ccbcc9e638c2decf8c42cfb760" have entirely different histories.
9e05210991
...
a4310ef2d7
|
@ -1,8 +1,8 @@
|
||||||
mutagen
|
mutagen
|
||||||
pycryptodomex
|
pycryptodomex
|
||||||
|
websockets
|
||||||
brotli; platform_python_implementation=='CPython'
|
brotli; platform_python_implementation=='CPython'
|
||||||
brotlicffi; platform_python_implementation!='CPython'
|
brotlicffi; platform_python_implementation!='CPython'
|
||||||
certifi
|
certifi
|
||||||
requests>=2.31.0,<3
|
requests>=2.31.0,<3
|
||||||
urllib3>=1.26.17,<3
|
urllib3>=1.26.17,<3
|
||||||
websockets>=12.0
|
|
|
@ -6,3 +6,16 @@ import pytest
|
||||||
from yt_dlp.networking import RequestHandler
|
from yt_dlp.networking import RequestHandler
|
||||||
from yt_dlp.networking.common import _REQUEST_HANDLERS
|
from yt_dlp.networking.common import _REQUEST_HANDLERS
|
||||||
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler(request):
|
||||||
|
RH_KEY = request.param
|
||||||
|
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
|
||||||
|
handler = RH_KEY
|
||||||
|
elif RH_KEY in _REQUEST_HANDLERS:
|
||||||
|
handler = _REQUEST_HANDLERS[RH_KEY]
|
||||||
|
else:
|
||||||
|
pytest.skip(f'{RH_KEY} request handler is not available')
|
||||||
|
|
||||||
|
return functools.partial(handler, logger=FakeLogger)
|
||||||
|
|
|
@ -55,8 +55,6 @@ from yt_dlp.networking.exceptions import (
|
||||||
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
||||||
from yt_dlp.utils.networking import HTTPHeaderDict
|
from yt_dlp.utils.networking import HTTPHeaderDict
|
||||||
|
|
||||||
from .conftest import validate_and_send
|
|
||||||
|
|
||||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,6 +278,11 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||||
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
|
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_send(rh, req):
|
||||||
|
rh.validate(req)
|
||||||
|
return rh.send(req)
|
||||||
|
|
||||||
|
|
||||||
class TestRequestHandlerBase:
|
class TestRequestHandlerBase:
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
|
@ -293,7 +296,7 @@ class TestRequestHandlerBase:
|
||||||
cls.http_server_thread.start()
|
cls.http_server_thread.start()
|
||||||
|
|
||||||
# HTTPS server
|
# HTTPS server
|
||||||
certfn = os.path.join(TEST_DIR, '../testcert.pem')
|
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
||||||
cls.https_httpd = http.server.ThreadingHTTPServer(
|
cls.https_httpd = http.server.ThreadingHTTPServer(
|
||||||
('127.0.0.1', 0), HTTPTestRequestHandler)
|
('127.0.0.1', 0), HTTPTestRequestHandler)
|
||||||
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
@ -372,7 +375,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||||
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
|
||||||
def test_raise_http_error(self, handler):
|
def test_raise_http_error(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
# TODO Return HTTP status code url
|
|
||||||
for bad_status in (400, 500, 599, 302):
|
for bad_status in (400, 500, 599, 302):
|
||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status)))
|
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status)))
|
||||||
|
@ -694,8 +696,8 @@ class TestClientCertificate:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
certfn = os.path.join(TEST_DIR, '../testcert.pem')
|
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
||||||
cls.certdir = os.path.join(TEST_DIR, '../testdata', 'certificate')
|
cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate')
|
||||||
cacertfn = os.path.join(cls.certdir, 'ca.crt')
|
cacertfn = os.path.join(cls.certdir, 'ca.crt')
|
||||||
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
|
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
|
||||||
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
@ -746,6 +748,82 @@ class TestClientCertificate:
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
||||||
|
class TestWebsockets:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
import websockets.server
|
||||||
|
|
||||||
|
async def echo(websocket):
|
||||||
|
async for message in websocket:
|
||||||
|
if message == b'headers':
|
||||||
|
await websocket.send(json.dumps(dict(websocket.request_headers)))
|
||||||
|
elif message == 'source_address':
|
||||||
|
await websocket.send(websocket.remote_address[0])
|
||||||
|
else:
|
||||||
|
await websocket.send(message)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
async def main():
|
||||||
|
async with websockets.server.serve(echo, "localhost", 8765):
|
||||||
|
await asyncio.Future()
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
cls.ws_server_thread = threading.Thread(target=run)
|
||||||
|
cls.ws_server_thread.daemon = True
|
||||||
|
cls.ws_server_thread.start()
|
||||||
|
time.sleep(1) # wait for server to start
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_send_recv(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
ws = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
|
||||||
|
ws.send(b'foo')
|
||||||
|
assert ws.recv() == b'foo'
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
@pytest.mark.parametrize('params,extensions', [
|
||||||
|
({'timeout': 0.00001}, {}),
|
||||||
|
({}, {'timeout': 0.00001}),
|
||||||
|
])
|
||||||
|
def test_timeout(self, handler, params, extensions):
|
||||||
|
with handler(**params) as rh:
|
||||||
|
with pytest.raises(TransportError):
|
||||||
|
validate_and_send(rh, Request('ws://127.0.0.1:8765', extensions=extensions))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_cookies(self, handler):
|
||||||
|
cookiejar = YoutubeDLCookieJar()
|
||||||
|
cookiejar.set_cookie(http.cookiejar.Cookie(
|
||||||
|
version=0, name='test', value='ytdlp', port=None, port_specified=False,
|
||||||
|
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
|
||||||
|
path_specified=True, secure=False, expires=None, discard=False, comment=None,
|
||||||
|
comment_url=None, rest={}))
|
||||||
|
|
||||||
|
with handler(cookiejar=cookiejar) as rh:
|
||||||
|
res = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
|
||||||
|
res.send(b'headers')
|
||||||
|
assert json.loads(res.recv())['cookie'] == 'test=ytdlp'
|
||||||
|
|
||||||
|
with handler() as rh:
|
||||||
|
res = validate_and_send(rh, Request('ws://127.0.0.1:8765'))
|
||||||
|
res.send(b'headers')
|
||||||
|
assert 'cookie' not in json.loads(res.recv())
|
||||||
|
|
||||||
|
res = validate_and_send(rh, Request('ws://127.0.0.1:8765', extensions={'cookiejar': cookiejar}))
|
||||||
|
res.send(b'headers')
|
||||||
|
assert json.loads(res.recv())['cookie'] == 'test=ytdlp'
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_source_address(self, handler):
|
||||||
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
|
with handler(source_address=source_address) as rh:
|
||||||
|
res = validate_and_send(
|
||||||
|
rh, Request(f'ws://127.0.0.1:8765/source_address'))
|
||||||
|
res.send('source_address')
|
||||||
|
assert source_address == res.recv()
|
||||||
|
|
||||||
|
|
||||||
class TestUrllibRequestHandler(TestRequestHandlerBase):
|
class TestUrllibRequestHandler(TestRequestHandlerBase):
|
||||||
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
|
||||||
def test_file_urls(self, handler):
|
def test_file_urls(self, handler):
|
|
@ -1,26 +0,0 @@
|
||||||
import functools
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from yt_dlp.networking import RequestHandler
|
|
||||||
from yt_dlp.networking.common import _REQUEST_HANDLERS
|
|
||||||
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def handler(request):
|
|
||||||
RH_KEY = request.param
|
|
||||||
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
|
|
||||||
handler = RH_KEY
|
|
||||||
elif RH_KEY in _REQUEST_HANDLERS:
|
|
||||||
handler = _REQUEST_HANDLERS[RH_KEY]
|
|
||||||
else:
|
|
||||||
pytest.skip(f'{RH_KEY} request handler is not available')
|
|
||||||
|
|
||||||
return functools.partial(handler, logger=FakeLogger)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_and_send(rh, req):
|
|
||||||
rh.validate(req)
|
|
||||||
return rh.send(req)
|
|
|
@ -1,258 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import enum
|
|
||||||
|
|
||||||
# Allow direct execution
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from websockets.datastructures import Headers
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import functools
|
|
||||||
import gzip
|
|
||||||
import http.client
|
|
||||||
import http.cookiejar
|
|
||||||
import http.server
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import pathlib
|
|
||||||
import random
|
|
||||||
import ssl
|
|
||||||
import tempfile
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
import warnings
|
|
||||||
import zlib
|
|
||||||
from email.message import Message
|
|
||||||
from http.cookiejar import CookieJar
|
|
||||||
|
|
||||||
import websockets.sync
|
|
||||||
|
|
||||||
from test.helper import FakeYDL, http_server_port
|
|
||||||
from yt_dlp.cookies import YoutubeDLCookieJar
|
|
||||||
from yt_dlp.dependencies import brotli, requests, urllib3, websockets
|
|
||||||
from yt_dlp.networking import (
|
|
||||||
HEADRequest,
|
|
||||||
PUTRequest,
|
|
||||||
Request,
|
|
||||||
RequestDirector,
|
|
||||||
RequestHandler,
|
|
||||||
Response,
|
|
||||||
)
|
|
||||||
from yt_dlp.networking._urllib import UrllibRH
|
|
||||||
from yt_dlp.networking.exceptions import (
|
|
||||||
CertificateVerifyError,
|
|
||||||
HTTPError,
|
|
||||||
IncompleteRead,
|
|
||||||
NoSupportingHandlers,
|
|
||||||
ProxyError,
|
|
||||||
RequestError,
|
|
||||||
SSLError,
|
|
||||||
TransportError,
|
|
||||||
UnsupportedRequest,
|
|
||||||
)
|
|
||||||
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
|
||||||
from yt_dlp.utils.networking import HTTPHeaderDict
|
|
||||||
|
|
||||||
from .conftest import validate_and_send
|
|
||||||
|
|
||||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
|
|
||||||
def websocket_handler(websocket):
|
|
||||||
for message in websocket:
|
|
||||||
if message == 'headers':
|
|
||||||
return websocket.send(json.dumps(dict(websocket.request.headers)))
|
|
||||||
elif message == 'path':
|
|
||||||
return websocket.send(websocket.request.path)
|
|
||||||
elif message == 'source_address':
|
|
||||||
return websocket.send(websocket.remote_address[0])
|
|
||||||
else:
|
|
||||||
return websocket.send(message)
|
|
||||||
|
|
||||||
|
|
||||||
def process_request(self, request):
|
|
||||||
if request.path.startswith('/gen_'):
|
|
||||||
status = http.HTTPStatus(int(request.path[5:]))
|
|
||||||
if 300 <= status.value <= 300:
|
|
||||||
return websockets.http11.Response(
|
|
||||||
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
|
|
||||||
return self.protocol.reject(status.value, status.phrase)
|
|
||||||
return self.protocol.accept(request)
|
|
||||||
|
|
||||||
|
|
||||||
def create_websocket_server(**ws_kwargs):
|
|
||||||
import websockets.sync.server
|
|
||||||
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
|
|
||||||
ws_port = wsd.socket.getsockname()[1]
|
|
||||||
ws_server_thread = threading.Thread(target=wsd.serve_forever)
|
|
||||||
ws_server_thread.daemon = True
|
|
||||||
ws_server_thread.start()
|
|
||||||
return ws_server_thread, ws_port
|
|
||||||
|
|
||||||
|
|
||||||
def create_ws_websocket_server():
|
|
||||||
return create_websocket_server()
|
|
||||||
|
|
||||||
|
|
||||||
def create_wss_websocket_server():
|
|
||||||
certfn = os.path.join(TEST_DIR, '../testcert.pem')
|
|
||||||
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
||||||
sslctx.load_cert_chain(certfn, None)
|
|
||||||
return create_websocket_server(ssl_context=sslctx)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
|
||||||
class TestWebsSocketRequestHandlerConformance:
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls):
|
|
||||||
cls.ws_thread, cls.ws_port = create_ws_websocket_server()
|
|
||||||
cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
|
|
||||||
|
|
||||||
cls.wss_thread, cls.wss_port = create_wss_websocket_server()
|
|
||||||
cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
|
|
||||||
|
|
||||||
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
|
|
||||||
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_basic_websockets(self, handler):
|
|
||||||
with handler() as rh:
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
|
||||||
assert 'upgrade' in ws.headers
|
|
||||||
assert ws.status == 101
|
|
||||||
ws.send('foo')
|
|
||||||
assert ws.recv() == 'foo'
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_verify_cert(self, handler):
|
|
||||||
with handler() as rh:
|
|
||||||
with pytest.raises(CertificateVerifyError):
|
|
||||||
validate_and_send(rh, Request(self.wss_base_url))
|
|
||||||
|
|
||||||
with handler(verify=False) as rh:
|
|
||||||
ws = validate_and_send(rh, Request(self.wss_base_url))
|
|
||||||
assert ws.status == 101
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_ssl_error(self, handler):
|
|
||||||
with handler(verify=False) as rh:
|
|
||||||
with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
|
|
||||||
validate_and_send(rh, Request(self.bad_wss_host))
|
|
||||||
assert not issubclass(exc_info.type, CertificateVerifyError)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
@pytest.mark.parametrize('path,expected', [
|
|
||||||
# Unicode characters should be encoded with uppercase percent-encoding
|
|
||||||
('/中文', '/%E4%B8%AD%E6%96%87'),
|
|
||||||
# don't normalize existing percent encodings
|
|
||||||
('/%c7%9f', '/%c7%9f'),
|
|
||||||
])
|
|
||||||
def test_percent_encode(self, handler, path, expected):
|
|
||||||
with handler() as rh:
|
|
||||||
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
|
||||||
ws.send('path')
|
|
||||||
assert ws.recv() == expected
|
|
||||||
assert ws.status == 101
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_remove_dot_segments(self, handler):
|
|
||||||
with handler() as rh:
|
|
||||||
# This isn't a comprehensive test,
|
|
||||||
# but it should be enough to check whether the handler is removing dot segments
|
|
||||||
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
|
||||||
assert ws.status == 101
|
|
||||||
ws.send('path')
|
|
||||||
assert ws.recv() == '/test'
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
# We are restricted to known HTTP status codes in http.HTTPStatus
|
|
||||||
# Redirects are not supported for websockets
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
|
|
||||||
def test_raise_http_error(self, handler, status):
|
|
||||||
with handler() as rh:
|
|
||||||
with pytest.raises(HTTPError) as exc_info:
|
|
||||||
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
|
||||||
assert exc_info.value.status == status
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
@pytest.mark.parametrize('params,extensions', [
|
|
||||||
({'timeout': 0.00001}, {}),
|
|
||||||
({}, {'timeout': 0.00001}),
|
|
||||||
])
|
|
||||||
def test_timeout(self, handler, params, extensions):
|
|
||||||
with handler(**params) as rh:
|
|
||||||
with pytest.raises(TransportError):
|
|
||||||
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_cookies(self, handler):
|
|
||||||
cookiejar = YoutubeDLCookieJar()
|
|
||||||
cookiejar.set_cookie(http.cookiejar.Cookie(
|
|
||||||
version=0, name='test', value='ytdlp', port=None, port_specified=False,
|
|
||||||
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
|
|
||||||
path_specified=True, secure=False, expires=None, discard=False, comment=None,
|
|
||||||
comment_url=None, rest={}))
|
|
||||||
|
|
||||||
with handler(cookiejar=cookiejar) as rh:
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
|
||||||
ws.send('headers')
|
|
||||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
with handler() as rh:
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
|
||||||
ws.send('headers')
|
|
||||||
assert 'cookie' not in json.loads(ws.recv())
|
|
||||||
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
|
||||||
ws.send('headers')
|
|
||||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_source_address(self, handler):
|
|
||||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
|
||||||
with handler(source_address=source_address) as rh:
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
|
||||||
ws.send('source_address')
|
|
||||||
assert source_address == ws.recv()
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_response_url(self, handler):
|
|
||||||
with handler() as rh:
|
|
||||||
url = f'{self.ws_base_url}/something'
|
|
||||||
ws = validate_and_send(rh, Request(url))
|
|
||||||
assert ws.url == url
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
||||||
def test_request_headers(self, handler):
|
|
||||||
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
|
|
||||||
# Global Headers
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
|
||||||
ws.send('headers')
|
|
||||||
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
|
||||||
assert headers['test1'] == 'test'
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
# Per request headers, merged with global
|
|
||||||
ws = validate_and_send(rh, Request(
|
|
||||||
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
|
|
||||||
ws.send('headers')
|
|
||||||
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
|
||||||
assert headers['test1'] == 'test'
|
|
||||||
assert headers['test2'] == 'changed'
|
|
||||||
assert headers['test3'] == 'test3'
|
|
||||||
ws.close()
|
|
|
@ -210,16 +210,6 @@ class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestR
|
||||||
self.wfile.write(payload.encode())
|
self.wfile.write(payload.encode())
|
||||||
|
|
||||||
|
|
||||||
class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
|
|
||||||
def handle(self):
|
|
||||||
import websockets.sync.server
|
|
||||||
protocol = websockets.ServerProtocol()
|
|
||||||
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
|
|
||||||
connection.handshake()
|
|
||||||
connection.send(json.dumps(self.socks_info))
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
|
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
|
||||||
server = server_thread = None
|
server = server_thread = None
|
||||||
|
@ -262,21 +252,8 @@ class HTTPSocksTestProxyContext(SocksProxyTestContext):
|
||||||
return json.loads(handler.send(request).read().decode())
|
return json.loads(handler.send(request).read().decode())
|
||||||
|
|
||||||
|
|
||||||
class WebSocketSocksTestProxyContext(SocksProxyTestContext):
|
|
||||||
REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
|
|
||||||
|
|
||||||
def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
|
|
||||||
request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
|
|
||||||
handler.validate(request)
|
|
||||||
ws = handler.send(request)
|
|
||||||
ws.send('socks_info')
|
|
||||||
socks_info = ws.recv()
|
|
||||||
ws.close()
|
|
||||||
return json.loads(socks_info)
|
|
||||||
|
|
||||||
CTX_MAP = {
|
CTX_MAP = {
|
||||||
'http': HTTPSocksTestProxyContext,
|
'http': HTTPSocksTestProxyContext,
|
||||||
'ws': WebSocketSocksTestProxyContext,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -286,7 +263,7 @@ def ctx(request):
|
||||||
|
|
||||||
|
|
||||||
class TestSocks4Proxy:
|
class TestSocks4Proxy:
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks4_no_auth(self, handler, ctx):
|
def test_socks4_no_auth(self, handler, ctx):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
|
@ -294,7 +271,7 @@ class TestSocks4Proxy:
|
||||||
rh, proxies={'all': f'socks4://{server_address}'})
|
rh, proxies={'all': f'socks4://{server_address}'})
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks4_auth(self, handler, ctx):
|
def test_socks4_auth(self, handler, ctx):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
|
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
|
||||||
|
@ -304,7 +281,7 @@ class TestSocks4Proxy:
|
||||||
rh, proxies={'all': f'socks4://user:@{server_address}'})
|
rh, proxies={'all': f'socks4://user:@{server_address}'})
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks4a_ipv4_target(self, handler, ctx):
|
def test_socks4a_ipv4_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
||||||
|
@ -312,7 +289,7 @@ class TestSocks4Proxy:
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
|
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks4a_domain_target(self, handler, ctx):
|
def test_socks4a_domain_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
||||||
|
@ -321,7 +298,7 @@ class TestSocks4Proxy:
|
||||||
assert response['ipv4_address'] is None
|
assert response['ipv4_address'] is None
|
||||||
assert response['domain_address'] == 'localhost'
|
assert response['domain_address'] == 'localhost'
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_ipv4_client_source_address(self, handler, ctx):
|
def test_ipv4_client_source_address(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
|
@ -331,7 +308,7 @@ class TestSocks4Proxy:
|
||||||
assert response['client_address'][0] == source_address
|
assert response['client_address'][0] == source_address
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
@pytest.mark.parametrize('reply_code', [
|
@pytest.mark.parametrize('reply_code', [
|
||||||
Socks4CD.REQUEST_REJECTED_OR_FAILED,
|
Socks4CD.REQUEST_REJECTED_OR_FAILED,
|
||||||
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
|
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
|
||||||
|
@ -343,7 +320,7 @@ class TestSocks4Proxy:
|
||||||
with pytest.raises(ProxyError):
|
with pytest.raises(ProxyError):
|
||||||
ctx.socks_info_request(rh)
|
ctx.socks_info_request(rh)
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_ipv6_socks4_proxy(self, handler, ctx):
|
def test_ipv6_socks4_proxy(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
|
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
|
||||||
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
|
||||||
|
@ -352,7 +329,7 @@ class TestSocks4Proxy:
|
||||||
assert response['ipv4_address'] == '127.0.0.1'
|
assert response['ipv4_address'] == '127.0.0.1'
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_timeout(self, handler, ctx):
|
def test_timeout(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
|
||||||
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
|
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
|
||||||
|
@ -362,7 +339,7 @@ class TestSocks4Proxy:
|
||||||
|
|
||||||
class TestSocks5Proxy:
|
class TestSocks5Proxy:
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5_no_auth(self, handler, ctx):
|
def test_socks5_no_auth(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
|
@ -370,7 +347,7 @@ class TestSocks5Proxy:
|
||||||
assert response['auth_methods'] == [0x0]
|
assert response['auth_methods'] == [0x0]
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5_user_pass(self, handler, ctx):
|
def test_socks5_user_pass(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
|
@ -383,7 +360,7 @@ class TestSocks5Proxy:
|
||||||
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
|
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5_ipv4_target(self, handler, ctx):
|
def test_socks5_ipv4_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
|
@ -391,7 +368,7 @@ class TestSocks5Proxy:
|
||||||
assert response['ipv4_address'] == '127.0.0.1'
|
assert response['ipv4_address'] == '127.0.0.1'
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5_domain_target(self, handler, ctx):
|
def test_socks5_domain_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
|
@ -399,7 +376,7 @@ class TestSocks5Proxy:
|
||||||
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
|
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5h_domain_target(self, handler, ctx):
|
def test_socks5h_domain_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
||||||
|
@ -408,7 +385,7 @@ class TestSocks5Proxy:
|
||||||
assert response['domain_address'] == 'localhost'
|
assert response['domain_address'] == 'localhost'
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5h_ip_target(self, handler, ctx):
|
def test_socks5h_ip_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
||||||
|
@ -417,7 +394,7 @@ class TestSocks5Proxy:
|
||||||
assert response['domain_address'] is None
|
assert response['domain_address'] is None
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_socks5_ipv6_destination(self, handler, ctx):
|
def test_socks5_ipv6_destination(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
|
@ -425,7 +402,7 @@ class TestSocks5Proxy:
|
||||||
assert response['ipv6_address'] == '::1'
|
assert response['ipv6_address'] == '::1'
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_ipv6_socks5_proxy(self, handler, ctx):
|
def test_ipv6_socks5_proxy(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
|
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
|
@ -436,7 +413,7 @@ class TestSocks5Proxy:
|
||||||
|
|
||||||
# XXX: is there any feasible way of testing IPv6 source addresses?
|
# XXX: is there any feasible way of testing IPv6 source addresses?
|
||||||
# Same would go for non-proxy source_address test...
|
# Same would go for non-proxy source_address test...
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
def test_ipv4_client_source_address(self, handler, ctx):
|
def test_ipv4_client_source_address(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
|
@ -445,7 +422,7 @@ class TestSocks5Proxy:
|
||||||
assert response['client_address'][0] == source_address
|
assert response['client_address'][0] == source_address
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
||||||
@pytest.mark.parametrize('reply_code', [
|
@pytest.mark.parametrize('reply_code', [
|
||||||
Socks5Reply.GENERAL_FAILURE,
|
Socks5Reply.GENERAL_FAILURE,
|
||||||
Socks5Reply.CONNECTION_NOT_ALLOWED,
|
Socks5Reply.CONNECTION_NOT_ALLOWED,
|
||||||
|
@ -462,7 +439,7 @@ class TestSocks5Proxy:
|
||||||
with pytest.raises(ProxyError):
|
with pytest.raises(ProxyError):
|
||||||
ctx.socks_info_request(rh)
|
ctx.socks_info_request(rh)
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
|
||||||
def test_timeout(self, handler, ctx):
|
def test_timeout(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
|
|
@ -1,36 +1,27 @@
|
||||||
|
# Request handler for https://github.com/python-websockets/websockets
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
import urllib.parse
|
||||||
import sys
|
import sys
|
||||||
|
from ._helper import create_connection
|
||||||
|
|
||||||
from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
|
from websockets.uri import parse_uri
|
||||||
from .common import Response, register_rh, Features
|
|
||||||
from .exceptions import (
|
from .common import register_rh
|
||||||
CertificateVerifyError,
|
from .exceptions import TransportError, RequestError
|
||||||
HTTPError,
|
from .websocket import WebSocketResponse, WebSocketRequestHandler
|
||||||
RequestError,
|
|
||||||
SSLError,
|
|
||||||
TransportError, ProxyError,
|
|
||||||
)
|
|
||||||
from .websocket import WebSocketRequestHandler, WebSocketResponse
|
|
||||||
from ..compat import functools
|
|
||||||
from ..dependencies import websockets
|
from ..dependencies import websockets
|
||||||
from ..utils import int_or_none
|
|
||||||
from ..socks import ProxyError as SocksProxyError
|
|
||||||
|
|
||||||
if not websockets:
|
if not websockets:
|
||||||
raise ImportError('websockets is not installed')
|
raise ImportError('websockets is not installed')
|
||||||
|
|
||||||
import websockets.version
|
|
||||||
|
|
||||||
websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
|
|
||||||
if websockets_version < (12, 0):
|
|
||||||
raise ImportError('Only websockets>=12.0 is supported')
|
|
||||||
|
|
||||||
import websockets.sync.client
|
import websockets.sync.client
|
||||||
from websockets.uri import parse_uri
|
from websockets.exceptions import InvalidHandshake, InvalidURI, ConnectionClosed
|
||||||
|
|
||||||
|
|
||||||
class WebsocketsResponseAdapter(WebSocketResponse):
|
class WebsocketsResponseAdapter(WebSocketResponse):
|
||||||
|
@ -53,10 +44,8 @@ class WebsocketsResponseAdapter(WebSocketResponse):
|
||||||
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
|
||||||
try:
|
try:
|
||||||
return self.wsw.send(*args)
|
return self.wsw.send(*args)
|
||||||
except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
|
except (ConnectionClosed, RuntimeError) as e:
|
||||||
raise TransportError(cause=e) from e
|
raise TransportError(cause=e) from e
|
||||||
except SocksProxyError as e:
|
|
||||||
raise ProxyError(cause=e) from e
|
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise RequestError(cause=e) from e
|
raise RequestError(cause=e) from e
|
||||||
|
|
||||||
|
@ -64,22 +53,13 @@ class WebsocketsResponseAdapter(WebSocketResponse):
|
||||||
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
|
||||||
try:
|
try:
|
||||||
return self.wsw.recv(*args)
|
return self.wsw.recv(*args)
|
||||||
except SocksProxyError as e:
|
except (ConnectionClosed, RuntimeError) as e:
|
||||||
raise ProxyError(cause=e) from e
|
|
||||||
except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
|
|
||||||
raise TransportError(cause=e) from e
|
raise TransportError(cause=e) from e
|
||||||
|
|
||||||
|
|
||||||
@register_rh
|
@register_rh
|
||||||
class WebsocketsRH(WebSocketRequestHandler):
|
class WebsocketsRH(WebSocketRequestHandler):
|
||||||
"""
|
|
||||||
Websockets request handler
|
|
||||||
https://websockets.readthedocs.io
|
|
||||||
https://github.com/python-websockets/websockets
|
|
||||||
"""
|
|
||||||
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
|
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
|
||||||
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
|
|
||||||
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
|
|
||||||
RH_NAME = 'websockets'
|
RH_NAME = 'websockets'
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -98,6 +78,22 @@ class WebsocketsRH(WebSocketRequestHandler):
|
||||||
extensions.pop('cookiejar', None)
|
extensions.pop('cookiejar', None)
|
||||||
|
|
||||||
def _send(self, request):
|
def _send(self, request):
|
||||||
|
"""
|
||||||
|
https://websockets.readthedocs.io/en/stable/reference/sync/client.html
|
||||||
|
TODO:
|
||||||
|
- Cookie Support
|
||||||
|
- Test Exception Mapping
|
||||||
|
- Timeout handling for closing?
|
||||||
|
- WS Pinging
|
||||||
|
- KeyboardInterrupt doesn't seem to kill websockets
|
||||||
|
"""
|
||||||
|
ws_kwargs = {}
|
||||||
|
if urllib.parse.urlparse(request.url).scheme == 'wss':
|
||||||
|
ws_kwargs['ssl_context'] = self._make_sslcontext()
|
||||||
|
|
||||||
|
source_address = self.source_address
|
||||||
|
if source_address is not None:
|
||||||
|
ws_kwargs['source_address'] = source_address
|
||||||
timeout = float(request.extensions.get('timeout') or self.timeout)
|
timeout = float(request.extensions.get('timeout') or self.timeout)
|
||||||
headers = self._merge_headers(request.headers)
|
headers = self._merge_headers(request.headers)
|
||||||
if 'cookie' not in headers:
|
if 'cookie' not in headers:
|
||||||
|
@ -107,53 +103,24 @@ class WebsocketsRH(WebSocketRequestHandler):
|
||||||
headers['cookie'] = cookie_header
|
headers['cookie'] = cookie_header
|
||||||
|
|
||||||
wsuri = parse_uri(request.url)
|
wsuri = parse_uri(request.url)
|
||||||
create_conn_kwargs = {
|
sock = create_connection(
|
||||||
'source_address': (self.source_address, 0) if self.source_address else None,
|
(wsuri.host, wsuri.port),
|
||||||
'timeout': timeout
|
source_address=(self.source_address, 0) if self.source_address else None,
|
||||||
}
|
timeout=timeout
|
||||||
proxy = select_proxy(request.url, request.proxies or self.proxies or {})
|
)
|
||||||
try:
|
try:
|
||||||
if proxy:
|
|
||||||
socks_proxy_options = make_socks_proxy_opts(proxy)
|
|
||||||
sock = create_connection(
|
|
||||||
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
|
|
||||||
_create_socket_func=functools.partial(
|
|
||||||
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
|
|
||||||
**create_conn_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sock = create_connection(
|
|
||||||
address=(wsuri.host, wsuri.port),
|
|
||||||
**create_conn_kwargs
|
|
||||||
)
|
|
||||||
conn = websockets.sync.client.connect(
|
conn = websockets.sync.client.connect(
|
||||||
sock=sock,
|
sock=sock,
|
||||||
uri=request.url,
|
uri=request.url,
|
||||||
additional_headers=headers,
|
additional_headers=headers,
|
||||||
open_timeout=timeout,
|
open_timeout=timeout,
|
||||||
user_agent_header=None,
|
user_agent_header=None,
|
||||||
ssl_context=self._make_sslcontext() if wsuri.secure else None,
|
|
||||||
)
|
)
|
||||||
return WebsocketsResponseAdapter(conn, url=request.url)
|
return WebsocketsResponseAdapter(conn, url=request.url)
|
||||||
|
|
||||||
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
|
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
|
||||||
except SocksProxyError as e:
|
except InvalidURI as e:
|
||||||
raise ProxyError(cause=e) from e
|
|
||||||
except websockets.exceptions.InvalidURI as e:
|
|
||||||
raise RequestError(cause=e) from e
|
raise RequestError(cause=e) from e
|
||||||
except ssl.SSLCertVerificationError as e:
|
except (OSError, TimeoutError, InvalidHandshake) as e:
|
||||||
raise CertificateVerifyError(cause=e) from e
|
|
||||||
except ssl.SSLError as e:
|
|
||||||
raise SSLError(cause=e) from e
|
|
||||||
except websockets.exceptions.InvalidStatus as e:
|
|
||||||
raise HTTPError(
|
|
||||||
Response(
|
|
||||||
fp=io.BytesIO(e.response.body),
|
|
||||||
url=request.url,
|
|
||||||
headers=e.response.headers,
|
|
||||||
status=e.response.status_code,
|
|
||||||
reason=e.response.reason_phrase),
|
|
||||||
) from e
|
|
||||||
except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
|
|
||||||
raise TransportError(cause=e) from e
|
raise TransportError(cause=e) from e
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user