import inspect import typing from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Mount, Route try: import yaml except ImportError: # pragma: nocover yaml = None # type: ignore class OpenAPIResponse(Response): media_type = "application/vnd.oai.openapi" def render(self, content: typing.Any) -> bytes: assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." assert isinstance( content, dict ), "The schema passed to OpenAPIResponse should be a dictionary." return yaml.dump(content, default_flow_style=False).encode("utf-8") class EndpointInfo(typing.NamedTuple): path: str http_method: str func: typing.Callable class BaseSchemaGenerator: def get_schema(self, routes: typing.List[BaseRoute]) -> dict: raise NotImplementedError() # pragma: no cover def get_endpoints( self, routes: typing.List[BaseRoute] ) -> typing.List[EndpointInfo]: """ Given the routes, yields the following information: - path eg: /users/ - http_method one of 'get', 'post', 'put', 'patch', 'delete', 'options' - func method ready to extract the docstring """ endpoints_info: list = [] for route in routes: if isinstance(route, Mount): routes = route.routes or [] sub_endpoints = [ EndpointInfo( path="".join((route.path, sub_endpoint.path)), http_method=sub_endpoint.http_method, func=sub_endpoint.func, ) for sub_endpoint in self.get_endpoints(routes) ] endpoints_info.extend(sub_endpoints) elif not isinstance(route, Route) or not route.include_in_schema: continue elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): for method in route.methods or ["GET"]: if method == "HEAD": continue endpoints_info.append( EndpointInfo(route.path, method.lower(), route.endpoint) ) else: for method in ["get", "post", "put", "patch", "delete", "options"]: if not hasattr(route.endpoint, method): continue func = getattr(route.endpoint, method) endpoints_info.append( EndpointInfo(route.path, method.lower(), func) ) return endpoints_info def parse_docstring(self, func_or_method: typing.Callable) -> dict: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ docstring = func_or_method.__doc__ if not docstring: return {} assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." # We support having regular docstrings before the schema # definition. Here we return just the schema part from # the docstring. docstring = docstring.split("---")[-1] parsed = yaml.safe_load(docstring) if not isinstance(parsed, dict): # A regular docstring (not yaml formatted) can return # a simple string here, which wouldn't follow the schema. return {} return parsed def OpenAPIResponse(self, request: Request) -> Response: routes = request.app.routes schema = self.get_schema(routes=routes) return OpenAPIResponse(schema) class SchemaGenerator(BaseSchemaGenerator): def __init__(self, base_schema: dict) -> None: self.base_schema = base_schema def get_schema(self, routes: typing.List[BaseRoute]) -> dict: schema = dict(self.base_schema) schema.setdefault("paths", {}) endpoints_info = self.get_endpoints(routes) for endpoint in endpoints_info: parsed = self.parse_docstring(endpoint.func) if not parsed: continue if endpoint.path not in schema["paths"]: schema["paths"][endpoint.path] = {} schema["paths"][endpoint.path][endpoint.http_method] = parsed return schema