490 lines
18 KiB

"""
wsproto/handshake
~~~~~~~~~~~~~~~~~~
An implementation of WebSocket handshakes.
"""
from collections import deque
from typing import (
cast,
Deque,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Union,
)
import h11
from .connection import Connection, ConnectionState, ConnectionType
from .events import AcceptConnection, Event, RejectConnection, RejectData, Request
from .extensions import Extension
from .typing import Headers
from .utilities import (
generate_accept_token,
generate_nonce,
LocalProtocolError,
normed_header_dict,
RemoteProtocolError,
split_comma_header,
)
# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
WEBSOCKET_VERSION = b"13"
class H11Handshake:
"""A Handshake implementation for HTTP/1.1 connections."""
def __init__(self, connection_type: ConnectionType) -> None:
self.client = connection_type is ConnectionType.CLIENT
self._state = ConnectionState.CONNECTING
if self.client:
self._h11_connection = h11.Connection(h11.CLIENT)
else:
self._h11_connection = h11.Connection(h11.SERVER)
self._connection: Optional[Connection] = None
self._events: Deque[Event] = deque()
self._initiating_request: Optional[Request] = None
self._nonce: Optional[bytes] = None
@property
def state(self) -> ConnectionState:
return self._state
@property
def connection(self) -> Optional[Connection]:
"""Return the established connection.
This will either return the connection or raise a
LocalProtocolError if the connection has not yet been
established.
:rtype: h11.Connection
"""
return self._connection
def initiate_upgrade_connection(self, headers: Headers, path: str) -> None:
"""Initiate an upgrade connection.
This should be used if the request has already be received and
parsed.
:param list headers: HTTP headers represented as a list of 2-tuples.
:param str path: A URL path.
"""
if self.client:
raise LocalProtocolError(
"Cannot initiate an upgrade connection when acting as the client"
)
upgrade_request = h11.Request(method=b"GET", target=path, headers=headers)
h11_client = h11.Connection(h11.CLIENT)
self.receive_data(h11_client.send(upgrade_request))
def send(self, event: Event) -> bytes:
"""Send an event to the remote.
This will return the bytes to send based on the event or raise
a LocalProtocolError if the event is not valid given the
state.
:returns: Data to send to the WebSocket peer.
:rtype: bytes
"""
data = b""
if isinstance(event, Request):
data += self._initiate_connection(event)
elif isinstance(event, AcceptConnection):
data += self._accept(event)
elif isinstance(event, RejectConnection):
data += self._reject(event)
elif isinstance(event, RejectData):
data += self._send_reject_data(event)
else:
raise LocalProtocolError(
f"Event {event} cannot be sent during the handshake"
)
return data
def receive_data(self, data: Optional[bytes]) -> None:
"""Receive data from the remote.
A list of events that the remote peer triggered by sending
this data can be retrieved with :meth:`events`.
:param bytes data: Data received from the WebSocket peer.
"""
self._h11_connection.receive_data(data or b"")
while True:
try:
event = self._h11_connection.next_event()
except h11.RemoteProtocolError:
raise RemoteProtocolError(
"Bad HTTP message", event_hint=RejectConnection()
)
if (
isinstance(event, h11.ConnectionClosed)
or event is h11.NEED_DATA
or event is h11.PAUSED
):
break
if self.client:
if isinstance(event, h11.InformationalResponse):
if event.status_code == 101:
self._events.append(self._establish_client_connection(event))
else:
self._events.append(
RejectConnection(
headers=list(event.headers),
status_code=event.status_code,
has_body=False,
)
)
self._state = ConnectionState.CLOSED
elif isinstance(event, h11.Response):
self._state = ConnectionState.REJECTING
self._events.append(
RejectConnection(
headers=list(event.headers),
status_code=event.status_code,
has_body=True,
)
)
elif isinstance(event, h11.Data):
self._events.append(
RejectData(data=event.data, body_finished=False)
)
elif isinstance(event, h11.EndOfMessage):
self._events.append(RejectData(data=b"", body_finished=True))
self._state = ConnectionState.CLOSED
else:
if isinstance(event, h11.Request):
self._events.append(self._process_connection_request(event))
def events(self) -> Generator[Event, None, None]:
"""Return a generator that provides any events that have been generated
by protocol activity.
:returns: a generator that yields H11 events.
"""
while self._events:
yield self._events.popleft()
# Server mode methods
def _process_connection_request( # noqa: MC0001
self, event: h11.Request
) -> Request:
if event.method != b"GET":
raise RemoteProtocolError(
"Request method must be GET", event_hint=RejectConnection()
)
connection_tokens = None
extensions: List[str] = []
host = None
key = None
subprotocols: List[str] = []
upgrade = b""
version = None
headers: Headers = []
for name, value in event.headers:
name = name.lower()
if name == b"connection":
connection_tokens = split_comma_header(value)
elif name == b"host":
host = value.decode("ascii")
continue # Skip appending to headers
elif name == b"sec-websocket-extensions":
extensions = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-key":
key = value
elif name == b"sec-websocket-protocol":
subprotocols = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-version":
version = value
elif name == b"upgrade":
upgrade = value
headers.append((name, value))
if connection_tokens is None or not any(
token.lower() == "upgrade" for token in connection_tokens
):
raise RemoteProtocolError(
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
)
if version != WEBSOCKET_VERSION:
raise RemoteProtocolError(
"Missing header, 'Sec-WebSocket-Version'",
event_hint=RejectConnection(
headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)],
status_code=426 if version else 400,
),
)
if key is None:
raise RemoteProtocolError(
"Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection()
)
if upgrade.lower() != b"websocket":
raise RemoteProtocolError(
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
)
if host is None:
raise RemoteProtocolError(
"Missing header, 'Host'", event_hint=RejectConnection()
)
self._initiating_request = Request(
extensions=extensions,
extra_headers=headers,
host=host,
subprotocols=subprotocols,
target=event.target.decode("ascii"),
)
return self._initiating_request
def _accept(self, event: AcceptConnection) -> bytes:
# _accept is always called after _process_connection_request.
assert self._initiating_request is not None
request_headers = normed_header_dict(self._initiating_request.extra_headers)
nonce = request_headers[b"sec-websocket-key"]
accept_token = generate_accept_token(nonce)
headers = [
(b"Upgrade", b"WebSocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Accept", accept_token),
]
if event.subprotocol is not None:
if event.subprotocol not in self._initiating_request.subprotocols:
raise LocalProtocolError(f"unexpected subprotocol {event.subprotocol}")
headers.append(
(b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii"))
)
if event.extensions:
accepts = server_extensions_handshake(
cast(Sequence[str], self._initiating_request.extensions),
event.extensions,
)
if accepts:
headers.append((b"Sec-WebSocket-Extensions", accepts))
response = h11.InformationalResponse(
status_code=101, headers=headers + event.extra_headers
)
self._connection = Connection(
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
event.extensions,
)
self._state = ConnectionState.OPEN
return self._h11_connection.send(response) or b""
def _reject(self, event: RejectConnection) -> bytes:
if self.state != ConnectionState.CONNECTING:
raise LocalProtocolError(
"Connection cannot be rejected in state %s" % self.state
)
headers = list(event.headers)
if not event.has_body:
headers.append((b"content-length", b"0"))
response = h11.Response(status_code=event.status_code, headers=headers)
data = self._h11_connection.send(response) or b""
self._state = ConnectionState.REJECTING
if not event.has_body:
data += self._h11_connection.send(h11.EndOfMessage()) or b""
self._state = ConnectionState.CLOSED
return data
def _send_reject_data(self, event: RejectData) -> bytes:
if self.state != ConnectionState.REJECTING:
raise LocalProtocolError(
f"Cannot send rejection data in state {self.state}"
)
data = self._h11_connection.send(h11.Data(data=event.data)) or b""
if event.body_finished:
data += self._h11_connection.send(h11.EndOfMessage()) or b""
self._state = ConnectionState.CLOSED
return data
# Client mode methods
def _initiate_connection(self, request: Request) -> bytes:
self._initiating_request = request
self._nonce = generate_nonce()
headers = [
(b"Host", request.host.encode("ascii")),
(b"Upgrade", b"WebSocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Key", self._nonce),
(b"Sec-WebSocket-Version", WEBSOCKET_VERSION),
]
if request.subprotocols:
headers.append(
(
b"Sec-WebSocket-Protocol",
(", ".join(request.subprotocols)).encode("ascii"),
)
)
if request.extensions:
offers: Dict[str, Union[str, bool]] = {}
for e in request.extensions:
assert isinstance(e, Extension)
offers[e.name] = e.offer()
extensions = []
for name, params in offers.items():
bname = name.encode("ascii")
if isinstance(params, bool):
if params:
extensions.append(bname)
else:
extensions.append(b"%s; %s" % (bname, params.encode("ascii")))
if extensions:
headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions)))
upgrade = h11.Request(
method=b"GET",
target=request.target.encode("ascii"),
headers=headers + request.extra_headers,
)
return self._h11_connection.send(upgrade) or b""
def _establish_client_connection(
self, event: h11.InformationalResponse
) -> AcceptConnection: # noqa: MC0001
# _establish_client_connection is always called after _initiate_connection.
assert self._initiating_request is not None
assert self._nonce is not None
accept = None
connection_tokens = None
accepts: List[str] = []
subprotocol = None
upgrade = b""
headers: Headers = []
for name, value in event.headers:
name = name.lower()
if name == b"connection":
connection_tokens = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-extensions":
accepts = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-accept":
accept = value
continue # Skip appending to headers
elif name == b"sec-websocket-protocol":
subprotocol = value.decode("ascii")
continue # Skip appending to headers
elif name == b"upgrade":
upgrade = value
continue # Skip appending to headers
headers.append((name, value))
if connection_tokens is None or not any(
token.lower() == "upgrade" for token in connection_tokens
):
raise RemoteProtocolError(
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
)
if upgrade.lower() != b"websocket":
raise RemoteProtocolError(
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
)
accept_token = generate_accept_token(self._nonce)
if accept != accept_token:
raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection())
if subprotocol is not None:
if subprotocol not in self._initiating_request.subprotocols:
raise RemoteProtocolError(
f"unrecognized subprotocol {subprotocol}",
event_hint=RejectConnection(),
)
extensions = client_extensions_handshake(
accepts, cast(Sequence[Extension], self._initiating_request.extensions)
)
self._connection = Connection(
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
extensions,
self._h11_connection.trailing_data[0],
)
self._state = ConnectionState.OPEN
return AcceptConnection(
extensions=extensions, extra_headers=headers, subprotocol=subprotocol
)
def __repr__(self) -> str:
return "{}(client={}, state={})".format(
self.__class__.__name__, self.client, self.state
)
def server_extensions_handshake(
requested: Iterable[str], supported: List[Extension]
) -> Optional[bytes]:
"""Agree on the extensions to use returning an appropriate header value.
This returns None if there are no agreed extensions
"""
accepts: Dict[str, Union[bool, bytes]] = {}
for offer in requested:
name = offer.split(";", 1)[0].strip()
for extension in supported:
if extension.name == name:
accept = extension.accept(offer)
if isinstance(accept, bool):
if accept:
accepts[extension.name] = True
elif accept is not None:
accepts[extension.name] = accept.encode("ascii")
if accepts:
extensions: List[bytes] = []
for name, params in accepts.items():
name_bytes = name.encode("ascii")
if isinstance(params, bool):
assert params
extensions.append(name_bytes)
else:
if params == b"":
extensions.append(b"%s" % (name_bytes))
else:
extensions.append(b"%s; %s" % (name_bytes, params))
return b", ".join(extensions)
return None
def client_extensions_handshake(
accepted: Iterable[str], supported: Sequence[Extension]
) -> List[Extension]:
# This raises RemoteProtocolError is the accepted extension is not
# supported.
extensions = []
for accept in accepted:
name = accept.split(";", 1)[0].strip()
for extension in supported:
if extension.name == name:
extension.finalize(accept)
extensions.append(extension)
break
else:
raise RemoteProtocolError(
f"unrecognized extension {name}", event_hint=RejectConnection()
)
return extensions