Compare commits

...

3 Commits

Author SHA1 Message Date
coletdjnz
75d70eb640
linter 2023-11-04 14:38:11 +13:00
coletdjnz
d120356dff
Add websocket mTLS tests 2023-11-04 13:44:46 +13:00
coletdjnz
d17a82bfed
don't refactor yet 2023-11-04 13:29:53 +13:00
10 changed files with 71 additions and 70 deletions

View File

@ -6,3 +6,21 @@ import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -8,9 +8,6 @@ import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import json
import gzip
import http.client
import http.cookiejar
@ -31,7 +28,7 @@ from http.cookiejar import CookieJar
from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3, websockets
from yt_dlp.dependencies import brotli, requests, urllib3
from yt_dlp.networking import (
HEADRequest,
PUTRequest,
@ -293,7 +290,7 @@ class TestRequestHandlerBase:
cls.http_server_thread.start()
# HTTPS server
certfn = os.path.join(TEST_DIR, '../testcert.pem')
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cls.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -694,8 +691,8 @@ class TestClientCertificate:
@classmethod
def setup_class(cls):
certfn = os.path.join(TEST_DIR, '../testcert.pem')
cls.certdir = os.path.join(TEST_DIR, '../testdata', 'certificate')
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate')
cacertfn = os.path.join(cls.certdir, 'ca.crt')
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)

View File

@ -1,26 +0,0 @@
import functools
import inspect
import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -274,6 +274,7 @@ class WebSocketSocksTestProxyContext(SocksProxyTestContext):
ws.close()
return json.loads(socks_info)
CTX_MAP = {
'http': HTTPSocksTestProxyContext,
'ws': WebSocketSocksTestProxyContext,

View File

@ -1,63 +1,34 @@
#!/usr/bin/env python3
import enum
# Allow direct execution
import os
import sys
from queue import Queue
import pytest
from websockets.datastructures import Headers
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import gzip
import http.client
import http.cookiejar
import http.server
import io
import json
import pathlib
import random
import ssl
import tempfile
import threading
import time
import urllib.error
import urllib.request
import warnings
import zlib
from email.message import Message
from http.cookiejar import CookieJar
import websockets.sync
from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3, websockets
from yt_dlp.dependencies import websockets
from yt_dlp.networking import (
HEADRequest,
PUTRequest,
Request,
RequestDirector,
RequestHandler,
Response,
)
from yt_dlp.networking._urllib import UrllibRH
from yt_dlp.networking.exceptions import (
CertificateVerifyError,
HTTPError,
IncompleteRead,
NoSupportingHandlers,
ProxyError,
RequestError,
SSLError,
TransportError,
UnsupportedRequest,
)
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send
@ -102,12 +73,27 @@ def create_ws_websocket_server():
def create_wss_websocket_server():
certfn = os.path.join(TEST_DIR, '../testcert.pem')
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
def create_mtls_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.load_verify_locations(cafile=cacertfn)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance:
@classmethod
@ -121,6 +107,9 @@ class TestWebsSocketRequestHandlerConformance:
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler):
with handler() as rh:
@ -256,3 +245,29 @@ class TestWebsSocketRequestHandlerConformance:
assert headers['test2'] == 'changed'
assert headers['test3'] == 'test3'
ws.close()
@pytest.mark.parametrize('client_cert', (
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
'client_certificate_password': 'foobar',
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
'client_certificate_password': 'foobar',
}
))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_mtls(self, handler, client_cert):
with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem
# The test is of a check on the server side, so unaffected
verify=False,
client_cert=client_cert
) as rh:
validate_and_send(rh, Request(self.mtls_wss_base_url))

View File

@ -13,7 +13,6 @@ from ..networking.exceptions import HTTPError
from ..utils import (
ExtractorError,
OnDemandPagedList,
WebSocketsWrapper,
bug_reports_message,
clean_html,
float_or_none,

View File

@ -156,4 +156,3 @@ class WebsocketsRH(WebSocketRequestHandler):
) from e
except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
raise TransportError(cause=e) from e

View File

@ -1,5 +1,3 @@
import asyncio
import atexit
import base64
import binascii
import calendar
@ -54,7 +52,7 @@ from ..compat import (
compat_os_name,
compat_shlex_quote,
)
from ..dependencies import websockets, xattr
from ..dependencies import xattr
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module