Compare commits

..

No commits in common. "6c1b87efd7b18057c043b38fcf6f119612f503ce" and "75d70eb640eba4aaa2b4eceab23687a5adec2510" have entirely different histories.

9 changed files with 55 additions and 181 deletions

View File

@ -52,7 +52,7 @@ from yt_dlp.networking.exceptions import (
from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from test.conftest import validate_and_send from .conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
@ -369,6 +369,7 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
def test_raise_http_error(self, handler): def test_raise_http_error(self, handler):
with handler() as rh: with handler() as rh:
# TODO Return HTTP status code url
for bad_status in (400, 500, 599, 302): for bad_status in (400, 500, 599, 302):
with pytest.raises(HTTPError): with pytest.raises(HTTPError):
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status))) validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status)))
@ -869,9 +870,8 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
]) ])
@pytest.mark.parametrize('handler', ['Requests'], indirect=True) @pytest.mark.parametrize('handler', ['Requests'], indirect=True)
def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match): def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
from requests.models import Response as RequestsResponse
from urllib3.response import HTTPResponse as Urllib3Response from urllib3.response import HTTPResponse as Urllib3Response
from requests.models import Response as RequestsResponse
from yt_dlp.networking._requests import RequestsResponseAdapter from yt_dlp.networking._requests import RequestsResponseAdapter
requests_res = RequestsResponse() requests_res = RequestsResponse()
requests_res.raw = Urllib3Response(body=b'', status=200) requests_res.raw = Urllib3Response(body=b'', status=200)
@ -996,7 +996,7 @@ class TestRequestHandlerValidation:
({'cookiejar': 'notacookiejar'}, False), ({'cookiejar': 'notacookiejar'}, False),
({'somerandom': 'test'}, False), # but any extension is allowed through ({'somerandom': 'test'}, False), # but any extension is allowed through
]), ]),
('Websockets', 'ws', [ ('Websockets', 'ws', [ # TODO
({'cookiejar': YoutubeDLCookieJar()}, False), ({'cookiejar': YoutubeDLCookieJar()}, False),
({'timeout': 2}, False), ({'timeout': 2}, False),
]), ]),
@ -1084,22 +1084,6 @@ class FakeRHYDL(FakeYDL):
self._request_director = self.build_request_director([FakeRH]) self._request_director = self.build_request_director([FakeRH])
class AllUnsupportedRHYDL(FakeYDL):
def __init__(self, *args, **kwargs):
class UnsupportedRH(RequestHandler):
def _send(self, request: Request):
pass
_SUPPORTED_FEATURES = ()
_SUPPORTED_PROXY_SCHEMES = ()
_SUPPORTED_URL_SCHEMES = ()
super().__init__(*args, **kwargs)
self._request_director = self.build_request_director([UnsupportedRH])
class TestRequestDirector: class TestRequestDirector:
def test_handler_operations(self): def test_handler_operations(self):
@ -1259,12 +1243,6 @@ class TestYoutubeDLNetworking:
with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'): with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
ydl.urlopen('file://') ydl.urlopen('file://')
@pytest.mark.parametrize('scheme', (['ws', 'wss']))
def test_websocket_unavailable_error(self, scheme):
with AllUnsupportedRHYDL() as ydl:
with pytest.raises(RequestError, match=r'This request requires WebSocket support'):
ydl.urlopen(f'{scheme}://')
def test_legacy_server_connect_error(self): def test_legacy_server_connect_error(self):
with FakeRHYDL() as ydl: with FakeRHYDL() as ydl:
for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'): for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):

View File

@ -16,40 +16,36 @@ import random
import ssl import ssl
import threading import threading
from yt_dlp import socks import websockets.sync
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import websockets from yt_dlp.dependencies import websockets
from yt_dlp.networking import Request from yt_dlp.networking import (
Request,
)
from yt_dlp.networking.exceptions import ( from yt_dlp.networking.exceptions import (
CertificateVerifyError, CertificateVerifyError,
HTTPError, HTTPError,
ProxyError,
RequestError,
SSLError, SSLError,
TransportError, TransportError,
) )
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from test.conftest import validate_and_send from .conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
def websocket_handler(websocket): def websocket_handler(websocket):
for message in websocket: for message in websocket:
if isinstance(message, bytes): if message == 'headers':
if message == b'bytes': return websocket.send(json.dumps(dict(websocket.request.headers)))
return websocket.send('2') elif message == 'path':
elif isinstance(message, str): return websocket.send(websocket.request.path)
if message == 'headers': elif message == 'source_address':
return websocket.send(json.dumps(dict(websocket.request.headers))) return websocket.send(websocket.remote_address[0])
elif message == 'path': else:
return websocket.send(websocket.request.path) return websocket.send(message)
elif message == 'source_address':
return websocket.send(websocket.remote_address[0])
elif message == 'str':
return websocket.send('1')
return websocket.send(message)
def process_request(self, request): def process_request(self, request):
@ -124,16 +120,6 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.recv() == 'foo' assert ws.recv() == 'foo'
ws.close() ws.close()
# https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
@pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_types(self, handler, msg, opcode):
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send(msg)
assert int(ws.recv()) == opcode
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_verify_cert(self, handler): def test_verify_cert(self, handler):
with handler() as rh: with handler() as rh:
@ -285,95 +271,3 @@ class TestWebsSocketRequestHandlerConformance:
client_cert=client_cert client_cert=client_cert
) as rh: ) as rh:
validate_and_send(rh, Request(self.mtls_wss_base_url)) validate_and_send(rh, Request(self.mtls_wss_base_url))
def create_fake_ws_connection(raised):
import websockets.sync.client
class FakeWsConnection(websockets.sync.client.ClientConnection):
def __init__(self, *args, **kwargs):
class FakeResponse:
body = b''
headers = {}
status_code = 101
reason_phrase = 'test'
self.response = FakeResponse()
def send(self, *args, **kwargs):
raise raised()
def recv(self, *args, **kwargs):
raise raised()
def close(self, *args, **kwargs):
return
return FakeWsConnection()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
class TestWebsocketsRequestHandler:
@pytest.mark.parametrize('raised,expected', [
# https://websockets.readthedocs.io/en/stable/reference/exceptions.html
(lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
# Requires a response object. Should be covered by HTTP error tests.
# (lambda: websockets.exceptions.InvalidStatus(), TransportError),
(lambda: websockets.exceptions.InvalidHandshake(), TransportError),
# These are subclasses of InvalidHandshake
(lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
(lambda: websockets.exceptions.NegotiationError(), TransportError),
# Catch-all
(lambda: websockets.exceptions.WebSocketException(), TransportError),
(lambda: TimeoutError(), TransportError),
# These may be raised by our create_connection implementation, which should also be caught
(lambda: OSError(), TransportError),
(lambda: ssl.SSLError(), SSLError),
(lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
(lambda: socks.ProxyError(), ProxyError),
])
def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
import websockets.sync.client
import yt_dlp.networking._websockets
with handler() as rh:
def fake_connect(*args, **kwargs):
raise raised()
monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
with pytest.raises(expected) as exc_info:
rh.send(Request('ws://fake-url'))
assert exc_info.type is expected
@pytest.mark.parametrize('raised,expected,match', [
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
(lambda: RuntimeError(), TransportError, None),
(lambda: TimeoutError(), TransportError, None),
(lambda: TypeError(), RequestError, None),
(lambda: socks.ProxyError(), ProxyError, None),
# Catch-all
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
])
def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
with pytest.raises(expected, match=match) as exc_info:
ws.send('test')
assert exc_info.type is expected
@pytest.mark.parametrize('raised,expected,match', [
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
(lambda: RuntimeError(), TransportError, None),
(lambda: TimeoutError(), TransportError, None),
(lambda: socks.ProxyError(), ProxyError, None),
# Catch-all
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
])
def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
with pytest.raises(expected, match=match) as exc_info:
ws.recv()
assert exc_info.type is expected

View File

@ -4051,7 +4051,6 @@ class YoutubeDL:
return self._request_director.send(req) return self._request_director.send(req)
except NoSupportingHandlers as e: except NoSupportingHandlers as e:
for ue in e.unsupported_errors: for ue in e.unsupported_errors:
# FIXME: This depends on the order of errors.
if not (ue.handler and ue.msg): if not (ue.handler and ue.msg):
continue continue
if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower(): if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
@ -4061,15 +4060,6 @@ class YoutubeDL:
if 'unsupported proxy type: "https"' in ue.msg.lower(): if 'unsupported proxy type: "https"' in ue.msg.lower():
raise RequestError( raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests') 'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests')
elif (
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
and not set(self._request_director.handlers.keys()).intersection({'websockets'})
):
raise RequestError(
'This request requires WebSocket support. '
'Ensure one of the following dependencies are installed: websockets',
cause=ue) from ue
raise raise
except SSLError as e: except SSLError as e:
if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e): if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):

View File

@ -6,7 +6,7 @@ from . import get_suitable_downloader
from .common import FileDownloader from .common import FileDownloader
from .external import FFmpegFD from .external import FFmpegFD
from ..networking import Request from ..networking import Request
from ..utils import DownloadError, str_or_none, try_get from ..utils import DownloadError, WebSocketsWrapper, str_or_none, try_get
class NiconicoDmcFD(FileDownloader): class NiconicoDmcFD(FileDownloader):
@ -64,6 +64,7 @@ class NiconicoLiveFD(FileDownloader):
ws_url = info_dict['url'] ws_url = info_dict['url']
ws_extractor = info_dict['ws'] ws_extractor = info_dict['ws']
ws_origin_host = info_dict['origin'] ws_origin_host = info_dict['origin']
cookies = info_dict.get('cookies')
live_quality = info_dict.get('live_quality', 'high') live_quality = info_dict.get('live_quality', 'high')
live_latency = info_dict.get('live_latency', 'high') live_latency = info_dict.get('live_latency', 'high')
dl = FFmpegFD(self.ydl, self.params or {}) dl = FFmpegFD(self.ydl, self.params or {})
@ -75,7 +76,12 @@ class NiconicoLiveFD(FileDownloader):
def communicate_ws(reconnect): def communicate_ws(reconnect):
if reconnect: if reconnect:
ws = self.ydl.urlopen(Request(ws_url, headers={'Origin': f'https://{ws_origin_host}'})) ws = WebSocketsWrapper(ws_url, {
'Cookies': str_or_none(cookies) or '',
'Origin': f'https://{ws_origin_host}',
'Accept': '*/*',
'User-Agent': self.params['http_headers']['User-Agent'],
})
if self.ydl.params.get('verbose', False): if self.ydl.params.get('verbose', False):
self.to_screen('[debug] Sending startWatching request') self.to_screen('[debug] Sending startWatching request')
ws.send(json.dumps({ ws.send(json.dumps({

View File

@ -196,7 +196,10 @@ class FC2LiveIE(InfoExtractor):
playlist_data = None playlist_data = None
ws = self._request_webpage(Request(ws_url, headers={ ws = self._request_webpage(Request(ws_url, headers={
'Cookie': str(self._get_cookies('https://live.fc2.com/'))[12:],
'Origin': 'https://live.fc2.com', 'Origin': 'https://live.fc2.com',
'Accept': '*/*',
'User-Agent': self.get_param('http_headers')['User-Agent'],
}), video_id, note='Fetching HLS playlist info via WebSocket') }), video_id, note='Fetching HLS playlist info via WebSocket')
self.write_debug('Sending HLS server request') self.write_debug('Sending HLS server request')

View File

@ -947,13 +947,17 @@ class NiconicoLiveIE(InfoExtractor):
}) })
hostname = remove_start(urlparse(urlh.url).hostname, 'sp.') hostname = remove_start(urlparse(urlh.url).hostname, 'sp.')
cookies = try_get(urlh.url, self._downloader._calc_cookies)
latency = try_get(self._configuration_arg('latency'), lambda x: x[0]) latency = try_get(self._configuration_arg('latency'), lambda x: x[0])
if latency not in self._KNOWN_LATENCY: if latency not in self._KNOWN_LATENCY:
latency = 'high' latency = 'high'
ws = self._request_webpage( ws = self._request_webpage(Request(ws_url, headers={
Request(ws_url, headers={'Origin': f'https://{hostname}'}), 'Cookies': str_or_none(cookies) or '',
video_id=video_id, note='Connecting to WebSocket server') 'Origin': f'https://{hostname}',
'Accept': '*/*',
'User-Agent': self.get_param('http_headers')['User-Agent'],
}), video_id=video_id, note='Connecting to WebSocket server')
self.write_debug('[debug] Sending HLS server request') self.write_debug('[debug] Sending HLS server request')
ws.send(json.dumps({ ws.send(json.dumps({
@ -1027,6 +1031,7 @@ class NiconicoLiveIE(InfoExtractor):
'protocol': 'niconico_live', 'protocol': 'niconico_live',
'ws': ws, 'ws': ws,
'video_id': video_id, 'video_id': video_id,
'cookies': cookies,
'live_latency': latency, 'live_latency': latency,
'origin': hostname, 'origin': hostname,
}) })

View File

@ -37,36 +37,36 @@ class WebsocketsResponseAdapter(WebSocketResponse):
def __init__(self, wsw: websockets.sync.client.ClientConnection, url): def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
super().__init__( super().__init__(
fp=io.BytesIO(wsw.response.body or b''), fp=io.BytesIO(wsw.response.body or b''), # TODO: test
url=url, url=url,
headers=wsw.response.headers, headers=wsw.response.headers, # TODO: test multiple headers (may need to use raw_items())
status=wsw.response.status_code, status=wsw.response.status_code,
reason=wsw.response.reason_phrase, reason=wsw.response.reason_phrase,
) )
self.wsw = wsw self.wsw = wsw
def close(self): def close(self, status=None):
self.wsw.close() self.wsw.close()
super().close() super().close()
def send(self, message): def send(self, *args):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
try: try:
return self.wsw.send(message) return self.wsw.send(*args)
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e
except SocksProxyError as e: except SocksProxyError as e:
raise ProxyError(cause=e) from e raise ProxyError(cause=e) from e
except TypeError as e: except TypeError as e:
raise RequestError(cause=e) from e raise RequestError(cause=e) from e
def recv(self): def recv(self, *args):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
try: try:
return self.wsw.recv() return self.wsw.recv(*args)
except SocksProxyError as e: except SocksProxyError as e:
raise ProxyError(cause=e) from e raise ProxyError(cause=e) from e
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: except (websockets.exceptions.ConnectionClosed, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e
@ -133,7 +133,6 @@ class WebsocketsRH(WebSocketRequestHandler):
open_timeout=timeout, open_timeout=timeout,
user_agent_header=None, user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None, ssl_context=self._make_sslcontext() if wsuri.secure else None,
close_timeout=0, # not ideal, but prevents yt-dlp hanging
) )
return WebsocketsResponseAdapter(conn, url=request.url) return WebsocketsResponseAdapter(conn, url=request.url)
@ -155,5 +154,5 @@ class WebsocketsRH(WebSocketRequestHandler):
status=e.response.status_code, status=e.response.status_code,
reason=e.response.reason_phrase), reason=e.response.reason_phrase),
) from e ) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: except (OSError, TimeoutError, websockets.exceptions.InvalidHandshake) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e

View File

@ -1,23 +1,21 @@
from __future__ import annotations
import abc import abc
from .common import Response, RequestHandler from .common import Response, RequestHandler
from .exceptions import TransportError
class WebSocketResponse(Response): class WebSocketResponse(Response):
def send(self, message: bytes | str): def send(self, *args):
"""
Send a message to the server.
@param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
"""
raise NotImplementedError raise NotImplementedError
def recv(self): def recv(self, *args):
raise NotImplementedError raise NotImplementedError
class WebSocketException(TransportError):
pass
class WebSocketRequestHandler(RequestHandler, abc.ABC): class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass pass

View File

@ -30,6 +30,7 @@ from ..networking._urllib import ( # noqa: F401
) )
from ..networking.exceptions import HTTPError, network_exceptions # noqa: F401 from ..networking.exceptions import HTTPError, network_exceptions # noqa: F401
has_certifi = bool(certifi) has_certifi = bool(certifi)
has_websockets = bool(websockets) has_websockets = bool(websockets)