You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
3.8 KiB
112 lines
3.8 KiB
import asyncio
|
|
import http
|
|
import typing
|
|
|
|
from starlette.concurrency import run_in_threadpool
|
|
from starlette.requests import Request
|
|
from starlette.responses import PlainTextResponse, Response
|
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
|
|
|
|
|
class HTTPException(Exception):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
detail: typing.Optional[str] = None,
|
|
headers: typing.Optional[dict] = None,
|
|
) -> None:
|
|
if detail is None:
|
|
detail = http.HTTPStatus(status_code).phrase
|
|
self.status_code = status_code
|
|
self.detail = detail
|
|
self.headers = headers
|
|
|
|
def __repr__(self) -> str:
|
|
class_name = self.__class__.__name__
|
|
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
|
|
|
|
|
|
class ExceptionMiddleware:
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
handlers: typing.Optional[
|
|
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
|
|
] = None,
|
|
debug: bool = False,
|
|
) -> None:
|
|
self.app = app
|
|
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
|
self._status_handlers: typing.Dict[int, typing.Callable] = {}
|
|
self._exception_handlers: typing.Dict[
|
|
typing.Type[Exception], typing.Callable
|
|
] = {HTTPException: self.http_exception}
|
|
if handlers is not None:
|
|
for key, value in handlers.items():
|
|
self.add_exception_handler(key, value)
|
|
|
|
def add_exception_handler(
|
|
self,
|
|
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
|
handler: typing.Callable[[Request, Exception], Response],
|
|
) -> None:
|
|
if isinstance(exc_class_or_status_code, int):
|
|
self._status_handlers[exc_class_or_status_code] = handler
|
|
else:
|
|
assert issubclass(exc_class_or_status_code, Exception)
|
|
self._exception_handlers[exc_class_or_status_code] = handler
|
|
|
|
def _lookup_exception_handler(
|
|
self, exc: Exception
|
|
) -> typing.Optional[typing.Callable]:
|
|
for cls in type(exc).__mro__:
|
|
if cls in self._exception_handlers:
|
|
return self._exception_handlers[cls]
|
|
return None
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
response_started = False
|
|
|
|
async def sender(message: Message) -> None:
|
|
nonlocal response_started
|
|
|
|
if message["type"] == "http.response.start":
|
|
response_started = True
|
|
await send(message)
|
|
|
|
try:
|
|
await self.app(scope, receive, sender)
|
|
except Exception as exc:
|
|
handler = None
|
|
|
|
if isinstance(exc, HTTPException):
|
|
handler = self._status_handlers.get(exc.status_code)
|
|
|
|
if handler is None:
|
|
handler = self._lookup_exception_handler(exc)
|
|
|
|
if handler is None:
|
|
raise exc
|
|
|
|
if response_started:
|
|
msg = "Caught handled exception, but response already started."
|
|
raise RuntimeError(msg) from exc
|
|
|
|
request = Request(scope, receive=receive)
|
|
if asyncio.iscoroutinefunction(handler):
|
|
response = await handler(request, exc)
|
|
else:
|
|
response = await run_in_threadpool(handler, request, exc)
|
|
await response(scope, receive, sender)
|
|
|
|
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
|
if exc.status_code in {204, 304}:
|
|
return Response(status_code=exc.status_code, headers=exc.headers)
|
|
return PlainTextResponse(
|
|
exc.detail, status_code=exc.status_code, headers=exc.headers
|
|
)
|