291 lines
9.4 KiB
291 lines
9.4 KiB
import json
|
|
import typing
|
|
from collections.abc import Mapping
|
|
from http import cookies as http_cookies
|
|
|
|
import anyio
|
|
|
|
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
|
|
from starlette.formparsers import FormParser, MultiPartParser
|
|
from starlette.types import Message, Receive, Scope, Send
|
|
|
|
try:
|
|
from multipart.multipart import parse_options_header
|
|
except ImportError: # pragma: nocover
|
|
parse_options_header = None
|
|
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from starlette.routing import Router
|
|
|
|
|
|
SERVER_PUSH_HEADERS_TO_COPY = {
|
|
"accept",
|
|
"accept-encoding",
|
|
"accept-language",
|
|
"cache-control",
|
|
"user-agent",
|
|
}
|
|
|
|
|
|
def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
|
|
"""
|
|
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
|
|
|
|
It attempts to mimic browser cookie parsing behavior: browsers and web servers
|
|
frequently disregard the spec (RFC 6265) when setting and reading cookies,
|
|
so we attempt to suit the common scenarios here.
|
|
|
|
This function has been adapted from Django 3.1.0.
|
|
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
|
|
on an outdated spec and will fail on lots of input we want to support
|
|
"""
|
|
cookie_dict: typing.Dict[str, str] = {}
|
|
for chunk in cookie_string.split(";"):
|
|
if "=" in chunk:
|
|
key, val = chunk.split("=", 1)
|
|
else:
|
|
# Assume an empty name per
|
|
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
|
key, val = "", chunk
|
|
key, val = key.strip(), val.strip()
|
|
if key or val:
|
|
# unquote using Python's algorithm.
|
|
cookie_dict[key] = http_cookies._unquote(val)
|
|
return cookie_dict
|
|
|
|
|
|
class ClientDisconnect(Exception):
|
|
pass
|
|
|
|
|
|
class HTTPConnection(Mapping):
|
|
"""
|
|
A base class for incoming HTTP connections, that is used to provide
|
|
any functionality that is common to both `Request` and `WebSocket`.
|
|
"""
|
|
|
|
def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None:
|
|
assert scope["type"] in ("http", "websocket")
|
|
self.scope = scope
|
|
|
|
def __getitem__(self, key: str) -> typing.Any:
|
|
return self.scope[key]
|
|
|
|
def __iter__(self) -> typing.Iterator[str]:
|
|
return iter(self.scope)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.scope)
|
|
|
|
# Don't use the `abc.Mapping.__eq__` implementation.
|
|
# Connection instances should never be considered equal
|
|
# unless `self is other`.
|
|
__eq__ = object.__eq__
|
|
__hash__ = object.__hash__
|
|
|
|
@property
|
|
def app(self) -> typing.Any:
|
|
return self.scope["app"]
|
|
|
|
@property
|
|
def url(self) -> URL:
|
|
if not hasattr(self, "_url"):
|
|
self._url = URL(scope=self.scope)
|
|
return self._url
|
|
|
|
@property
|
|
def base_url(self) -> URL:
|
|
if not hasattr(self, "_base_url"):
|
|
base_url_scope = dict(self.scope)
|
|
base_url_scope["path"] = "/"
|
|
base_url_scope["query_string"] = b""
|
|
base_url_scope["root_path"] = base_url_scope.get(
|
|
"app_root_path", base_url_scope.get("root_path", "")
|
|
)
|
|
self._base_url = URL(scope=base_url_scope)
|
|
return self._base_url
|
|
|
|
@property
|
|
def headers(self) -> Headers:
|
|
if not hasattr(self, "_headers"):
|
|
self._headers = Headers(scope=self.scope)
|
|
return self._headers
|
|
|
|
@property
|
|
def query_params(self) -> QueryParams:
|
|
if not hasattr(self, "_query_params"):
|
|
self._query_params = QueryParams(self.scope["query_string"])
|
|
return self._query_params
|
|
|
|
@property
|
|
def path_params(self) -> typing.Dict[str, typing.Any]:
|
|
return self.scope.get("path_params", {})
|
|
|
|
@property
|
|
def cookies(self) -> typing.Dict[str, str]:
|
|
if not hasattr(self, "_cookies"):
|
|
cookies: typing.Dict[str, str] = {}
|
|
cookie_header = self.headers.get("cookie")
|
|
|
|
if cookie_header:
|
|
cookies = cookie_parser(cookie_header)
|
|
self._cookies = cookies
|
|
return self._cookies
|
|
|
|
@property
|
|
def client(self) -> typing.Optional[Address]:
|
|
# client is a 2 item tuple of (host, port), None or missing
|
|
host_port = self.scope.get("client")
|
|
if host_port is not None:
|
|
return Address(*host_port)
|
|
return None
|
|
|
|
@property
|
|
def session(self) -> dict:
|
|
assert (
|
|
"session" in self.scope
|
|
), "SessionMiddleware must be installed to access request.session"
|
|
return self.scope["session"]
|
|
|
|
@property
|
|
def auth(self) -> typing.Any:
|
|
assert (
|
|
"auth" in self.scope
|
|
), "AuthenticationMiddleware must be installed to access request.auth"
|
|
return self.scope["auth"]
|
|
|
|
@property
|
|
def user(self) -> typing.Any:
|
|
assert (
|
|
"user" in self.scope
|
|
), "AuthenticationMiddleware must be installed to access request.user"
|
|
return self.scope["user"]
|
|
|
|
@property
|
|
def state(self) -> State:
|
|
if not hasattr(self, "_state"):
|
|
# Ensure 'state' has an empty dict if it's not already populated.
|
|
self.scope.setdefault("state", {})
|
|
# Create a state instance with a reference to the dict in which it should
|
|
# store info
|
|
self._state = State(self.scope["state"])
|
|
return self._state
|
|
|
|
def url_for(self, name: str, **path_params: typing.Any) -> str:
|
|
router: Router = self.scope["router"]
|
|
url_path = router.url_path_for(name, **path_params)
|
|
return url_path.make_absolute_url(base_url=self.base_url)
|
|
|
|
|
|
async def empty_receive() -> typing.NoReturn:
|
|
raise RuntimeError("Receive channel has not been made available")
|
|
|
|
|
|
async def empty_send(message: Message) -> typing.NoReturn:
|
|
raise RuntimeError("Send channel has not been made available")
|
|
|
|
|
|
class Request(HTTPConnection):
|
|
def __init__(
|
|
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
|
|
):
|
|
super().__init__(scope)
|
|
assert scope["type"] == "http"
|
|
self._receive = receive
|
|
self._send = send
|
|
self._stream_consumed = False
|
|
self._is_disconnected = False
|
|
|
|
@property
|
|
def method(self) -> str:
|
|
return self.scope["method"]
|
|
|
|
@property
|
|
def receive(self) -> Receive:
|
|
return self._receive
|
|
|
|
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
|
if hasattr(self, "_body"):
|
|
yield self._body
|
|
yield b""
|
|
return
|
|
|
|
if self._stream_consumed:
|
|
raise RuntimeError("Stream consumed")
|
|
|
|
self._stream_consumed = True
|
|
while True:
|
|
message = await self._receive()
|
|
if message["type"] == "http.request":
|
|
body = message.get("body", b"")
|
|
if body:
|
|
yield body
|
|
if not message.get("more_body", False):
|
|
break
|
|
elif message["type"] == "http.disconnect":
|
|
self._is_disconnected = True
|
|
raise ClientDisconnect()
|
|
yield b""
|
|
|
|
async def body(self) -> bytes:
|
|
if not hasattr(self, "_body"):
|
|
chunks = []
|
|
async for chunk in self.stream():
|
|
chunks.append(chunk)
|
|
self._body = b"".join(chunks)
|
|
return self._body
|
|
|
|
async def json(self) -> typing.Any:
|
|
if not hasattr(self, "_json"):
|
|
body = await self.body()
|
|
self._json = json.loads(body)
|
|
return self._json
|
|
|
|
async def form(self) -> FormData:
|
|
if not hasattr(self, "_form"):
|
|
assert (
|
|
parse_options_header is not None
|
|
), "The `python-multipart` library must be installed to use form parsing."
|
|
content_type_header = self.headers.get("Content-Type")
|
|
content_type, options = parse_options_header(content_type_header)
|
|
if content_type == b"multipart/form-data":
|
|
multipart_parser = MultiPartParser(self.headers, self.stream())
|
|
self._form = await multipart_parser.parse()
|
|
elif content_type == b"application/x-www-form-urlencoded":
|
|
form_parser = FormParser(self.headers, self.stream())
|
|
self._form = await form_parser.parse()
|
|
else:
|
|
self._form = FormData()
|
|
return self._form
|
|
|
|
async def close(self) -> None:
|
|
if hasattr(self, "_form"):
|
|
await self._form.close()
|
|
|
|
async def is_disconnected(self) -> bool:
|
|
if not self._is_disconnected:
|
|
message: Message = {}
|
|
|
|
# If message isn't immediately available, move on
|
|
with anyio.CancelScope() as cs:
|
|
cs.cancel()
|
|
message = await self._receive()
|
|
|
|
if message.get("type") == "http.disconnect":
|
|
self._is_disconnected = True
|
|
|
|
return self._is_disconnected
|
|
|
|
async def send_push_promise(self, path: str) -> None:
|
|
if "http.response.push" in self.scope.get("extensions", {}):
|
|
raw_headers = []
|
|
for name in SERVER_PUSH_HEADERS_TO_COPY:
|
|
for value in self.headers.getlist(name):
|
|
raw_headers.append(
|
|
(name.encode("latin-1"), value.encode("latin-1"))
|
|
)
|
|
await self._send(
|
|
{"type": "http.response.push", "path": path, "headers": raw_headers}
|
|
)
|