Add Code
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,265 @@
|
||||
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
# Licensed under the Apache License (Apache-2.0)
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import sys
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
|
||||
# Licensed under the Python Software Foundation License (PSF-2.0)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import final
|
||||
else:
|
||||
# @final exists in 3.8+, but we backport it for all versions
|
||||
# before 3.11 to keep support for the __final__ attribute.
|
||||
# See https://bugs.python.org/issue46342
|
||||
def final(f):
|
||||
"""This decorator can be used to indicate to type checkers that
|
||||
the decorated method cannot be overridden, and decorated class
|
||||
cannot be subclassed. For example:
|
||||
|
||||
class Base:
|
||||
@final
|
||||
def done(self) -> None:
|
||||
...
|
||||
class Sub(Base):
|
||||
def done(self) -> None: # Error reported by type checker
|
||||
...
|
||||
@final
|
||||
class Leaf:
|
||||
...
|
||||
class Other(Leaf): # Error reported by type checker
|
||||
...
|
||||
|
||||
There is no runtime checking of these properties. The decorator
|
||||
sets the ``__final__`` attribute to ``True`` on the decorated object
|
||||
to allow runtime introspection.
|
||||
"""
|
||||
try:
|
||||
f.__final__ = True
|
||||
except (AttributeError, TypeError):
|
||||
# Skip the attribute silently if it is not writable.
|
||||
# AttributeError happens if the object has __slots__ or a
|
||||
# read-only property, TypeError if it's a builtin class.
|
||||
pass
|
||||
return f
|
||||
|
||||
|
||||
# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
|
||||
__version__ = "4.0.2"
|
||||
|
||||
|
||||
__all__ = ("timeout", "timeout_at", "Timeout")
|
||||
|
||||
|
||||
def timeout(delay: Optional[float]) -> "Timeout":
|
||||
"""timeout context manager.
|
||||
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
|
||||
>>> async with timeout(0.001):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
|
||||
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
if delay is not None:
|
||||
deadline = loop.time() + delay # type: Optional[float]
|
||||
else:
|
||||
deadline = None
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
|
||||
def timeout_at(deadline: Optional[float]) -> "Timeout":
|
||||
"""Schedule the timeout at absolute time.
|
||||
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
|
||||
>>> async with timeout_at(loop.time() + 10):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
|
||||
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
|
||||
class _State(enum.Enum):
|
||||
INIT = "INIT"
|
||||
ENTER = "ENTER"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
EXIT = "EXIT"
|
||||
|
||||
|
||||
@final
|
||||
class Timeout:
|
||||
# Internal class, please don't instantiate it directly
|
||||
# Use timeout() and timeout_at() public factories instead.
|
||||
#
|
||||
# Implementation note: `async with timeout()` is preferred
|
||||
# over `with timeout()`.
|
||||
# While technically the Timeout class implementation
|
||||
# doesn't need to be async at all,
|
||||
# the `async with` statement explicitly points that
|
||||
# the context manager should be used from async function context.
|
||||
#
|
||||
# This design allows to avoid many silly misusages.
|
||||
#
|
||||
# TimeoutError is raised immediately when scheduled
|
||||
# if the deadline is passed.
|
||||
# The purpose is to time out as soon as possible
|
||||
# without waiting for the next await expression.
|
||||
|
||||
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
|
||||
|
||||
def __init__(
|
||||
self, deadline: Optional[float], loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._loop = loop
|
||||
self._state = _State.INIT
|
||||
|
||||
self._timeout_handler = None # type: Optional[asyncio.Handle]
|
||||
if deadline is None:
|
||||
self._deadline = None # type: Optional[float]
|
||||
else:
|
||||
self.update(deadline)
|
||||
|
||||
def __enter__(self) -> "Timeout":
|
||||
warnings.warn(
|
||||
"with timeout() is deprecated, use async with timeout() instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state == _State.TIMEOUT
|
||||
|
||||
@property
|
||||
def deadline(self) -> Optional[float]:
|
||||
return self._deadline
|
||||
|
||||
def reject(self) -> None:
|
||||
"""Reject scheduled timeout if any."""
|
||||
# cancel is maybe better name but
|
||||
# task.cancel() raises CancelledError in asyncio world.
|
||||
if self._state not in (_State.INIT, _State.ENTER):
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._reject()
|
||||
|
||||
def _reject(self) -> None:
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
def shift(self, delay: float) -> None:
|
||||
"""Advance timeout on delay seconds.
|
||||
|
||||
The delay can be negative.
|
||||
|
||||
Raise RuntimeError if shift is called when deadline is not scheduled
|
||||
"""
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
raise RuntimeError("cannot shift timeout if deadline is not scheduled")
|
||||
self.update(deadline + delay)
|
||||
|
||||
def update(self, deadline: float) -> None:
|
||||
"""Set deadline to absolute value.
|
||||
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
|
||||
If new deadline is in the past the timeout is raised immediately.
|
||||
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
"""
|
||||
if self._state == _State.EXIT:
|
||||
raise RuntimeError("cannot reschedule after exit from context manager")
|
||||
if self._state == _State.TIMEOUT:
|
||||
raise RuntimeError("cannot reschedule expired timeout")
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._deadline = deadline
|
||||
if self._state != _State.INIT:
|
||||
self._reschedule()
|
||||
|
||||
def _reschedule(self) -> None:
|
||||
assert self._state == _State.ENTER
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
return
|
||||
|
||||
now = self._loop.time()
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
task = asyncio.current_task()
|
||||
if deadline <= now:
|
||||
self._timeout_handler = self._loop.call_soon(self._on_timeout, task)
|
||||
else:
|
||||
self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task)
|
||||
|
||||
def _do_enter(self) -> None:
|
||||
if self._state != _State.INIT:
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._state = _State.ENTER
|
||||
self._reschedule()
|
||||
|
||||
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
|
||||
if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
|
||||
self._timeout_handler = None
|
||||
raise asyncio.TimeoutError
|
||||
# timeout has not expired
|
||||
self._state = _State.EXIT
|
||||
self._reject()
|
||||
return None
|
||||
|
||||
def _on_timeout(self, task: "asyncio.Task[None]") -> None:
|
||||
task.cancel()
|
||||
self._state = _State.TIMEOUT
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
||||
|
||||
|
||||
# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hmac
|
||||
import http
|
||||
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..headers import build_www_authenticate_basic, parse_authorization_basic
|
||||
from .server import HTTPResponse, WebSocketServerProtocol
|
||||
|
||||
|
||||
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
|
||||
|
||||
Credentials = Tuple[str, str]
|
||||
|
||||
|
||||
def is_credentials(value: Any) -> bool:
|
||||
try:
|
||||
username, password = value
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
WebSocket server protocol that enforces HTTP Basic Auth.
|
||||
|
||||
"""
|
||||
|
||||
realm: str = ""
|
||||
"""
|
||||
Scope of protection.
|
||||
|
||||
If provided, it should contain only ASCII characters because the
|
||||
encoding of non-ASCII characters is undefined.
|
||||
"""
|
||||
|
||||
username: Optional[str] = None
|
||||
"""Username of the authenticated user."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
realm: Optional[str] = None,
|
||||
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if realm is not None:
|
||||
self.realm = realm # shadow class attribute
|
||||
self._check_credentials = check_credentials
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def check_credentials(self, username: str, password: str) -> bool:
|
||||
"""
|
||||
Check whether credentials are authorized.
|
||||
|
||||
This coroutine may be overridden in a subclass, for example to
|
||||
authenticate against a database or an external service.
|
||||
|
||||
Args:
|
||||
username: HTTP Basic Auth username.
|
||||
password: HTTP Basic Auth password.
|
||||
|
||||
Returns:
|
||||
bool: :obj:`True` if the handshake should continue;
|
||||
:obj:`False` if it should fail with an HTTP 401 error.
|
||||
|
||||
"""
|
||||
if self._check_credentials is not None:
|
||||
return await self._check_credentials(username, password)
|
||||
|
||||
return False
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
path: str,
|
||||
request_headers: Headers,
|
||||
) -> Optional[HTTPResponse]:
|
||||
"""
|
||||
Check HTTP Basic Auth and return an HTTP 401 response if needed.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request_headers["Authorization"]
|
||||
except KeyError:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Missing credentials\n",
|
||||
)
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Unsupported credentials\n",
|
||||
)
|
||||
|
||||
if not await self.check_credentials(username, password):
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Invalid credentials\n",
|
||||
)
|
||||
|
||||
self.username = username
|
||||
|
||||
return await super().process_request(path, request_headers)
|
||||
|
||||
|
||||
def basic_auth_protocol_factory(
|
||||
realm: Optional[str] = None,
|
||||
credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
|
||||
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
|
||||
create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None,
|
||||
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
|
||||
"""
|
||||
Protocol factory that enforces HTTP Basic Auth.
|
||||
|
||||
:func:`basic_auth_protocol_factory` is designed to integrate with
|
||||
:func:`~websockets.server.serve` like this::
|
||||
|
||||
websockets.serve(
|
||||
...,
|
||||
create_protocol=websockets.basic_auth_protocol_factory(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
)
|
||||
)
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined.
|
||||
Refer to section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Coroutine that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments
|
||||
and returns a :class:`bool`. One of ``credentials`` or
|
||||
``check_credentials`` must be provided but not both.
|
||||
create_protocol: Factory that creates the protocol. By default, this
|
||||
is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
|
||||
by a subclass.
|
||||
Raises:
|
||||
TypeError: If the ``credentials`` or ``check_credentials`` argument is
|
||||
wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Credentials, credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(credentials)
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
async def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
if create_protocol is None:
|
||||
create_protocol = BasicAuthWebSocketServerProtocol
|
||||
|
||||
return functools.partial(
|
||||
create_protocol,
|
||||
realm=realm,
|
||||
check_credentials=check_credentials,
|
||||
)
|
||||
@@ -0,0 +1,705 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..datastructures import Headers, HeadersLike
|
||||
from ..exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
NegotiationError,
|
||||
RedirectHandshake,
|
||||
SecurityError,
|
||||
)
|
||||
from ..extensions import ClientExtensionFactory, Extension
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import (
|
||||
build_authorization_basic,
|
||||
build_extension,
|
||||
build_host,
|
||||
build_subprotocol,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http import USER_AGENT
|
||||
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .compatibility import asyncio_timeout
|
||||
from .handshake import build_request, check_response
|
||||
from .http import read_response
|
||||
from .protocol import WebSocketCommonProtocol
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
|
||||
|
||||
|
||||
class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||||
"""
|
||||
WebSocket client connection.
|
||||
|
||||
:class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
|
||||
coroutines for receiving and sending messages.
|
||||
|
||||
It supports asynchronous iteration to receive incoming messages::
|
||||
|
||||
async for message in websocket:
|
||||
await process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises
|
||||
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
|
||||
is closed with any other code.
|
||||
|
||||
See :func:`connect` for the documentation of ``logger``, ``origin``,
|
||||
``extensions``, ``subprotocols``, ``extra_headers``, and
|
||||
``user_agent_header``.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
"""
|
||||
|
||||
is_client = True
|
||||
side = "client"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
origin: Optional[Origin] = None,
|
||||
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
extra_headers: Optional[HeadersLike] = None,
|
||||
user_agent_header: Optional[str] = USER_AGENT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
super().__init__(logger=logger, **kwargs)
|
||||
self.origin = origin
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
self.extra_headers = extra_headers
|
||||
self.user_agent_header = user_agent_header
|
||||
|
||||
def write_http_request(self, path: str, headers: Headers) -> None:
|
||||
"""
|
||||
Write request line and headers to the HTTP request.
|
||||
|
||||
"""
|
||||
self.path = path
|
||||
self.request_headers = headers
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("> GET %s HTTP/1.1", path)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
|
||||
# Since the path and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
request = f"GET {path} HTTP/1.1\r\n"
|
||||
request += str(headers)
|
||||
|
||||
self.transport.write(request.encode())
|
||||
|
||||
async def read_http_response(self) -> Tuple[int, Headers]:
|
||||
"""
|
||||
Read status line and headers from the HTTP response.
|
||||
|
||||
If the response contains a body, it may be read from ``self.reader``
|
||||
after this coroutine returns.
|
||||
|
||||
Raises:
|
||||
InvalidMessage: If the HTTP message is malformed or isn't an
|
||||
HTTP/1.1 GET response.
|
||||
|
||||
"""
|
||||
try:
|
||||
status_code, reason, headers = await read_response(self.reader)
|
||||
except Exception as exc:
|
||||
raise InvalidMessage("did not receive a valid HTTP response") from exc
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
|
||||
self.response_headers = headers
|
||||
|
||||
return status_code, self.response_headers
|
||||
|
||||
@staticmethod
|
||||
def process_extensions(
|
||||
headers: Headers,
|
||||
available_extensions: Optional[Sequence[ClientExtensionFactory]],
|
||||
) -> List[Extension]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||||
|
||||
Check that each extension is supported, as well as its parameters.
|
||||
|
||||
Return the list of accepted extensions.
|
||||
|
||||
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
|
||||
connection.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
:extension.
|
||||
|
||||
To provide this level of flexibility, for each extension accepted by
|
||||
the server, we check for a match with each extension available in the
|
||||
client configuration. If no match is found, an exception is raised.
|
||||
|
||||
If several variants of the same extension are accepted by the server,
|
||||
it may be configured several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
"""
|
||||
accepted_extensions: List[Extension] = []
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if header_values:
|
||||
if available_extensions is None:
|
||||
raise InvalidHandshake("no extensions supported")
|
||||
|
||||
parsed_header_values: List[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
for name, response_params in parsed_header_values:
|
||||
for extension_factory in available_extensions:
|
||||
# Skip non-matching extensions based on their name.
|
||||
if extension_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
extension = extension_factory.process_response_params(
|
||||
response_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the server sent. Fail the connection.
|
||||
else:
|
||||
raise NegotiationError(
|
||||
f"Unsupported extension: "
|
||||
f"name = {name}, params = {response_params}"
|
||||
)
|
||||
|
||||
return accepted_extensions
|
||||
|
||||
@staticmethod
|
||||
def process_subprotocol(
|
||||
headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
|
||||
) -> Optional[Subprotocol]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||||
|
||||
Check that it contains exactly one supported subprotocol.
|
||||
|
||||
Return the selected subprotocol.
|
||||
|
||||
"""
|
||||
subprotocol: Optional[Subprotocol] = None
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if header_values:
|
||||
if available_subprotocols is None:
|
||||
raise InvalidHandshake("no subprotocols supported")
|
||||
|
||||
parsed_header_values: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
if len(parsed_header_values) > 1:
|
||||
subprotocols = ", ".join(parsed_header_values)
|
||||
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
|
||||
|
||||
subprotocol = parsed_header_values[0]
|
||||
|
||||
if subprotocol not in available_subprotocols:
|
||||
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
||||
|
||||
return subprotocol
|
||||
|
||||
async def handshake(
|
||||
self,
|
||||
wsuri: WebSocketURI,
|
||||
origin: Optional[Origin] = None,
|
||||
available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
available_subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
extra_headers: Optional[HeadersLike] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the client side of the opening handshake.
|
||||
|
||||
Args:
|
||||
wsuri: URI of the WebSocket server.
|
||||
origin: Value of the ``Origin`` header.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: Arbitrary HTTP headers to add to the handshake request.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake fails.
|
||||
|
||||
"""
|
||||
request_headers = Headers()
|
||||
|
||||
request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
|
||||
|
||||
if wsuri.user_info:
|
||||
request_headers["Authorization"] = build_authorization_basic(
|
||||
*wsuri.user_info
|
||||
)
|
||||
|
||||
if origin is not None:
|
||||
request_headers["Origin"] = origin
|
||||
|
||||
key = build_request(request_headers)
|
||||
|
||||
if available_extensions is not None:
|
||||
extensions_header = build_extension(
|
||||
[
|
||||
(extension_factory.name, extension_factory.get_request_params())
|
||||
for extension_factory in available_extensions
|
||||
]
|
||||
)
|
||||
request_headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if available_subprotocols is not None:
|
||||
protocol_header = build_subprotocol(available_subprotocols)
|
||||
request_headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
if self.extra_headers is not None:
|
||||
request_headers.update(self.extra_headers)
|
||||
|
||||
if self.user_agent_header is not None:
|
||||
request_headers.setdefault("User-Agent", self.user_agent_header)
|
||||
|
||||
self.write_http_request(wsuri.resource_name, request_headers)
|
||||
|
||||
status_code, response_headers = await self.read_http_response()
|
||||
if status_code in (301, 302, 303, 307, 308):
|
||||
if "Location" not in response_headers:
|
||||
raise InvalidHeader("Location")
|
||||
raise RedirectHandshake(response_headers["Location"])
|
||||
elif status_code != 101:
|
||||
raise InvalidStatusCode(status_code, response_headers)
|
||||
|
||||
check_response(response_headers, key)
|
||||
|
||||
self.extensions = self.process_extensions(
|
||||
response_headers, available_extensions
|
||||
)
|
||||
|
||||
self.subprotocol = self.process_subprotocol(
|
||||
response_headers, available_subprotocols
|
||||
)
|
||||
|
||||
self.connection_open()
|
||||
|
||||
|
||||
class Connect:
|
||||
"""
|
||||
Connect to the WebSocket server at ``uri``.
|
||||
|
||||
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
|
||||
can then be used to send and receive messages.
|
||||
|
||||
:func:`connect` can be used as a asynchronous context manager::
|
||||
|
||||
async with websockets.connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||||
reconnect automatically on errors::
|
||||
|
||||
async for websocket in websockets.connect(...):
|
||||
try:
|
||||
...
|
||||
except websockets.ConnectionClosed:
|
||||
continue
|
||||
|
||||
The connection is closed automatically after each iteration of the loop.
|
||||
|
||||
If an error occurs while establishing the connection, :func:`connect`
|
||||
retries with exponential backoff. The backoff delay starts at three
|
||||
seconds and increases up to one minute.
|
||||
|
||||
If an error occurs in the body of the loop, you can handle the exception
|
||||
and :func:`connect` will reconnect with the next iteration; or you can
|
||||
let the exception bubble up and break out of the loop. This lets you
|
||||
decide which errors trigger a reconnection and which errors are fatal.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
create_protocol: Factory for the :class:`asyncio.Protocol` managing
|
||||
the connection. It defaults to :class:`WebSocketClientProtocol`.
|
||||
Set it to a wrapper or a subclass to customize connection handling.
|
||||
logger: Logger for this client.
|
||||
It defaults to ``logging.getLogger("websockets.client")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
origin: Value of the ``Origin`` header, for servers that require it.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: Arbitrary HTTP headers to add to the handshake request.
|
||||
user_agent_header: Value of the ``User-Agent`` request header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``.
|
||||
Setting it to :obj:`None` removes the header.
|
||||
open_timeout: Timeout for opening the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
Any other keyword arguments are passed the event loop's
|
||||
:meth:`~asyncio.loop.create_connection` method.
|
||||
|
||||
For example:
|
||||
|
||||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
|
||||
settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
|
||||
provided, a TLS context is created
|
||||
with :func:`~ssl.create_default_context`.
|
||||
|
||||
* You can set ``host`` and ``port`` to connect to a different host and
|
||||
port from those found in ``uri``. This only changes the destination of
|
||||
the TCP connection. The host name from ``uri`` is still used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
Raises:
|
||||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||||
OSError: If the TCP connection fails.
|
||||
InvalidHandshake: If the opening handshake fails.
|
||||
~asyncio.TimeoutError: If the opening handshake times out.
|
||||
|
||||
"""
|
||||
|
||||
MAX_REDIRECTS_ALLOWED = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
compression: Optional[str] = "deflate",
|
||||
origin: Optional[Origin] = None,
|
||||
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
extra_headers: Optional[HeadersLike] = None,
|
||||
user_agent_header: Optional[str] = USER_AGENT,
|
||||
open_timeout: Optional[float] = 10,
|
||||
ping_interval: Optional[float] = 20,
|
||||
ping_timeout: Optional[float] = 20,
|
||||
close_timeout: Optional[float] = None,
|
||||
max_size: Optional[int] = 2**20,
|
||||
max_queue: Optional[int] = 2**5,
|
||||
read_limit: int = 2**16,
|
||||
write_limit: int = 2**16,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Backwards compatibility: close_timeout used to be called timeout.
|
||||
timeout: Optional[float] = kwargs.pop("timeout", None)
|
||||
if timeout is None:
|
||||
timeout = 10
|
||||
else:
|
||||
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
|
||||
# If both are specified, timeout is ignored.
|
||||
if close_timeout is None:
|
||||
close_timeout = timeout
|
||||
|
||||
# Backwards compatibility: create_protocol used to be called klass.
|
||||
klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None)
|
||||
if klass is None:
|
||||
klass = WebSocketClientProtocol
|
||||
else:
|
||||
warnings.warn("rename klass to create_protocol", DeprecationWarning)
|
||||
# If both are specified, klass is ignored.
|
||||
if create_protocol is None:
|
||||
create_protocol = klass
|
||||
|
||||
# Backwards compatibility: recv() used to return None on closed connections
|
||||
legacy_recv: bool = kwargs.pop("legacy_recv", False)
|
||||
|
||||
# Backwards compatibility: the loop parameter used to be supported.
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
|
||||
if _loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
loop = _loop
|
||||
warnings.warn("remove loop argument", DeprecationWarning)
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
elif kwargs.get("ssl") is not None:
|
||||
raise ValueError(
|
||||
"connect() received a ssl argument for a ws:// URI, "
|
||||
"use a wss:// URI to enable TLS"
|
||||
)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_client_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
factory = functools.partial(
|
||||
create_protocol,
|
||||
logger=logger,
|
||||
origin=origin,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
extra_headers=extra_headers,
|
||||
user_agent_header=user_agent_header,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_size=max_size,
|
||||
max_queue=max_queue,
|
||||
read_limit=read_limit,
|
||||
write_limit=write_limit,
|
||||
host=wsuri.host,
|
||||
port=wsuri.port,
|
||||
secure=wsuri.secure,
|
||||
legacy_recv=legacy_recv,
|
||||
loop=_loop,
|
||||
)
|
||||
|
||||
if kwargs.pop("unix", False):
|
||||
path: Optional[str] = kwargs.pop("path", None)
|
||||
create_connection = functools.partial(
|
||||
loop.create_unix_connection, factory, path, **kwargs
|
||||
)
|
||||
else:
|
||||
host: Optional[str]
|
||||
port: Optional[int]
|
||||
if kwargs.get("sock") is None:
|
||||
host, port = wsuri.host, wsuri.port
|
||||
else:
|
||||
# If sock is given, host and port shouldn't be specified.
|
||||
host, port = None, None
|
||||
if kwargs.get("ssl"):
|
||||
kwargs.setdefault("server_hostname", wsuri.host)
|
||||
# If host and port are given, override values from the URI.
|
||||
host = kwargs.pop("host", host)
|
||||
port = kwargs.pop("port", port)
|
||||
create_connection = functools.partial(
|
||||
loop.create_connection, factory, host, port, **kwargs
|
||||
)
|
||||
|
||||
self.open_timeout = open_timeout
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
self.logger = logger
|
||||
|
||||
# This is a coroutine function.
|
||||
self._create_connection = create_connection
|
||||
self._uri = uri
|
||||
self._wsuri = wsuri
|
||||
|
||||
def handle_redirect(self, uri: str) -> None:
|
||||
# Update the state of this instance to connect to a new URI.
|
||||
old_uri = self._uri
|
||||
old_wsuri = self._wsuri
|
||||
new_uri = urllib.parse.urljoin(old_uri, uri)
|
||||
new_wsuri = parse_uri(new_uri)
|
||||
|
||||
# Forbid TLS downgrade.
|
||||
if old_wsuri.secure and not new_wsuri.secure:
|
||||
raise SecurityError("redirect from WSS to WS")
|
||||
|
||||
same_origin = (
|
||||
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
|
||||
)
|
||||
|
||||
# Rewrite the host and port arguments for cross-origin redirects.
|
||||
# This preserves connection overrides with the host and port
|
||||
# arguments if the redirect points to the same host and port.
|
||||
if not same_origin:
|
||||
# Replace the host and port argument passed to the protocol factory.
|
||||
factory = self._create_connection.args[0]
|
||||
factory = functools.partial(
|
||||
factory.func,
|
||||
*factory.args,
|
||||
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
|
||||
)
|
||||
# Replace the host and port argument passed to create_connection.
|
||||
self._create_connection = functools.partial(
|
||||
self._create_connection.func,
|
||||
*(factory, new_wsuri.host, new_wsuri.port),
|
||||
**self._create_connection.keywords,
|
||||
)
|
||||
|
||||
# Set the new WebSocket URI. This suffices for same-origin redirects.
|
||||
self._uri = new_uri
|
||||
self._wsuri = new_wsuri
|
||||
|
||||
# async for ... in connect(...):
|
||||
|
||||
BACKOFF_MIN = 1.92
|
||||
BACKOFF_MAX = 60.0
|
||||
BACKOFF_FACTOR = 1.618
|
||||
BACKOFF_INITIAL = 5
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
yield protocol
|
||||
except Exception:
|
||||
# Add a random initial delay between 0 and 5 seconds.
|
||||
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
|
||||
if backoff_delay == self.BACKOFF_MIN:
|
||||
initial_delay = random.random() * self.BACKOFF_INITIAL
|
||||
self.logger.info(
|
||||
"! connect failed; reconnecting in %.1f seconds",
|
||||
initial_delay,
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(initial_delay)
|
||||
else:
|
||||
self.logger.info(
|
||||
"! connect failed again; retrying in %d seconds",
|
||||
int(backoff_delay),
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(int(backoff_delay))
|
||||
# Increase delay with truncated exponential backoff.
|
||||
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
|
||||
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
|
||||
continue
|
||||
else:
|
||||
# Connection succeeded - reset backoff delay
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
|
||||
# async with connect(...) as ...:
|
||||
|
||||
async def __aenter__(self) -> WebSocketClientProtocol:
|
||||
return await self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.protocol.close()
|
||||
|
||||
# ... = await connect(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl_timeout__().__await__()
|
||||
|
||||
async def __await_impl_timeout__(self) -> WebSocketClientProtocol:
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
return await self.__await_impl__()
|
||||
|
||||
async def __await_impl__(self) -> WebSocketClientProtocol:
|
||||
for redirects in range(self.MAX_REDIRECTS_ALLOWED):
|
||||
_transport, _protocol = await self._create_connection()
|
||||
protocol = cast(WebSocketClientProtocol, _protocol)
|
||||
try:
|
||||
await protocol.handshake(
|
||||
self._wsuri,
|
||||
origin=protocol.origin,
|
||||
available_extensions=protocol.available_extensions,
|
||||
available_subprotocols=protocol.available_subprotocols,
|
||||
extra_headers=protocol.extra_headers,
|
||||
)
|
||||
except RedirectHandshake as exc:
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
self.handle_redirect(exc.uri)
|
||||
# Avoid leaking a connected socket when the handshake fails.
|
||||
except (Exception, asyncio.CancelledError):
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
raise
|
||||
else:
|
||||
self.protocol = protocol
|
||||
return protocol
|
||||
else:
|
||||
raise SecurityError("too many redirects")
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
|
||||
connect = Connect
|
||||
|
||||
|
||||
def unix_connect(
|
||||
path: Optional[str] = None,
|
||||
uri: str = "ws://localhost/",
|
||||
**kwargs: Any,
|
||||
) -> Connect:
|
||||
"""
|
||||
Similar to :func:`connect`, but for connecting to a Unix socket.
|
||||
|
||||
This function builds upon the event loop's
|
||||
:meth:`~asyncio.loop.create_unix_connection` method.
|
||||
|
||||
It is only available on Unix.
|
||||
|
||||
It's mainly useful for debugging servers listening on Unix sockets.
|
||||
|
||||
Args:
|
||||
path: File system path to the Unix socket.
|
||||
uri: URI of the WebSocket server; the host is used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
"""
|
||||
return connect(uri=uri, path=path, unix=True, **kwargs)
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ["asyncio_timeout"]
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from asyncio import timeout as asyncio_timeout # noqa: F401
|
||||
else:
|
||||
from .async_timeout import timeout as asyncio_timeout # noqa: F401
|
||||
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple
|
||||
|
||||
from .. import extensions, frames
|
||||
from ..exceptions import PayloadTooBig, ProtocolError
|
||||
|
||||
|
||||
try:
|
||||
from ..speedups import apply_mask
|
||||
except ImportError:
|
||||
from ..utils import apply_mask
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
fin: bool
|
||||
opcode: frames.Opcode
|
||||
data: bytes
|
||||
rsv1: bool = False
|
||||
rsv2: bool = False
|
||||
rsv3: bool = False
|
||||
|
||||
@property
|
||||
def new_frame(self) -> frames.Frame:
|
||||
return frames.Frame(
|
||||
self.opcode,
|
||||
self.data,
|
||||
self.fin,
|
||||
self.rsv1,
|
||||
self.rsv2,
|
||||
self.rsv3,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.new_frame)
|
||||
|
||||
def check(self) -> None:
|
||||
return self.new_frame.check()
|
||||
|
||||
@classmethod
|
||||
async def read(
|
||||
cls,
|
||||
reader: Callable[[int], Awaitable[bytes]],
|
||||
*,
|
||||
mask: bool,
|
||||
max_size: Optional[int] = None,
|
||||
extensions: Optional[Sequence[extensions.Extension]] = None,
|
||||
) -> Frame:
|
||||
"""
|
||||
Read a WebSocket frame.
|
||||
|
||||
Args:
|
||||
reader: Coroutine that reads exactly the requested number of
|
||||
bytes, unless the end of file is reached.
|
||||
mask: Whether the frame should be masked i.e. whether the read
|
||||
happens on the server side.
|
||||
max_size: Maximum payload size in bytes.
|
||||
extensions: List of extensions, applied in reverse order.
|
||||
|
||||
Raises:
|
||||
PayloadTooBig: If the frame exceeds ``max_size``.
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
|
||||
# Read the header.
|
||||
data = await reader(2)
|
||||
head1, head2 = struct.unpack("!BB", data)
|
||||
|
||||
# While not Pythonic, this is marginally faster than calling bool().
|
||||
fin = True if head1 & 0b10000000 else False
|
||||
rsv1 = True if head1 & 0b01000000 else False
|
||||
rsv2 = True if head1 & 0b00100000 else False
|
||||
rsv3 = True if head1 & 0b00010000 else False
|
||||
|
||||
try:
|
||||
opcode = frames.Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
data = await reader(2)
|
||||
(length,) = struct.unpack("!H", data)
|
||||
elif length == 127:
|
||||
data = await reader(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
|
||||
if mask:
|
||||
mask_bits = await reader(4)
|
||||
|
||||
# Read the data.
|
||||
data = await reader(length)
|
||||
if mask:
|
||||
data = apply_mask(data, mask_bits)
|
||||
|
||||
new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in reversed(extensions):
|
||||
new_frame = extension.decode(new_frame, max_size=max_size)
|
||||
|
||||
new_frame.check()
|
||||
|
||||
return cls(
|
||||
new_frame.fin,
|
||||
new_frame.opcode,
|
||||
new_frame.data,
|
||||
new_frame.rsv1,
|
||||
new_frame.rsv2,
|
||||
new_frame.rsv3,
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
write: Callable[[bytes], Any],
|
||||
*,
|
||||
mask: bool,
|
||||
extensions: Optional[Sequence[extensions.Extension]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write a WebSocket frame.
|
||||
|
||||
Args:
|
||||
frame: Frame to write.
|
||||
write: Function that writes bytes.
|
||||
mask: Whether the frame should be masked i.e. whether the write
|
||||
happens on the client side.
|
||||
extensions: List of extensions, applied in order.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
# The frame is written in a single call to write in order to prevent
|
||||
# TCP fragmentation. See #68 for details. This also makes it safe to
|
||||
# send frames concurrently from multiple coroutines.
|
||||
write(self.new_frame.serialize(mask=mask, extensions=extensions))
|
||||
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
from ..frames import ( # noqa: E402, F401, I001
|
||||
Close,
|
||||
prepare_ctrl as encode_data,
|
||||
prepare_data,
|
||||
)
|
||||
|
||||
|
||||
def parse_close(data: bytes) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse the payload from a close frame.
|
||||
|
||||
Returns:
|
||||
Close code and reason.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If data is ill-formed.
|
||||
UnicodeDecodeError: If the reason isn't valid UTF-8.
|
||||
|
||||
"""
|
||||
close = Close.parse(data)
|
||||
return close.code, close.reason
|
||||
|
||||
|
||||
def serialize_close(code: int, reason: str) -> bytes:
|
||||
"""
|
||||
Serialize the payload for a close frame.
|
||||
|
||||
"""
|
||||
return Close(code, reason).serialize()
|
||||
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from typing import List
|
||||
|
||||
from ..datastructures import Headers, MultipleValuesError
|
||||
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
|
||||
from ..headers import parse_connection, parse_upgrade
|
||||
from ..typing import ConnectionOption, UpgradeProtocol
|
||||
from ..utils import accept_key as accept, generate_key
|
||||
|
||||
|
||||
__all__ = ["build_request", "check_request", "build_response", "check_response"]
|
||||
|
||||
|
||||
def build_request(headers: Headers) -> str:
|
||||
"""
|
||||
Build a handshake request to send to the server.
|
||||
|
||||
Update request headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: Handshake request headers.
|
||||
|
||||
Returns:
|
||||
str: ``key`` that must be passed to :func:`check_response`.
|
||||
|
||||
"""
|
||||
key = generate_key()
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Key"] = key
|
||||
headers["Sec-WebSocket-Version"] = "13"
|
||||
return key
|
||||
|
||||
|
||||
def check_request(headers: Headers) -> str:
|
||||
"""
|
||||
Check a handshake request received from the client.
|
||||
|
||||
This function doesn't verify that the request is an HTTP/1.1 or higher GET
|
||||
request and doesn't perform ``Host`` and ``Origin`` checks. These controls
|
||||
are usually performed earlier in the HTTP request handling code. They're
|
||||
the responsibility of the caller.
|
||||
|
||||
Args:
|
||||
headers: Handshake request headers.
|
||||
|
||||
Returns:
|
||||
str: ``key`` that must be passed to :func:`build_response`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake request is invalid.
|
||||
Then, the server must return a 400 Bad Request error.
|
||||
|
||||
"""
|
||||
connection: List[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", ", ".join(connection))
|
||||
|
||||
upgrade: List[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
|
||||
if len(raw_key) != 16:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
|
||||
|
||||
try:
|
||||
s_w_version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
|
||||
) from exc
|
||||
|
||||
if s_w_version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
|
||||
|
||||
return s_w_key
|
||||
|
||||
|
||||
def build_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Build a handshake response to send to the client.
|
||||
|
||||
Update response headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: Handshake response headers.
|
||||
key: Returned by :func:`check_request`.
|
||||
|
||||
"""
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Accept"] = accept(key)
|
||||
|
||||
|
||||
def check_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Check a handshake response received from the server.
|
||||
|
||||
This function doesn't verify that the response is an HTTP/1.1 or higher
|
||||
response with a 101 status code. These controls are the responsibility of
|
||||
the caller.
|
||||
|
||||
Args:
|
||||
headers: Handshake response headers.
|
||||
key: Returned by :func:`build_request`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake response is invalid.
|
||||
|
||||
"""
|
||||
connection: List[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", " ".join(connection))
|
||||
|
||||
upgrade: List[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
|
||||
) from exc
|
||||
|
||||
if s_w_accept != accept(key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
||||
@@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import SecurityError
|
||||
|
||||
|
||||
__all__ = ["read_request", "read_response"]
|
||||
|
||||
MAX_HEADERS = 128
|
||||
MAX_LINE = 8192
|
||||
|
||||
|
||||
def d(value: bytes) -> str:
|
||||
"""
|
||||
Decode a bytestring for interpolating into an error message.
|
||||
|
||||
"""
|
||||
return value.decode(errors="backslashreplace")
|
||||
|
||||
|
||||
# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B.
|
||||
|
||||
# Regex for validating header names.
|
||||
|
||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
# Regex for validating header values.
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
|
||||
|
||||
# The ABNF is complicated because it attempts to express that optional
|
||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
|
||||
|
||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
|
||||
|
||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 GET request and return ``(path, headers)``.
|
||||
|
||||
``path`` isn't URL-decoded or validated in any way.
|
||||
|
||||
``path`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the request body because
|
||||
WebSocket handshake requests don't have one. If the request contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: Input to read the request from.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP request.
|
||||
SecurityError: If the request exceeds a security limit.
|
||||
ValueError: If the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1
|
||||
|
||||
# Parsing is simple because fixed values are expected for method and
|
||||
# version and because path isn't checked. Since WebSocket software tends
|
||||
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
|
||||
|
||||
try:
|
||||
request_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, version = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method: {d(method)}")
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return path, headers
|
||||
|
||||
|
||||
async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
|
||||
|
||||
``reason`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the response body because
|
||||
WebSocket handshake responses don't have one. If the response contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: Input to read the response from.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP response.
|
||||
SecurityError: If the response exceeds a security limit.
|
||||
ValueError: If the response isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2
|
||||
|
||||
# As in read_request, parsing is simple because a fixed value is expected
|
||||
# for version, status_code is a 3-digit number, and reason can be ignored.
|
||||
|
||||
try:
|
||||
status_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
|
||||
if not 100 <= status_code < 1000:
|
||||
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode()
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return status_code, reason, headers
|
||||
|
||||
|
||||
async def read_headers(stream: asyncio.StreamReader) -> Headers:
|
||||
"""
|
||||
Read HTTP headers from ``stream``.
|
||||
|
||||
Non-ASCII characters are represented with surrogate escapes.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = Headers()
|
||||
for _ in range(MAX_HEADERS + 1):
|
||||
try:
|
||||
line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP headers") from exc
|
||||
if line == b"":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_name, raw_value = line.split(b":", 1)
|
||||
except ValueError: # not enough values to unpack (expected 2, got 1)
|
||||
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
|
||||
if not _token_re.fullmatch(raw_name):
|
||||
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
|
||||
raw_value = raw_value.strip(b" \t")
|
||||
if not _value_re.fullmatch(raw_value):
|
||||
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
|
||||
|
||||
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
|
||||
value = raw_value.decode("ascii", "surrogateescape")
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
async def read_line(stream: asyncio.StreamReader) -> bytes:
|
||||
"""
|
||||
Read a single line from ``stream``.
|
||||
|
||||
CRLF is stripped from the return value.
|
||||
|
||||
"""
|
||||
# Security: this is bounded by the StreamReader's limit (default = 32 KiB).
|
||||
line = await stream.readline()
|
||||
# Security: this guarantees header values are small (hard-coded = 8 KiB)
|
||||
if len(line) > MAX_LINE:
|
||||
raise SecurityError("line too long")
|
||||
# Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
return line[:-2]
|
||||
File diff suppressed because it is too large
Load Diff
1185
code/.venv/lib/python3.12/site-packages/websockets/legacy/server.py
Normal file
1185
code/.venv/lib/python3.12/site-packages/websockets/legacy/server.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user