490 lines
18 KiB
490 lines
18 KiB
"""
|
|
wsproto/handshake
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
An implementation of WebSocket handshakes.
|
|
"""
|
|
from collections import deque
|
|
from typing import (
|
|
cast,
|
|
Deque,
|
|
Dict,
|
|
Generator,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Union,
|
|
)
|
|
|
|
import h11
|
|
|
|
from .connection import Connection, ConnectionState, ConnectionType
|
|
from .events import AcceptConnection, Event, RejectConnection, RejectData, Request
|
|
from .extensions import Extension
|
|
from .typing import Headers
|
|
from .utilities import (
|
|
generate_accept_token,
|
|
generate_nonce,
|
|
LocalProtocolError,
|
|
normed_header_dict,
|
|
RemoteProtocolError,
|
|
split_comma_header,
|
|
)
|
|
|
|
# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
|
|
WEBSOCKET_VERSION = b"13"
|
|
|
|
|
|
class H11Handshake:
|
|
"""A Handshake implementation for HTTP/1.1 connections."""
|
|
|
|
def __init__(self, connection_type: ConnectionType) -> None:
|
|
self.client = connection_type is ConnectionType.CLIENT
|
|
self._state = ConnectionState.CONNECTING
|
|
|
|
if self.client:
|
|
self._h11_connection = h11.Connection(h11.CLIENT)
|
|
else:
|
|
self._h11_connection = h11.Connection(h11.SERVER)
|
|
|
|
self._connection: Optional[Connection] = None
|
|
self._events: Deque[Event] = deque()
|
|
self._initiating_request: Optional[Request] = None
|
|
self._nonce: Optional[bytes] = None
|
|
|
|
@property
|
|
def state(self) -> ConnectionState:
|
|
return self._state
|
|
|
|
@property
|
|
def connection(self) -> Optional[Connection]:
|
|
"""Return the established connection.
|
|
|
|
This will either return the connection or raise a
|
|
LocalProtocolError if the connection has not yet been
|
|
established.
|
|
|
|
:rtype: h11.Connection
|
|
"""
|
|
return self._connection
|
|
|
|
def initiate_upgrade_connection(self, headers: Headers, path: str) -> None:
|
|
"""Initiate an upgrade connection.
|
|
|
|
This should be used if the request has already be received and
|
|
parsed.
|
|
|
|
:param list headers: HTTP headers represented as a list of 2-tuples.
|
|
:param str path: A URL path.
|
|
"""
|
|
if self.client:
|
|
raise LocalProtocolError(
|
|
"Cannot initiate an upgrade connection when acting as the client"
|
|
)
|
|
upgrade_request = h11.Request(method=b"GET", target=path, headers=headers)
|
|
h11_client = h11.Connection(h11.CLIENT)
|
|
self.receive_data(h11_client.send(upgrade_request))
|
|
|
|
def send(self, event: Event) -> bytes:
|
|
"""Send an event to the remote.
|
|
|
|
This will return the bytes to send based on the event or raise
|
|
a LocalProtocolError if the event is not valid given the
|
|
state.
|
|
|
|
:returns: Data to send to the WebSocket peer.
|
|
:rtype: bytes
|
|
"""
|
|
data = b""
|
|
if isinstance(event, Request):
|
|
data += self._initiate_connection(event)
|
|
elif isinstance(event, AcceptConnection):
|
|
data += self._accept(event)
|
|
elif isinstance(event, RejectConnection):
|
|
data += self._reject(event)
|
|
elif isinstance(event, RejectData):
|
|
data += self._send_reject_data(event)
|
|
else:
|
|
raise LocalProtocolError(
|
|
f"Event {event} cannot be sent during the handshake"
|
|
)
|
|
return data
|
|
|
|
def receive_data(self, data: Optional[bytes]) -> None:
|
|
"""Receive data from the remote.
|
|
|
|
A list of events that the remote peer triggered by sending
|
|
this data can be retrieved with :meth:`events`.
|
|
|
|
:param bytes data: Data received from the WebSocket peer.
|
|
"""
|
|
self._h11_connection.receive_data(data or b"")
|
|
while True:
|
|
try:
|
|
event = self._h11_connection.next_event()
|
|
except h11.RemoteProtocolError:
|
|
raise RemoteProtocolError(
|
|
"Bad HTTP message", event_hint=RejectConnection()
|
|
)
|
|
if (
|
|
isinstance(event, h11.ConnectionClosed)
|
|
or event is h11.NEED_DATA
|
|
or event is h11.PAUSED
|
|
):
|
|
break
|
|
|
|
if self.client:
|
|
if isinstance(event, h11.InformationalResponse):
|
|
if event.status_code == 101:
|
|
self._events.append(self._establish_client_connection(event))
|
|
else:
|
|
self._events.append(
|
|
RejectConnection(
|
|
headers=list(event.headers),
|
|
status_code=event.status_code,
|
|
has_body=False,
|
|
)
|
|
)
|
|
self._state = ConnectionState.CLOSED
|
|
elif isinstance(event, h11.Response):
|
|
self._state = ConnectionState.REJECTING
|
|
self._events.append(
|
|
RejectConnection(
|
|
headers=list(event.headers),
|
|
status_code=event.status_code,
|
|
has_body=True,
|
|
)
|
|
)
|
|
elif isinstance(event, h11.Data):
|
|
self._events.append(
|
|
RejectData(data=event.data, body_finished=False)
|
|
)
|
|
elif isinstance(event, h11.EndOfMessage):
|
|
self._events.append(RejectData(data=b"", body_finished=True))
|
|
self._state = ConnectionState.CLOSED
|
|
else:
|
|
if isinstance(event, h11.Request):
|
|
self._events.append(self._process_connection_request(event))
|
|
|
|
def events(self) -> Generator[Event, None, None]:
|
|
"""Return a generator that provides any events that have been generated
|
|
by protocol activity.
|
|
|
|
:returns: a generator that yields H11 events.
|
|
"""
|
|
while self._events:
|
|
yield self._events.popleft()
|
|
|
|
# Server mode methods
|
|
|
|
def _process_connection_request( # noqa: MC0001
|
|
self, event: h11.Request
|
|
) -> Request:
|
|
if event.method != b"GET":
|
|
raise RemoteProtocolError(
|
|
"Request method must be GET", event_hint=RejectConnection()
|
|
)
|
|
connection_tokens = None
|
|
extensions: List[str] = []
|
|
host = None
|
|
key = None
|
|
subprotocols: List[str] = []
|
|
upgrade = b""
|
|
version = None
|
|
headers: Headers = []
|
|
for name, value in event.headers:
|
|
name = name.lower()
|
|
if name == b"connection":
|
|
connection_tokens = split_comma_header(value)
|
|
elif name == b"host":
|
|
host = value.decode("ascii")
|
|
continue # Skip appending to headers
|
|
elif name == b"sec-websocket-extensions":
|
|
extensions = split_comma_header(value)
|
|
continue # Skip appending to headers
|
|
elif name == b"sec-websocket-key":
|
|
key = value
|
|
elif name == b"sec-websocket-protocol":
|
|
subprotocols = split_comma_header(value)
|
|
continue # Skip appending to headers
|
|
elif name == b"sec-websocket-version":
|
|
version = value
|
|
elif name == b"upgrade":
|
|
upgrade = value
|
|
headers.append((name, value))
|
|
if connection_tokens is None or not any(
|
|
token.lower() == "upgrade" for token in connection_tokens
|
|
):
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
|
|
)
|
|
if version != WEBSOCKET_VERSION:
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Sec-WebSocket-Version'",
|
|
event_hint=RejectConnection(
|
|
headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)],
|
|
status_code=426 if version else 400,
|
|
),
|
|
)
|
|
if key is None:
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection()
|
|
)
|
|
if upgrade.lower() != b"websocket":
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
|
|
)
|
|
if host is None:
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Host'", event_hint=RejectConnection()
|
|
)
|
|
|
|
self._initiating_request = Request(
|
|
extensions=extensions,
|
|
extra_headers=headers,
|
|
host=host,
|
|
subprotocols=subprotocols,
|
|
target=event.target.decode("ascii"),
|
|
)
|
|
return self._initiating_request
|
|
|
|
def _accept(self, event: AcceptConnection) -> bytes:
|
|
# _accept is always called after _process_connection_request.
|
|
assert self._initiating_request is not None
|
|
request_headers = normed_header_dict(self._initiating_request.extra_headers)
|
|
|
|
nonce = request_headers[b"sec-websocket-key"]
|
|
accept_token = generate_accept_token(nonce)
|
|
|
|
headers = [
|
|
(b"Upgrade", b"WebSocket"),
|
|
(b"Connection", b"Upgrade"),
|
|
(b"Sec-WebSocket-Accept", accept_token),
|
|
]
|
|
|
|
if event.subprotocol is not None:
|
|
if event.subprotocol not in self._initiating_request.subprotocols:
|
|
raise LocalProtocolError(f"unexpected subprotocol {event.subprotocol}")
|
|
headers.append(
|
|
(b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii"))
|
|
)
|
|
|
|
if event.extensions:
|
|
accepts = server_extensions_handshake(
|
|
cast(Sequence[str], self._initiating_request.extensions),
|
|
event.extensions,
|
|
)
|
|
if accepts:
|
|
headers.append((b"Sec-WebSocket-Extensions", accepts))
|
|
|
|
response = h11.InformationalResponse(
|
|
status_code=101, headers=headers + event.extra_headers
|
|
)
|
|
self._connection = Connection(
|
|
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
|
|
event.extensions,
|
|
)
|
|
self._state = ConnectionState.OPEN
|
|
return self._h11_connection.send(response) or b""
|
|
|
|
def _reject(self, event: RejectConnection) -> bytes:
|
|
if self.state != ConnectionState.CONNECTING:
|
|
raise LocalProtocolError(
|
|
"Connection cannot be rejected in state %s" % self.state
|
|
)
|
|
|
|
headers = list(event.headers)
|
|
if not event.has_body:
|
|
headers.append((b"content-length", b"0"))
|
|
response = h11.Response(status_code=event.status_code, headers=headers)
|
|
data = self._h11_connection.send(response) or b""
|
|
self._state = ConnectionState.REJECTING
|
|
if not event.has_body:
|
|
data += self._h11_connection.send(h11.EndOfMessage()) or b""
|
|
self._state = ConnectionState.CLOSED
|
|
return data
|
|
|
|
def _send_reject_data(self, event: RejectData) -> bytes:
|
|
if self.state != ConnectionState.REJECTING:
|
|
raise LocalProtocolError(
|
|
f"Cannot send rejection data in state {self.state}"
|
|
)
|
|
|
|
data = self._h11_connection.send(h11.Data(data=event.data)) or b""
|
|
if event.body_finished:
|
|
data += self._h11_connection.send(h11.EndOfMessage()) or b""
|
|
self._state = ConnectionState.CLOSED
|
|
return data
|
|
|
|
# Client mode methods
|
|
|
|
def _initiate_connection(self, request: Request) -> bytes:
|
|
self._initiating_request = request
|
|
self._nonce = generate_nonce()
|
|
|
|
headers = [
|
|
(b"Host", request.host.encode("ascii")),
|
|
(b"Upgrade", b"WebSocket"),
|
|
(b"Connection", b"Upgrade"),
|
|
(b"Sec-WebSocket-Key", self._nonce),
|
|
(b"Sec-WebSocket-Version", WEBSOCKET_VERSION),
|
|
]
|
|
|
|
if request.subprotocols:
|
|
headers.append(
|
|
(
|
|
b"Sec-WebSocket-Protocol",
|
|
(", ".join(request.subprotocols)).encode("ascii"),
|
|
)
|
|
)
|
|
|
|
if request.extensions:
|
|
offers: Dict[str, Union[str, bool]] = {}
|
|
for e in request.extensions:
|
|
assert isinstance(e, Extension)
|
|
offers[e.name] = e.offer()
|
|
extensions = []
|
|
for name, params in offers.items():
|
|
bname = name.encode("ascii")
|
|
if isinstance(params, bool):
|
|
if params:
|
|
extensions.append(bname)
|
|
else:
|
|
extensions.append(b"%s; %s" % (bname, params.encode("ascii")))
|
|
if extensions:
|
|
headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions)))
|
|
|
|
upgrade = h11.Request(
|
|
method=b"GET",
|
|
target=request.target.encode("ascii"),
|
|
headers=headers + request.extra_headers,
|
|
)
|
|
return self._h11_connection.send(upgrade) or b""
|
|
|
|
def _establish_client_connection(
|
|
self, event: h11.InformationalResponse
|
|
) -> AcceptConnection: # noqa: MC0001
|
|
# _establish_client_connection is always called after _initiate_connection.
|
|
assert self._initiating_request is not None
|
|
assert self._nonce is not None
|
|
|
|
accept = None
|
|
connection_tokens = None
|
|
accepts: List[str] = []
|
|
subprotocol = None
|
|
upgrade = b""
|
|
headers: Headers = []
|
|
for name, value in event.headers:
|
|
name = name.lower()
|
|
if name == b"connection":
|
|
connection_tokens = split_comma_header(value)
|
|
continue # Skip appending to headers
|
|
elif name == b"sec-websocket-extensions":
|
|
accepts = split_comma_header(value)
|
|
continue # Skip appending to headers
|
|
elif name == b"sec-websocket-accept":
|
|
accept = value
|
|
continue # Skip appending to headers
|
|
elif name == b"sec-websocket-protocol":
|
|
subprotocol = value.decode("ascii")
|
|
continue # Skip appending to headers
|
|
elif name == b"upgrade":
|
|
upgrade = value
|
|
continue # Skip appending to headers
|
|
headers.append((name, value))
|
|
|
|
if connection_tokens is None or not any(
|
|
token.lower() == "upgrade" for token in connection_tokens
|
|
):
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
|
|
)
|
|
if upgrade.lower() != b"websocket":
|
|
raise RemoteProtocolError(
|
|
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
|
|
)
|
|
accept_token = generate_accept_token(self._nonce)
|
|
if accept != accept_token:
|
|
raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection())
|
|
if subprotocol is not None:
|
|
if subprotocol not in self._initiating_request.subprotocols:
|
|
raise RemoteProtocolError(
|
|
f"unrecognized subprotocol {subprotocol}",
|
|
event_hint=RejectConnection(),
|
|
)
|
|
extensions = client_extensions_handshake(
|
|
accepts, cast(Sequence[Extension], self._initiating_request.extensions)
|
|
)
|
|
|
|
self._connection = Connection(
|
|
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
|
|
extensions,
|
|
self._h11_connection.trailing_data[0],
|
|
)
|
|
self._state = ConnectionState.OPEN
|
|
return AcceptConnection(
|
|
extensions=extensions, extra_headers=headers, subprotocol=subprotocol
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return "{}(client={}, state={})".format(
|
|
self.__class__.__name__, self.client, self.state
|
|
)
|
|
|
|
|
|
def server_extensions_handshake(
|
|
requested: Iterable[str], supported: List[Extension]
|
|
) -> Optional[bytes]:
|
|
"""Agree on the extensions to use returning an appropriate header value.
|
|
|
|
This returns None if there are no agreed extensions
|
|
"""
|
|
accepts: Dict[str, Union[bool, bytes]] = {}
|
|
for offer in requested:
|
|
name = offer.split(";", 1)[0].strip()
|
|
for extension in supported:
|
|
if extension.name == name:
|
|
accept = extension.accept(offer)
|
|
if isinstance(accept, bool):
|
|
if accept:
|
|
accepts[extension.name] = True
|
|
elif accept is not None:
|
|
accepts[extension.name] = accept.encode("ascii")
|
|
|
|
if accepts:
|
|
extensions: List[bytes] = []
|
|
for name, params in accepts.items():
|
|
name_bytes = name.encode("ascii")
|
|
if isinstance(params, bool):
|
|
assert params
|
|
extensions.append(name_bytes)
|
|
else:
|
|
if params == b"":
|
|
extensions.append(b"%s" % (name_bytes))
|
|
else:
|
|
extensions.append(b"%s; %s" % (name_bytes, params))
|
|
return b", ".join(extensions)
|
|
|
|
return None
|
|
|
|
|
|
def client_extensions_handshake(
|
|
accepted: Iterable[str], supported: Sequence[Extension]
|
|
) -> List[Extension]:
|
|
# This raises RemoteProtocolError is the accepted extension is not
|
|
# supported.
|
|
extensions = []
|
|
for accept in accepted:
|
|
name = accept.split(";", 1)[0].strip()
|
|
for extension in supported:
|
|
if extension.name == name:
|
|
extension.finalize(accept)
|
|
extensions.append(extension)
|
|
break
|
|
else:
|
|
raise RemoteProtocolError(
|
|
f"unrecognized extension {name}", event_hint=RejectConnection()
|
|
)
|
|
return extensions
|