Compare commits

..

No commits in common. "75d70eb640eba4aaa2b4eceab23687a5adec2510" and "9e05210991eebfafcd86912574bbd7dacd94fe55" have entirely different histories.

10 changed files with 70 additions and 71 deletions

View File

@ -6,21 +6,3 @@ import pytest
from yt_dlp.networking import RequestHandler from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

View File

@ -0,0 +1,26 @@
import functools
import inspect
import pytest
from yt_dlp.networking import RequestHandler
from yt_dlp.networking.common import _REQUEST_HANDLERS
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
RH_KEY = request.param
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
handler = _REQUEST_HANDLERS[RH_KEY]
else:
pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

View File

@ -8,6 +8,9 @@ import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import json
import gzip import gzip
import http.client import http.client
import http.cookiejar import http.cookiejar
@ -28,7 +31,7 @@ from http.cookiejar import CookieJar
from test.helper import FakeYDL, http_server_port from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, requests, urllib3 from yt_dlp.dependencies import brotli, requests, urllib3, websockets
from yt_dlp.networking import ( from yt_dlp.networking import (
HEADRequest, HEADRequest,
PUTRequest, PUTRequest,
@ -290,7 +293,7 @@ class TestRequestHandlerBase:
cls.http_server_thread.start() cls.http_server_thread.start()
# HTTPS server # HTTPS server
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, '../testcert.pem')
cls.https_httpd = http.server.ThreadingHTTPServer( cls.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler) ('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@ -691,8 +694,8 @@ class TestClientCertificate:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, '../testcert.pem')
cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate') cls.certdir = os.path.join(TEST_DIR, '../testdata', 'certificate')
cacertfn = os.path.join(cls.certdir, 'ca.crt') cacertfn = os.path.join(cls.certdir, 'ca.crt')
cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler) cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)

View File

@ -274,7 +274,6 @@ class WebSocketSocksTestProxyContext(SocksProxyTestContext):
ws.close() ws.close()
return json.loads(socks_info) return json.loads(socks_info)
CTX_MAP = { CTX_MAP = {
'http': HTTPSocksTestProxyContext, 'http': HTTPSocksTestProxyContext,
'ws': WebSocketSocksTestProxyContext, 'ws': WebSocketSocksTestProxyContext,

View File

@ -1,34 +1,63 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import enum
# Allow direct execution # Allow direct execution
import os import os
import sys import sys
from queue import Queue
import pytest import pytest
from websockets.datastructures import Headers
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import functools
import gzip
import http.client import http.client
import http.cookiejar import http.cookiejar
import http.server import http.server
import io
import json import json
import pathlib
import random import random
import ssl import ssl
import tempfile
import threading import threading
import time
import urllib.error
import urllib.request
import warnings
import zlib
from email.message import Message
from http.cookiejar import CookieJar
import websockets.sync import websockets.sync
from test.helper import FakeYDL, http_server_port
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import websockets from yt_dlp.dependencies import brotli, requests, urllib3, websockets
from yt_dlp.networking import ( from yt_dlp.networking import (
HEADRequest,
PUTRequest,
Request, Request,
RequestDirector,
RequestHandler,
Response,
) )
from yt_dlp.networking._urllib import UrllibRH
from yt_dlp.networking.exceptions import ( from yt_dlp.networking.exceptions import (
CertificateVerifyError, CertificateVerifyError,
HTTPError, HTTPError,
IncompleteRead,
NoSupportingHandlers,
ProxyError,
RequestError,
SSLError, SSLError,
TransportError, TransportError,
UnsupportedRequest,
) )
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from .conftest import validate_and_send from .conftest import validate_and_send
@ -73,27 +102,12 @@ def create_ws_websocket_server():
def create_wss_websocket_server(): def create_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, '../testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None) sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx) return create_websocket_server(ssl_context=sslctx)
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
def create_mtls_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.load_verify_locations(cafile=cacertfn)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance: class TestWebsSocketRequestHandlerConformance:
@classmethod @classmethod
@ -107,9 +121,6 @@ class TestWebsSocketRequestHandlerConformance:
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler): def test_basic_websockets(self, handler):
with handler() as rh: with handler() as rh:
@ -245,29 +256,3 @@ class TestWebsSocketRequestHandlerConformance:
assert headers['test2'] == 'changed' assert headers['test2'] == 'changed'
assert headers['test3'] == 'test3' assert headers['test3'] == 'test3'
ws.close() ws.close()
@pytest.mark.parametrize('client_cert', (
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
'client_certificate_password': 'foobar',
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
'client_certificate_password': 'foobar',
}
))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_mtls(self, handler, client_cert):
with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem
# The test is of a check on the server side, so unaffected
verify=False,
client_cert=client_cert
) as rh:
validate_and_send(rh, Request(self.mtls_wss_base_url))

View File

@ -13,6 +13,7 @@ from ..networking.exceptions import HTTPError
from ..utils import ( from ..utils import (
ExtractorError, ExtractorError,
OnDemandPagedList, OnDemandPagedList,
WebSocketsWrapper,
bug_reports_message, bug_reports_message,
clean_html, clean_html,
float_or_none, float_or_none,

View File

@ -156,3 +156,4 @@ class WebsocketsRH(WebSocketRequestHandler):
) from e ) from e
except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e: except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e

View File

@ -1,3 +1,5 @@
import asyncio
import atexit
import base64 import base64
import binascii import binascii
import calendar import calendar
@ -52,7 +54,7 @@ from ..compat import (
compat_os_name, compat_os_name,
compat_shlex_quote, compat_shlex_quote,
) )
from ..dependencies import xattr from ..dependencies import websockets, xattr
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module __name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module