You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/Resources/WPy64-3720/python-3.7.2.amd64/Lib/site-packages/starlette/authentication.py

152 lines
5.0 KiB

import asyncio
import functools
import inspect
import typing
from urllib.parse import urlencode
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
from starlette.websockets import WebSocket
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
return True
def requires(
scopes: typing.Union[str, typing.Sequence[str]],
status_code: int = 403,
redirect: typing.Optional[str] = None,
) -> typing.Callable:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(func: typing.Callable) -> typing.Callable:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
type_ = parameter.name
break
else:
raise Exception(
f'No "request" or "websocket" argument on function "{func}"'
)
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> None:
websocket = kwargs.get(
"websocket", args[idx] if idx < len(args) else None
)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
return websocket_wrapper
elif asyncio.iscoroutinefunction(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> Response:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = "{redirect_path}?{orig_request}".format(
redirect_path=request.url_for(redirect),
orig_request=orig_request_qparam,
)
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
return async_wrapper
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = "{redirect_path}?{orig_request}".format(
redirect_path=request.url_for(redirect),
orig_request=orig_request_qparam,
)
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return decorator
class AuthenticationError(Exception):
pass
class AuthenticationBackend:
async def authenticate(
self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
raise NotImplementedError() # pragma: no cover
class AuthCredentials:
def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None):
self.scopes = [] if scopes is None else list(scopes)
class BaseUser:
@property
def is_authenticated(self) -> bool:
raise NotImplementedError() # pragma: no cover
@property
def display_name(self) -> str:
raise NotImplementedError() # pragma: no cover
@property
def identity(self) -> str:
raise NotImplementedError() # pragma: no cover
class SimpleUser(BaseUser):
def __init__(self, username: str) -> None:
self.username = username
@property
def is_authenticated(self) -> bool:
return True
@property
def display_name(self) -> str:
return self.username
class UnauthenticatedUser(BaseUser):
@property
def is_authenticated(self) -> bool:
return False
@property
def display_name(self) -> str:
return ""