import asyncio import functools import inspect import typing from urllib.parse import urlencode from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response from starlette.websockets import WebSocket def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: for scope in scopes: if scope not in conn.auth.scopes: return False return True def requires( scopes: typing.Union[str, typing.Sequence[str]], status_code: int = 403, redirect: typing.Optional[str] = None, ) -> typing.Callable: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": type_ = parameter.name break else: raise Exception( f'No "request" or "websocket" argument on function "{func}"' ) if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> None: websocket = kwargs.get( "websocket", args[idx] if idx < len(args) else None ) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): await websocket.close() else: await func(*args, **kwargs) return websocket_wrapper elif asyncio.iscoroutinefunction(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> Response: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) next_url = "{redirect_path}?{orig_request}".format( redirect_path=request.url_for(redirect), orig_request=orig_request_qparam, ) return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) return async_wrapper else: # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) next_url = "{redirect_path}?{orig_request}".format( redirect_path=request.url_for(redirect), orig_request=orig_request_qparam, ) return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return func(*args, **kwargs) return sync_wrapper return decorator class AuthenticationError(Exception): pass class AuthenticationBackend: async def authenticate( self, conn: HTTPConnection ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: raise NotImplementedError() # pragma: no cover class AuthCredentials: def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None): self.scopes = [] if scopes is None else list(scopes) class BaseUser: @property def is_authenticated(self) -> bool: raise NotImplementedError() # pragma: no cover @property def display_name(self) -> str: raise NotImplementedError() # pragma: no cover @property def identity(self) -> str: raise NotImplementedError() # pragma: no cover class SimpleUser(BaseUser): def __init__(self, username: str) -> None: self.username = username @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username class UnauthenticatedUser(BaseUser): @property def is_authenticated(self) -> bool: return False @property def display_name(self) -> str: return ""