This commit is contained in:
Jonas Zeunert
2024-08-16 21:57:55 +02:00
parent adeb5c5ec7
commit 4309a2d185
1696 changed files with 279655 additions and 0 deletions

View File

@@ -0,0 +1,328 @@
from __future__ import annotations
import socket
import ssl
import threading
from typing import Any, Optional, Sequence, Type
from ..client import ClientProtocol
from ..datastructures import HeadersLike
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
from ..http import USER_AGENT
from ..http11 import Response
from ..protocol import CONNECTING, OPEN, Event
from ..typing import LoggerLike, Origin, Subprotocol
from ..uri import parse_uri
from .connection import Connection
from .utils import Deadline
__all__ = ["connect", "unix_connect", "ClientConnection"]
class ClientConnection(Connection):
"""
Threaded implementation of a WebSocket client connection.
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for
receiving and sending messages.
It supports iteration to receive messages::
for message in websocket:
process(message)
The iterator exits normally when the connection is closed with close code
1000 (OK) or 1001 (going away) or without a close code. It raises a
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
Args:
socket: Socket connected to a WebSocket server.
protocol: Sans-I/O connection.
close_timeout: Timeout for closing the connection in seconds.
"""
def __init__(
self,
socket: socket.socket,
protocol: ClientProtocol,
*,
close_timeout: Optional[float] = 10,
) -> None:
self.protocol: ClientProtocol
self.response_rcvd = threading.Event()
super().__init__(
socket,
protocol,
close_timeout=close_timeout,
)
def handshake(
self,
additional_headers: Optional[HeadersLike] = None,
user_agent_header: Optional[str] = USER_AGENT,
timeout: Optional[float] = None,
) -> None:
"""
Perform the opening handshake.
"""
with self.send_context(expected_state=CONNECTING):
self.request = self.protocol.connect()
if additional_headers is not None:
self.request.headers.update(additional_headers)
if user_agent_header is not None:
self.request.headers["User-Agent"] = user_agent_header
self.protocol.send_request(self.request)
if not self.response_rcvd.wait(timeout):
self.close_socket()
self.recv_events_thread.join()
raise TimeoutError("timed out during handshake")
if self.response is None:
self.close_socket()
self.recv_events_thread.join()
raise ConnectionError("connection closed during handshake")
if self.protocol.state is not OPEN:
self.recv_events_thread.join(self.close_timeout)
self.close_socket()
self.recv_events_thread.join()
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
def process_event(self, event: Event) -> None:
"""
Process one incoming event.
"""
# First event - handshake response.
if self.response is None:
assert isinstance(event, Response)
self.response = event
self.response_rcvd.set()
# Later events - frames.
else:
super().process_event(event)
def recv_events(self) -> None:
"""
Read incoming data from the socket and process events.
"""
try:
super().recv_events()
finally:
# If the connection is closed during the handshake, unblock it.
self.response_rcvd.set()
def connect(
uri: str,
*,
# TCP/TLS — unix and path are only for unix_connect()
sock: Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
unix: bool = False,
path: Optional[str] = None,
# WebSocket
origin: Optional[Origin] = None,
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
additional_headers: Optional[HeadersLike] = None,
user_agent_header: Optional[str] = USER_AGENT,
compression: Optional[str] = "deflate",
# Timeouts
open_timeout: Optional[float] = 10,
close_timeout: Optional[float] = 10,
# Limits
max_size: Optional[int] = 2**20,
# Logging
logger: Optional[LoggerLike] = None,
# Escape hatch for advanced customization
create_connection: Optional[Type[ClientConnection]] = None,
) -> ClientConnection:
"""
Connect to the WebSocket server at ``uri``.
This function returns a :class:`ClientConnection` instance, which you can
use to send and receive messages.
:func:`connect` may be used as a context manager::
async with websockets.sync.client.connect(...) as websocket:
...
The connection is closed automatically when exiting the context.
Args:
uri: URI of the WebSocket server.
sock: Preexisting TCP socket. ``sock`` overrides the host and port
from ``uri``. You may call :func:`socket.create_connection` to
create a suitable TCP socket.
ssl_context: Configuration for enabling TLS on the connection.
server_hostname: Host name for the TLS handshake. ``server_hostname``
overrides the host name from ``uri``.
origin: Value of the ``Origin`` header, for servers that require it.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
preference.
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
to the handshake request.
user_agent_header: Value of the ``User-Agent`` request header.
It defaults to ``"Python/x.y.z websockets/X.Y"``.
Setting it to :obj:`None` removes the header.
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
open_timeout: Timeout for opening the connection in seconds.
:obj:`None` disables the timeout.
close_timeout: Timeout for closing the connection in seconds.
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
:obj:`None` disables the limit.
logger: Logger for this client.
It defaults to ``logging.getLogger("websockets.client")``.
See the :doc:`logging guide <../../topics/logging>` for details.
create_connection: Factory for the :class:`ClientConnection` managing
the connection. Set it to a wrapper or a subclass to customize
connection handling.
Raises:
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
OSError: If the TCP connection fails.
InvalidHandshake: If the opening handshake fails.
TimeoutError: If the opening handshake times out.
"""
# Process parameters
wsuri = parse_uri(uri)
if not wsuri.secure and ssl_context is not None:
raise TypeError("ssl_context argument is incompatible with a ws:// URI")
if unix:
if path is None and sock is None:
raise TypeError("missing path argument")
elif path is not None and sock is not None:
raise TypeError("path and sock arguments are incompatible")
else:
assert path is None # private argument, only set by unix_connect()
if subprotocols is not None:
validate_subprotocols(subprotocols)
if compression == "deflate":
extensions = enable_client_permessage_deflate(extensions)
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
# Calculate timeouts on the TCP, TLS, and WebSocket handshakes.
# The TCP and TLS timeouts must be set on the socket, then removed
# to avoid conflicting with the WebSocket timeout in handshake().
deadline = Deadline(open_timeout)
if create_connection is None:
create_connection = ClientConnection
try:
# Connect socket
if sock is None:
if unix:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(deadline.timeout())
assert path is not None # validated above -- this is for mpypy
sock.connect(path)
else:
sock = socket.create_connection(
(wsuri.host, wsuri.port),
deadline.timeout(),
)
sock.settimeout(None)
# Disable Nagle algorithm
if not unix:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
# Initialize TLS wrapper and perform TLS handshake
if wsuri.secure:
if ssl_context is None:
ssl_context = ssl.create_default_context()
if server_hostname is None:
server_hostname = wsuri.host
sock.settimeout(deadline.timeout())
sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname)
sock.settimeout(None)
# Initialize WebSocket connection
protocol = ClientProtocol(
wsuri,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
state=CONNECTING,
max_size=max_size,
logger=logger,
)
# Initialize WebSocket protocol
connection = create_connection(
sock,
protocol,
close_timeout=close_timeout,
)
# On failure, handshake() closes the socket and raises an exception.
connection.handshake(
additional_headers,
user_agent_header,
deadline.timeout(),
)
except Exception:
if sock is not None:
sock.close()
raise
return connection
def unix_connect(
path: Optional[str] = None,
uri: Optional[str] = None,
**kwargs: Any,
) -> ClientConnection:
"""
Connect to a WebSocket server listening on a Unix socket.
This function is identical to :func:`connect`, except for the additional
``path`` argument. It's only available on Unix.
It's mainly useful for debugging servers listening on Unix sockets.
Args:
path: File system path to the Unix socket.
uri: URI of the WebSocket server. ``uri`` defaults to
``ws://localhost/`` or, when a ``ssl_context`` is provided, to
``wss://localhost/``.
"""
if uri is None:
if kwargs.get("ssl_context") is None:
uri = "ws://localhost/"
else:
uri = "wss://localhost/"
return connect(uri=uri, unix=True, path=path, **kwargs)

View File

@@ -0,0 +1,773 @@
from __future__ import annotations
import contextlib
import logging
import random
import socket
import struct
import threading
import uuid
from types import TracebackType
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union
from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl
from ..http11 import Request, Response
from ..protocol import CLOSED, OPEN, Event, Protocol, State
from ..typing import Data, LoggerLike, Subprotocol
from .messages import Assembler
from .utils import Deadline
__all__ = ["Connection"]
logger = logging.getLogger(__name__)
class Connection:
"""
Threaded implementation of a WebSocket connection.
:class:`Connection` provides APIs shared between WebSocket servers and
clients.
You shouldn't use it directly. Instead, use
:class:`~websockets.sync.client.ClientConnection` or
:class:`~websockets.sync.server.ServerConnection`.
"""
recv_bufsize = 65536
def __init__(
self,
socket: socket.socket,
protocol: Protocol,
*,
close_timeout: Optional[float] = 10,
) -> None:
self.socket = socket
self.protocol = protocol
self.close_timeout = close_timeout
# Inject reference to this instance in the protocol's logger.
self.protocol.logger = logging.LoggerAdapter(
self.protocol.logger,
{"websocket": self},
)
# Copy attributes from the protocol for convenience.
self.id: uuid.UUID = self.protocol.id
"""Unique identifier of the connection. Useful in logs."""
self.logger: LoggerLike = self.protocol.logger
"""Logger for this connection."""
self.debug = self.protocol.debug
# HTTP handshake request and response.
self.request: Optional[Request] = None
"""Opening handshake request."""
self.response: Optional[Response] = None
"""Opening handshake response."""
# Mutex serializing interactions with the protocol.
self.protocol_mutex = threading.Lock()
# Assembler turning frames into messages and serializing reads.
self.recv_messages = Assembler()
# Whether we are busy sending a fragmented message.
self.send_in_progress = False
# Deadline for the closing handshake.
self.close_deadline: Optional[Deadline] = None
# Mapping of ping IDs to pong waiters, in chronological order.
self.pings: Dict[bytes, threading.Event] = {}
# Receiving events from the socket.
self.recv_events_thread = threading.Thread(target=self.recv_events)
self.recv_events_thread.start()
# Exception raised in recv_events, to be chained to ConnectionClosed
# in the user thread in order to show why the TCP connection dropped.
self.recv_events_exc: Optional[BaseException] = None
# Public attributes
@property
def local_address(self) -> Any:
"""
Local address of the connection.
For IPv4 connections, this is a ``(host, port)`` tuple.
The format of the address depends on the address family.
See :meth:`~socket.socket.getsockname`.
"""
return self.socket.getsockname()
@property
def remote_address(self) -> Any:
"""
Remote address of the connection.
For IPv4 connections, this is a ``(host, port)`` tuple.
The format of the address depends on the address family.
See :meth:`~socket.socket.getpeername`.
"""
return self.socket.getpeername()
@property
def subprotocol(self) -> Optional[Subprotocol]:
"""
Subprotocol negotiated during the opening handshake.
:obj:`None` if no subprotocol was negotiated.
"""
return self.protocol.subprotocol
# Public methods
def __enter__(self) -> Connection:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if exc_type is None:
self.close()
else:
self.close(CloseCode.INTERNAL_ERROR)
def __iter__(self) -> Iterator[Data]:
"""
Iterate on incoming messages.
The iterator calls :meth:`recv` and yields messages in an infinite loop.
It exits when the connection is closed normally. It raises a
:exc:`~websockets.exceptions.ConnectionClosedError` exception after a
protocol error or a network failure.
"""
try:
while True:
yield self.recv()
except ConnectionClosedOK:
return
def recv(self, timeout: Optional[float] = None) -> Data:
"""
Receive the next message.
When the connection is closed, :meth:`recv` raises
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises
:exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure
and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
error or a network failure. This is how you detect the end of the
message stream.
If ``timeout`` is :obj:`None`, block until a message is received. If
``timeout`` is set and no message is received within ``timeout``
seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if
a message was already received.
If the message is fragmented, wait until all fragments are received,
reassemble them, and return the whole message.
Returns:
A string (:class:`str`) for a Text_ frame or a bytestring
(:class:`bytes`) for a Binary_ frame.
.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
.. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
try:
return self.recv_messages.get(timeout)
except EOFError:
raise self.protocol.close_exc from self.recv_events_exc
except RuntimeError:
raise RuntimeError(
"cannot call recv while another thread "
"is already running recv or recv_streaming"
) from None
def recv_streaming(self) -> Iterator[Data]:
"""
Receive the next message frame by frame.
If the message is fragmented, yield each fragment as it is received.
The iterator must be fully consumed, or else the connection will become
unusable.
:meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
Returns:
An iterator of strings (:class:`str`) for a Text_ frame or
bytestrings (:class:`bytes`) for a Binary_ frame.
.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
.. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
try:
yield from self.recv_messages.get_iter()
except EOFError:
raise self.protocol.close_exc from self.recv_events_exc
except RuntimeError:
raise RuntimeError(
"cannot call recv_streaming while another thread "
"is already running recv or recv_streaming"
) from None
def send(self, message: Union[Data, Iterable[Data]]) -> None:
"""
Send a message.
A string (:class:`str`) is sent as a Text_ frame. A bytestring or
bytes-like object (:class:`bytes`, :class:`bytearray`, or
:class:`memoryview`) is sent as a Binary_ frame.
.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
.. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
:meth:`send` also accepts an iterable of strings, bytestrings, or
bytes-like objects to enable fragmentation_. Each item is treated as a
message fragment and sent in its own frame. All items must be of the
same type, or else :meth:`send` will raise a :exc:`TypeError` and the
connection will be closed.
.. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4
:meth:`send` rejects dict-like objects because this is often an error.
(If you really want to send the keys of a dict-like object as fragments,
call its :meth:`~dict.keys` method and pass the result to :meth:`send`.)
When the connection is closed, :meth:`send` raises
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
connection closure and
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
error or a network failure.
Args:
message: Message to send.
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If a connection is busy sending a fragmented message.
TypeError: If ``message`` doesn't have a supported type.
"""
# Unfragmented message -- this case must be handled first because
# strings and bytes-like objects are iterable.
if isinstance(message, str):
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
"cannot call send while another thread "
"is already running send"
)
self.protocol.send_text(message.encode("utf-8"))
elif isinstance(message, BytesLike):
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
"cannot call send while another thread "
"is already running send"
)
self.protocol.send_binary(message)
# Catch a common mistake -- passing a dict to send().
elif isinstance(message, Mapping):
raise TypeError("data is a dict-like object")
# Fragmented message -- regular iterator.
elif isinstance(message, Iterable):
chunks = iter(message)
try:
chunk = next(chunks)
except StopIteration:
return
try:
# First fragment.
if isinstance(chunk, str):
text = True
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
"cannot call send while another thread "
"is already running send"
)
self.send_in_progress = True
self.protocol.send_text(
chunk.encode("utf-8"),
fin=False,
)
elif isinstance(chunk, BytesLike):
text = False
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
"cannot call send while another thread "
"is already running send"
)
self.send_in_progress = True
self.protocol.send_binary(
chunk,
fin=False,
)
else:
raise TypeError("data iterable must contain bytes or str")
# Other fragments
for chunk in chunks:
if isinstance(chunk, str) and text:
with self.send_context():
assert self.send_in_progress
self.protocol.send_continuation(
chunk.encode("utf-8"),
fin=False,
)
elif isinstance(chunk, BytesLike) and not text:
with self.send_context():
assert self.send_in_progress
self.protocol.send_continuation(
chunk,
fin=False,
)
else:
raise TypeError("data iterable must contain uniform types")
# Final fragment.
with self.send_context():
self.protocol.send_continuation(b"", fin=True)
self.send_in_progress = False
except RuntimeError:
# We didn't start sending a fragmented message.
raise
except Exception:
# We're half-way through a fragmented message and we can't
# complete it. This makes the connection unusable.
with self.send_context():
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"error in fragmented message",
)
raise
else:
raise TypeError("data must be bytes, str, or iterable")
def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None:
"""
Perform the closing handshake.
:meth:`close` waits for the other end to complete the handshake, for the
TCP connection to terminate, and for all incoming messages to be read
with :meth:`recv`.
:meth:`close` is idempotent: it doesn't do anything once the
connection is closed.
Args:
code: WebSocket close code.
reason: WebSocket close reason.
"""
try:
# The context manager takes care of waiting for the TCP connection
# to terminate after calling a method that sends a close frame.
with self.send_context():
if self.send_in_progress:
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"close during fragmented message",
)
else:
self.protocol.send_close(code, reason)
except ConnectionClosed:
# Ignore ConnectionClosed exceptions raised from send_context().
# They mean that the connection is closed, which was the goal.
pass
def ping(self, data: Optional[Data] = None) -> threading.Event:
"""
Send a Ping_.
.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2
A ping may serve as a keepalive or as a check that the remote endpoint
received all messages up to this point
Args:
data: Payload of the ping. A :class:`str` will be encoded to UTF-8.
If ``data`` is :obj:`None`, the payload is four random bytes.
Returns:
An event that will be set when the corresponding pong is received.
You can ignore it if you don't intend to wait.
::
pong_event = ws.ping()
pong_event.wait() # only if you want to wait for the pong
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If another ping was sent with the same data and
the corresponding pong wasn't received yet.
"""
if data is not None:
data = prepare_ctrl(data)
with self.send_context():
# Protect against duplicates if a payload is explicitly set.
if data in self.pings:
raise RuntimeError("already waiting for a pong with the same data")
# Generate a unique random payload otherwise.
while data is None or data in self.pings:
data = struct.pack("!I", random.getrandbits(32))
pong_waiter = threading.Event()
self.pings[data] = pong_waiter
self.protocol.send_ping(data)
return pong_waiter
def pong(self, data: Data = b"") -> None:
"""
Send a Pong_.
.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3
An unsolicited pong may serve as a unidirectional heartbeat.
Args:
data: Payload of the pong. A :class:`str` will be encoded to UTF-8.
Raises:
ConnectionClosed: When the connection is closed.
"""
data = prepare_ctrl(data)
with self.send_context():
self.protocol.send_pong(data)
# Private methods
def process_event(self, event: Event) -> None:
"""
Process one incoming event.
This method is overridden in subclasses to handle the handshake.
"""
assert isinstance(event, Frame)
if event.opcode in DATA_OPCODES:
self.recv_messages.put(event)
if event.opcode is Opcode.PONG:
self.acknowledge_pings(bytes(event.data))
def acknowledge_pings(self, data: bytes) -> None:
"""
Acknowledge pings when receiving a pong.
"""
with self.protocol_mutex:
# Ignore unsolicited pong.
if data not in self.pings:
return
# Sending a pong for only the most recent ping is legal.
# Acknowledge all previous pings too in that case.
ping_id = None
ping_ids = []
for ping_id, ping in self.pings.items():
ping_ids.append(ping_id)
ping.set()
if ping_id == data:
break
else:
raise AssertionError("solicited pong not found in pings")
# Remove acknowledged pings from self.pings.
for ping_id in ping_ids:
del self.pings[ping_id]
def recv_events(self) -> None:
"""
Read incoming data from the socket and process events.
Run this method in a thread as long as the connection is alive.
``recv_events()`` exits immediately when the ``self.socket`` is closed.
"""
try:
while True:
try:
if self.close_deadline is not None:
self.socket.settimeout(self.close_deadline.timeout())
data = self.socket.recv(self.recv_bufsize)
except Exception as exc:
if self.debug:
self.logger.debug("error while receiving data", exc_info=True)
# When the closing handshake is initiated by our side,
# recv() may block until send_context() closes the socket.
# In that case, send_context() already set recv_events_exc.
# Calling set_recv_events_exc() avoids overwriting it.
with self.protocol_mutex:
self.set_recv_events_exc(exc)
break
if data == b"":
break
# Acquire the connection lock.
with self.protocol_mutex:
# Feed incoming data to the connection.
self.protocol.receive_data(data)
# This isn't expected to raise an exception.
events = self.protocol.events_received()
# Write outgoing data to the socket.
try:
self.send_data()
except Exception as exc:
if self.debug:
self.logger.debug("error while sending data", exc_info=True)
# Similarly to the above, avoid overriding an exception
# set by send_context(), in case of a race condition
# i.e. send_context() closes the socket after recv()
# returns above but before send_data() calls send().
self.set_recv_events_exc(exc)
break
if self.protocol.close_expected():
# If the connection is expected to close soon, set the
# close deadline based on the close timeout.
if self.close_deadline is None:
self.close_deadline = Deadline(self.close_timeout)
# Unlock conn_mutex before processing events. Else, the
# application can't send messages in response to events.
# If self.send_data raised an exception, then events are lost.
# Given that automatic responses write small amounts of data,
# this should be uncommon, so we don't handle the edge case.
try:
for event in events:
# This may raise EOFError if the closing handshake
# times out while a message is waiting to be read.
self.process_event(event)
except EOFError:
break
# Breaking out of the while True: ... loop means that we believe
# that the socket doesn't work anymore.
with self.protocol_mutex:
# Feed the end of the data stream to the connection.
self.protocol.receive_eof()
# This isn't expected to generate events.
assert not self.protocol.events_received()
# There is no error handling because send_data() can only write
# the end of the data stream here and it handles errors itself.
self.send_data()
except Exception as exc:
# This branch should never run. It's a safety net in case of bugs.
self.logger.error("unexpected internal error", exc_info=True)
with self.protocol_mutex:
self.set_recv_events_exc(exc)
# We don't know where we crashed. Force protocol state to CLOSED.
self.protocol.state = CLOSED
finally:
# This isn't expected to raise an exception.
self.close_socket()
@contextlib.contextmanager
def send_context(
self,
*,
expected_state: State = OPEN, # CONNECTING during the opening handshake
) -> Iterator[None]:
"""
Create a context for writing to the connection from user code.
On entry, :meth:`send_context` acquires the connection lock and checks
that the connection is open; on exit, it writes outgoing data to the
socket::
with self.send_context():
self.protocol.send_text(message.encode("utf-8"))
When the connection isn't open on entry, when the connection is expected
to close on exit, or when an unexpected error happens, terminating the
connection, :meth:`send_context` waits until the connection is closed
then raises :exc:`~websockets.exceptions.ConnectionClosed`.
"""
# Should we wait until the connection is closed?
wait_for_close = False
# Should we close the socket and raise ConnectionClosed?
raise_close_exc = False
# What exception should we chain ConnectionClosed to?
original_exc: Optional[BaseException] = None
# Acquire the protocol lock.
with self.protocol_mutex:
if self.protocol.state is expected_state:
# Let the caller interact with the protocol.
try:
yield
except (ProtocolError, RuntimeError):
# The protocol state wasn't changed. Exit immediately.
raise
except Exception as exc:
self.logger.error("unexpected internal error", exc_info=True)
# This branch should never run. It's a safety net in case of
# bugs. Since we don't know what happened, we will close the
# connection and raise the exception to the caller.
wait_for_close = False
raise_close_exc = True
original_exc = exc
else:
# Check if the connection is expected to close soon.
if self.protocol.close_expected():
wait_for_close = True
# If the connection is expected to close soon, set the
# close deadline based on the close timeout.
# Since we tested earlier that protocol.state was OPEN
# (or CONNECTING) and we didn't release protocol_mutex,
# it is certain that self.close_deadline is still None.
assert self.close_deadline is None
self.close_deadline = Deadline(self.close_timeout)
# Write outgoing data to the socket.
try:
self.send_data()
except Exception as exc:
if self.debug:
self.logger.debug("error while sending data", exc_info=True)
# While the only expected exception here is OSError,
# other exceptions would be treated identically.
wait_for_close = False
raise_close_exc = True
original_exc = exc
else: # self.protocol.state is not expected_state
# Minor layering violation: we assume that the connection
# will be closing soon if it isn't in the expected state.
wait_for_close = True
raise_close_exc = True
# To avoid a deadlock, release the connection lock by exiting the
# context manager before waiting for recv_events() to terminate.
# If the connection is expected to close soon and the close timeout
# elapses, close the socket to terminate the connection.
if wait_for_close:
if self.close_deadline is None:
timeout = self.close_timeout
else:
# Thread.join() returns immediately if timeout is negative.
timeout = self.close_deadline.timeout(raise_if_elapsed=False)
self.recv_events_thread.join(timeout)
if self.recv_events_thread.is_alive():
# There's no risk to overwrite another error because
# original_exc is never set when wait_for_close is True.
assert original_exc is None
original_exc = TimeoutError("timed out while closing connection")
# Set recv_events_exc before closing the socket in order to get
# proper exception reporting.
raise_close_exc = True
with self.protocol_mutex:
self.set_recv_events_exc(original_exc)
# If an error occurred, close the socket to terminate the connection and
# raise an exception.
if raise_close_exc:
self.close_socket()
self.recv_events_thread.join()
raise self.protocol.close_exc from original_exc
def send_data(self) -> None:
"""
Send outgoing data.
This method requires holding protocol_mutex.
Raises:
OSError: When a socket operations fails.
"""
assert self.protocol_mutex.locked()
for data in self.protocol.data_to_send():
if data:
if self.close_deadline is not None:
self.socket.settimeout(self.close_deadline.timeout())
self.socket.sendall(data)
else:
try:
self.socket.shutdown(socket.SHUT_WR)
except OSError: # socket already closed
pass
def set_recv_events_exc(self, exc: Optional[BaseException]) -> None:
"""
Set recv_events_exc, if not set yet.
This method requires holding protocol_mutex.
"""
assert self.protocol_mutex.locked()
if self.recv_events_exc is None:
self.recv_events_exc = exc
def close_socket(self) -> None:
"""
Shutdown and close socket. Close message assembler.
Calling close_socket() guarantees that recv_events() terminates. Indeed,
recv_events() may block only on socket.recv() or on recv_messages.put().
"""
# shutdown() is required to interrupt recv() on Linux.
try:
self.socket.shutdown(socket.SHUT_RDWR)
except OSError:
pass # socket is already closed
self.socket.close()
self.recv_messages.close()

View File

@@ -0,0 +1,281 @@
from __future__ import annotations
import codecs
import queue
import threading
from typing import Iterator, List, Optional, cast
from ..frames import Frame, Opcode
from ..typing import Data
__all__ = ["Assembler"]
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
class Assembler:
"""
Assemble messages from frames.
"""
def __init__(self) -> None:
# Serialize reads and writes -- except for reads via synchronization
# primitives provided by the threading and queue modules.
self.mutex = threading.Lock()
# We create a latch with two events to ensure proper interleaving of
# writing and reading messages.
# put() sets this event to tell get() that a message can be fetched.
self.message_complete = threading.Event()
# get() sets this event to let put() that the message was fetched.
self.message_fetched = threading.Event()
# This flag prevents concurrent calls to get() by user code.
self.get_in_progress = False
# This flag prevents concurrent calls to put() by library code.
self.put_in_progress = False
# Decoder for text frames, None for binary frames.
self.decoder: Optional[codecs.IncrementalDecoder] = None
# Buffer of frames belonging to the same message.
self.chunks: List[Data] = []
# When switching from "buffering" to "streaming", we use a thread-safe
# queue for transferring frames from the writing thread (library code)
# to the reading thread (user code). We're buffering when chunks_queue
# is None and streaming when it's a SimpleQueue. None is a sentinel
# value marking the end of the stream, superseding message_complete.
# Stream data from frames belonging to the same message.
# Remove quotes around type when dropping Python < 3.9.
self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None
# This flag marks the end of the stream.
self.closed = False
def get(self, timeout: Optional[float] = None) -> Data:
"""
Read the next message.
:meth:`get` returns a single :class:`str` or :class:`bytes`.
If the message is fragmented, :meth:`get` waits until the last frame is
received, then it reassembles the message and returns it. To receive
messages frame by frame, use :meth:`get_iter` instead.
Args:
timeout: If a timeout is provided and elapses before a complete
message is received, :meth:`get` raises :exc:`TimeoutError`.
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
concurrently.
"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")
if self.get_in_progress:
raise RuntimeError("get or get_iter is already running")
self.get_in_progress = True
# If the message_complete event isn't set yet, release the lock to
# allow put() to run and eventually set it.
# Locking with get_in_progress ensures only one thread can get here.
completed = self.message_complete.wait(timeout)
with self.mutex:
self.get_in_progress = False
# Waiting for a complete message timed out.
if not completed:
raise TimeoutError(f"timed out in {timeout:.1f}s")
# get() was unblocked by close() rather than put().
if self.closed:
raise EOFError("stream of frames ended")
assert self.message_complete.is_set()
self.message_complete.clear()
joiner: Data = b"" if self.decoder is None else ""
# mypy cannot figure out that chunks have the proper type.
message: Data = joiner.join(self.chunks) # type: ignore
assert not self.message_fetched.is_set()
self.message_fetched.set()
self.chunks = []
assert self.chunks_queue is None
return message
def get_iter(self) -> Iterator[Data]:
"""
Stream the next message.
Iterating the return value of :meth:`get_iter` yields a :class:`str` or
:class:`bytes` for each frame in the message.
The iterator must be fully consumed before calling :meth:`get_iter` or
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
This method only makes sense for fragmented messages. If messages aren't
fragmented, use :meth:`get` instead.
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
concurrently.
"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")
if self.get_in_progress:
raise RuntimeError("get or get_iter is already running")
chunks = self.chunks
self.chunks = []
self.chunks_queue = cast(
# Remove quotes around type when dropping Python < 3.9.
"queue.SimpleQueue[Optional[Data]]",
queue.SimpleQueue(),
)
# Sending None in chunk_queue supersedes setting message_complete
# when switching to "streaming". If message is already complete
# when the switch happens, put() didn't send None, so we have to.
if self.message_complete.is_set():
self.chunks_queue.put(None)
self.get_in_progress = True
# Locking with get_in_progress ensures only one thread can get here.
yield from chunks
while True:
chunk = self.chunks_queue.get()
if chunk is None:
break
yield chunk
with self.mutex:
self.get_in_progress = False
assert self.message_complete.is_set()
self.message_complete.clear()
# get_iter() was unblocked by close() rather than put().
if self.closed:
raise EOFError("stream of frames ended")
assert not self.message_fetched.is_set()
self.message_fetched.set()
assert self.chunks == []
self.chunks_queue = None
def put(self, frame: Frame) -> None:
"""
Add ``frame`` to the next message.
When ``frame`` is the final frame in a message, :meth:`put` waits until
the message is fetched, either by calling :meth:`get` or by fully
consuming the return value of :meth:`get_iter`.
:meth:`put` assumes that the stream of frames respects the protocol. If
it doesn't, the behavior is undefined.
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`put` concurrently.
"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")
if self.put_in_progress:
raise RuntimeError("put is already running")
if frame.opcode is Opcode.TEXT:
self.decoder = UTF8Decoder(errors="strict")
elif frame.opcode is Opcode.BINARY:
self.decoder = None
elif frame.opcode is Opcode.CONT:
pass
else:
# Ignore control frames.
return
data: Data
if self.decoder is not None:
data = self.decoder.decode(frame.data, frame.fin)
else:
data = frame.data
if self.chunks_queue is None:
self.chunks.append(data)
else:
self.chunks_queue.put(data)
if not frame.fin:
return
# Message is complete. Wait until it's fetched to return.
assert not self.message_complete.is_set()
self.message_complete.set()
if self.chunks_queue is not None:
self.chunks_queue.put(None)
assert not self.message_fetched.is_set()
self.put_in_progress = True
# Release the lock to allow get() to run and eventually set the event.
self.message_fetched.wait()
with self.mutex:
self.put_in_progress = False
assert self.message_fetched.is_set()
self.message_fetched.clear()
# put() was unblocked by close() rather than get() or get_iter().
if self.closed:
raise EOFError("stream of frames ended")
self.decoder = None
def close(self) -> None:
"""
End the stream of frames.
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
or :meth:`put` is safe. They will raise :exc:`EOFError`.
"""
with self.mutex:
if self.closed:
return
self.closed = True
# Unblock get or get_iter.
if self.get_in_progress:
self.message_complete.set()
if self.chunks_queue is not None:
self.chunks_queue.put(None)
# Unblock put().
if self.put_in_progress:
self.message_fetched.set()

View File

@@ -0,0 +1,530 @@
from __future__ import annotations
import http
import logging
import os
import selectors
import socket
import ssl
import sys
import threading
from types import TracebackType
from typing import Any, Callable, Optional, Sequence, Type
from websockets.frames import CloseCode
from ..extensions.base import ServerExtensionFactory
from ..extensions.permessage_deflate import enable_server_permessage_deflate
from ..headers import validate_subprotocols
from ..http import USER_AGENT
from ..http11 import Request, Response
from ..protocol import CONNECTING, OPEN, Event
from ..server import ServerProtocol
from ..typing import LoggerLike, Origin, Subprotocol
from .connection import Connection
from .utils import Deadline
__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"]
class ServerConnection(Connection):
"""
Threaded implementation of a WebSocket server connection.
:class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
receiving and sending messages.
It supports iteration to receive messages::
for message in websocket:
process(message)
The iterator exits normally when the connection is closed with close code
1000 (OK) or 1001 (going away) or without a close code. It raises a
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
Args:
socket: Socket connected to a WebSocket client.
protocol: Sans-I/O connection.
close_timeout: Timeout for closing the connection in seconds.
"""
def __init__(
self,
socket: socket.socket,
protocol: ServerProtocol,
*,
close_timeout: Optional[float] = 10,
) -> None:
self.protocol: ServerProtocol
self.request_rcvd = threading.Event()
super().__init__(
socket,
protocol,
close_timeout=close_timeout,
)
def handshake(
self,
process_request: Optional[
Callable[
[ServerConnection, Request],
Optional[Response],
]
] = None,
process_response: Optional[
Callable[
[ServerConnection, Request, Response],
Optional[Response],
]
] = None,
server_header: Optional[str] = USER_AGENT,
timeout: Optional[float] = None,
) -> None:
"""
Perform the opening handshake.
"""
if not self.request_rcvd.wait(timeout):
self.close_socket()
self.recv_events_thread.join()
raise TimeoutError("timed out during handshake")
if self.request is None:
self.close_socket()
self.recv_events_thread.join()
raise ConnectionError("connection closed during handshake")
with self.send_context(expected_state=CONNECTING):
self.response = None
if process_request is not None:
try:
self.response = process_request(self, self.request)
except Exception as exc:
self.protocol.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
self.response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if self.response is None:
self.response = self.protocol.accept(self.request)
if server_header is not None:
self.response.headers["Server"] = server_header
if process_response is not None:
try:
response = process_response(self, self.request, self.response)
except Exception as exc:
self.protocol.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
self.response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
else:
if response is not None:
self.response = response
self.protocol.send_response(self.response)
if self.protocol.state is not OPEN:
self.recv_events_thread.join(self.close_timeout)
self.close_socket()
self.recv_events_thread.join()
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
def process_event(self, event: Event) -> None:
"""
Process one incoming event.
"""
# First event - handshake request.
if self.request is None:
assert isinstance(event, Request)
self.request = event
self.request_rcvd.set()
# Later events - frames.
else:
super().process_event(event)
def recv_events(self) -> None:
"""
Read incoming data from the socket and process events.
"""
try:
super().recv_events()
finally:
# If the connection is closed during the handshake, unblock it.
self.request_rcvd.set()
class WebSocketServer:
"""
WebSocket server returned by :func:`serve`.
This class mirrors the API of :class:`~socketserver.BaseServer`, notably the
:meth:`~socketserver.BaseServer.serve_forever` and
:meth:`~socketserver.BaseServer.shutdown` methods, as well as the context
manager protocol.
Args:
socket: Server socket listening for new connections.
handler: Handler for one connection. Receives the socket and address
returned by :meth:`~socket.socket.accept`.
logger: Logger for this server.
"""
def __init__(
self,
socket: socket.socket,
handler: Callable[[socket.socket, Any], None],
logger: Optional[LoggerLike] = None,
):
self.socket = socket
self.handler = handler
if logger is None:
logger = logging.getLogger("websockets.server")
self.logger = logger
if sys.platform != "win32":
self.shutdown_watcher, self.shutdown_notifier = os.pipe()
def serve_forever(self) -> None:
"""
See :meth:`socketserver.BaseServer.serve_forever`.
This method doesn't return. Calling :meth:`shutdown` from another thread
stops the server.
Typical use::
with serve(...) as server:
server.serve_forever()
"""
poller = selectors.DefaultSelector()
poller.register(self.socket, selectors.EVENT_READ)
if sys.platform != "win32":
poller.register(self.shutdown_watcher, selectors.EVENT_READ)
while True:
poller.select()
try:
# If the socket is closed, this will raise an exception and exit
# the loop. So we don't need to check the return value of select().
sock, addr = self.socket.accept()
except OSError:
break
thread = threading.Thread(target=self.handler, args=(sock, addr))
thread.start()
def shutdown(self) -> None:
"""
See :meth:`socketserver.BaseServer.shutdown`.
"""
self.socket.close()
if sys.platform != "win32":
os.write(self.shutdown_notifier, b"x")
def fileno(self) -> int:
"""
See :meth:`socketserver.BaseServer.fileno`.
"""
return self.socket.fileno()
def __enter__(self) -> WebSocketServer:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.shutdown()
def serve(
handler: Callable[[ServerConnection], None],
host: Optional[str] = None,
port: Optional[int] = None,
*,
# TCP/TLS — unix and path are only for unix_serve()
sock: Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
unix: bool = False,
path: Optional[str] = None,
# WebSocket
origins: Optional[Sequence[Optional[Origin]]] = None,
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
select_subprotocol: Optional[
Callable[
[ServerConnection, Sequence[Subprotocol]],
Optional[Subprotocol],
]
] = None,
process_request: Optional[
Callable[
[ServerConnection, Request],
Optional[Response],
]
] = None,
process_response: Optional[
Callable[
[ServerConnection, Request, Response],
Optional[Response],
]
] = None,
server_header: Optional[str] = USER_AGENT,
compression: Optional[str] = "deflate",
# Timeouts
open_timeout: Optional[float] = 10,
close_timeout: Optional[float] = 10,
# Limits
max_size: Optional[int] = 2**20,
# Logging
logger: Optional[LoggerLike] = None,
# Escape hatch for advanced customization
create_connection: Optional[Type[ServerConnection]] = None,
) -> WebSocketServer:
"""
Create a WebSocket server listening on ``host`` and ``port``.
Whenever a client connects, the server creates a :class:`ServerConnection`,
performs the opening handshake, and delegates to the ``handler``.
The handler receives a :class:`ServerConnection` instance, which you can use
to send and receive messages.
Once the handler completes, either normally or with an exception, the server
performs the closing handshake and closes the connection.
:class:`WebSocketServer` mirrors the API of
:class:`~socketserver.BaseServer`. Treat it as a context manager to ensure
that it will be closed and call the :meth:`~WebSocketServer.serve_forever`
method to serve requests::
def handler(websocket):
...
with websockets.sync.server.serve(handler, ...) as server:
server.serve_forever()
Args:
handler: Connection handler. It receives the WebSocket connection,
which is a :class:`ServerConnection`, in argument.
host: Network interfaces the server binds to.
See :func:`~socket.create_server` for details.
port: TCP port the server listens on.
See :func:`~socket.create_server` for details.
sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``.
You may call :func:`socket.create_server` to create a suitable TCP
socket.
ssl_context: Configuration for enabling TLS on the connection.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
in the list if the lack of an origin is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
preference.
select_subprotocol: Callback for selecting a subprotocol among
those supported by the client and the server. It receives a
:class:`ServerConnection` (not a
:class:`~websockets.server.ServerProtocol`!) instance and a list of
subprotocols offered by the client. Other than the first argument,
it has the same behavior as the
:meth:`ServerProtocol.select_subprotocol
<websockets.server.ServerProtocol.select_subprotocol>` method.
process_request: Intercept the request during the opening handshake.
Return an HTTP response to force the response or :obj:`None` to
continue normally. When you force an HTTP 101 Continue response,
the handshake is successful. Else, the connection is aborted.
process_response: Intercept the response during the opening handshake.
Return an HTTP response to force the response or :obj:`None` to
continue normally. When you force an HTTP 101 Continue response,
the handshake is successful. Else, the connection is aborted.
server_header: Value of the ``Server`` response header.
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
:obj:`None` removes the header.
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
open_timeout: Timeout for opening connections in seconds.
:obj:`None` disables the timeout.
close_timeout: Timeout for closing connections in seconds.
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
:obj:`None` disables the limit.
logger: Logger for this server.
It defaults to ``logging.getLogger("websockets.server")``. See the
:doc:`logging guide <../../topics/logging>` for details.
create_connection: Factory for the :class:`ServerConnection` managing
the connection. Set it to a wrapper or a subclass to customize
connection handling.
"""
# Process parameters
if subprotocols is not None:
validate_subprotocols(subprotocols)
if compression == "deflate":
extensions = enable_server_permessage_deflate(extensions)
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
if create_connection is None:
create_connection = ServerConnection
# Bind socket and listen
if sock is None:
if unix:
if path is None:
raise TypeError("missing path argument")
sock = socket.create_server(path, family=socket.AF_UNIX)
else:
sock = socket.create_server((host, port))
else:
if path is not None:
raise TypeError("path and sock arguments are incompatible")
# Initialize TLS wrapper
if ssl_context is not None:
sock = ssl_context.wrap_socket(
sock,
server_side=True,
# Delay TLS handshake until after we set a timeout on the socket.
do_handshake_on_connect=False,
)
# Define request handler
def conn_handler(sock: socket.socket, addr: Any) -> None:
# Calculate timeouts on the TLS and WebSocket handshakes.
# The TLS timeout must be set on the socket, then removed
# to avoid conflicting with the WebSocket timeout in handshake().
deadline = Deadline(open_timeout)
try:
# Disable Nagle algorithm
if not unix:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
# Perform TLS handshake
if ssl_context is not None:
sock.settimeout(deadline.timeout())
assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out
sock.do_handshake()
sock.settimeout(None)
# Create a closure so that select_subprotocol has access to self.
protocol_select_subprotocol: Optional[
Callable[
[ServerProtocol, Sequence[Subprotocol]],
Optional[Subprotocol],
]
] = None
if select_subprotocol is not None:
def protocol_select_subprotocol(
protocol: ServerProtocol,
subprotocols: Sequence[Subprotocol],
) -> Optional[Subprotocol]:
# mypy doesn't know that select_subprotocol is immutable.
assert select_subprotocol is not None
# Ensure this function is only used in the intended context.
assert protocol is connection.protocol
return select_subprotocol(connection, subprotocols)
# Initialize WebSocket connection
protocol = ServerProtocol(
origins=origins,
extensions=extensions,
subprotocols=subprotocols,
select_subprotocol=protocol_select_subprotocol,
state=CONNECTING,
max_size=max_size,
logger=logger,
)
# Initialize WebSocket protocol
assert create_connection is not None # help mypy
connection = create_connection(
sock,
protocol,
close_timeout=close_timeout,
)
# On failure, handshake() closes the socket, raises an exception, and
# logs it.
connection.handshake(
process_request,
process_response,
server_header,
deadline.timeout(),
)
except Exception:
sock.close()
return
try:
handler(connection)
except Exception:
protocol.logger.error("connection handler failed", exc_info=True)
connection.close(CloseCode.INTERNAL_ERROR)
else:
connection.close()
# Initialize server
return WebSocketServer(sock, conn_handler, logger)
def unix_serve(
handler: Callable[[ServerConnection], Any],
path: Optional[str] = None,
**kwargs: Any,
) -> WebSocketServer:
"""
Create a WebSocket server listening on a Unix socket.
This function is identical to :func:`serve`, except the ``host`` and
``port`` arguments are replaced by ``path``. It's only available on Unix.
It's useful for deploying a server behind a reverse proxy such as nginx.
Args:
handler: Connection handler. It receives the WebSocket connection,
which is a :class:`ServerConnection`, in argument.
path: File system path to the Unix socket.
"""
return serve(handler, path=path, unix=True, **kwargs)

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import time
from typing import Optional
__all__ = ["Deadline"]
class Deadline:
"""
Manage timeouts across multiple steps.
Args:
timeout: Time available in seconds or :obj:`None` if there is no limit.
"""
def __init__(self, timeout: Optional[float]) -> None:
self.deadline: Optional[float]
if timeout is None:
self.deadline = None
else:
self.deadline = time.monotonic() + timeout
def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]:
"""
Calculate a timeout from a deadline.
Args:
raise_if_elapsed (bool): Whether to raise :exc:`TimeoutError`
if the deadline lapsed.
Raises:
TimeoutError: If the deadline lapsed.
Returns:
Time left in seconds or :obj:`None` if there is no limit.
"""
if self.deadline is None:
return None
timeout = self.deadline - time.monotonic()
if raise_if_elapsed and timeout <= 0:
raise TimeoutError("timed out")
return timeout