Compare commits

...

3 Commits

Author SHA1 Message Date
coletdjnz
75d70eb640
linter 2023-11-04 14:38:11 +13:00
coletdjnz
d120356dff
Add websocket mTLS tests 2023-11-04 13:44:46 +13:00
coletdjnz
d17a82bfed
don't refactor yet 2023-11-04 13:29:53 +13:00
10 changed files with 71 additions and 70 deletions

View File

@ -6,3 +6,21 @@ import pytest
from yt_dlp.networking import RequestHandler from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -8,9 +8,6 @@ import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import json
import gzip import gzip
import http.client import http.client
import http.cookiejar import http.cookiejar
@ -31,7 +28,7 @@ from http.cookiejar import CookieJar
from test.helper import FakeYDL, http_server_port from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3, websockets from yt_dlp.dependencies import brotli, requests, urllib3
from yt_dlp.networking import ( from yt_dlp.networking import (
HEADRequest, HEADRequest,
PUTRequest, PUTRequest,
@ -293,7 +290,7 @@ class TestRequestHandlerBase:
cls.http_server_thread.start() cls.http_server_thread.start()
# HTTPS server # HTTPS server
certfn = os.path.join(TEST_DIR, '../testcert.pem') certfn = os.path.join(TEST_DIR, 'testcert.pem')
cls.https_httpd = http.server.ThreadingHTTPServer( cls.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler) ('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -694,8 +691,8 @@ class TestClientCertificate:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
certfn = os.path.join(TEST_DIR, '../testcert.pem') certfn = os.path.join(TEST_DIR, 'testcert.pem')
cls.certdir = os.path.join(TEST_DIR, '../testdata', 'certificate') cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate')
cacertfn = os.path.join(cls.certdir, 'ca.crt') cacertfn = os.path.join(cls.certdir, 'ca.crt')
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler) cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)

View File

@ -1,26 +0,0 @@
import functools
import inspect
import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -274,6 +274,7 @@ class WebSocketSocksTestProxyContext(SocksProxyTestContext):
ws.close() ws.close()
return json.loads(socks_info) return json.loads(socks_info)
CTX_MAP = { CTX_MAP = {
'http': HTTPSocksTestProxyContext, 'http': HTTPSocksTestProxyContext,
'ws': WebSocketSocksTestProxyContext, 'ws': WebSocketSocksTestProxyContext,

View File

@ -1,63 +1,34 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import enum
# Allow direct execution # Allow direct execution
import os import os
import sys import sys
from queue import Queue
import pytest import pytest
from websockets.datastructures import Headers
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import gzip
import http.client import http.client
import http.cookiejar import http.cookiejar
import http.server import http.server
import io
import json import json
import pathlib
import random import random
import ssl import ssl
import tempfile
import threading import threading
import time
import urllib.error
import urllib.request
import warnings
import zlib
from email.message import Message
from http.cookiejar import CookieJar
import websockets.sync import websockets.sync
from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3, websockets from yt_dlp.dependencies import websockets
from yt_dlp.networking import ( from yt_dlp.networking import (
HEADRequest,
PUTRequest,
Request, Request,
RequestDirector,
RequestHandler,
Response,
) )
from yt_dlp.networking._urllib import UrllibRH
from yt_dlp.networking.exceptions import ( from yt_dlp.networking.exceptions import (
CertificateVerifyError, CertificateVerifyError,
HTTPError, HTTPError,
IncompleteRead,
NoSupportingHandlers,
ProxyError,
RequestError,
SSLError, SSLError,
TransportError, TransportError,
UnsupportedRequest,
) )
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send from .conftest import validate_and_send
@ -102,12 +73,27 @@ def create_ws_websocket_server():
def create_wss_websocket_server(): def create_wss_websocket_server():
certfn = os.path.join(TEST_DIR, '../testcert.pem') certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None) sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx) return create_websocket_server(ssl_context=sslctx)
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
def create_mtls_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.load_verify_locations(cafile=cacertfn)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance: class TestWebsSocketRequestHandlerConformance:
@classmethod @classmethod
@ -121,6 +107,9 @@ class TestWebsSocketRequestHandlerConformance:
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler): def test_basic_websockets(self, handler):
with handler() as rh: with handler() as rh:
@ -256,3 +245,29 @@ class TestWebsSocketRequestHandlerConformance:
assert headers['test2'] == 'changed' assert headers['test2'] == 'changed'
assert headers['test3'] == 'test3' assert headers['test3'] == 'test3'
ws.close() ws.close()
@pytest.mark.parametrize('client_cert', (
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
'client_certificate_password': 'foobar',
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
'client_certificate_password': 'foobar',
}
))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_mtls(self, handler, client_cert):
with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem
# The test is of a check on the server side, so unaffected
verify=False,
client_cert=client_cert
) as rh:
validate_and_send(rh, Request(self.mtls_wss_base_url))

View File

@ -13,7 +13,6 @@ from ..networking.exceptions import HTTPError
from ..utils import ( from ..utils import (
ExtractorError, ExtractorError,
OnDemandPagedList, OnDemandPagedList,
WebSocketsWrapper,
bug_reports_message, bug_reports_message,
clean_html, clean_html,
float_or_none, float_or_none,

View File

@ -153,7 +153,6 @@ class WebsocketsRH(WebSocketRequestHandler):
headers=e.response.headers, headers=e.response.headers,
status=e.response.status_code, status=e.response.status_code,
reason=e.response.reason_phrase), reason=e.response.reason_phrase),
) from e ) from e
except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e: except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e

View File

@ -1,5 +1,3 @@
import asyncio
import atexit
import base64 import base64
import binascii import binascii
import calendar import calendar
@ -54,7 +52,7 @@ from ..compat import (
compat_os_name, compat_os_name,
compat_shlex_quote, compat_shlex_quote,
) )
from ..dependencies import websockets, xattr from ..dependencies import xattr
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module __name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module