import typing from os import PathLike from starlette.background import BackgroundTask from starlette.responses import Response from starlette.types import Receive, Scope, Send try: import jinja2 # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1 # hence we try to get pass_context (most installs will be >=3.1) # and fall back to contextfunction, # adding a type ignore for mypy to let us access an attribute that may not exist if hasattr(jinja2, "pass_context"): pass_context = jinja2.pass_context else: # pragma: nocover pass_context = jinja2.contextfunction # type: ignore[attr-defined] except ImportError: # pragma: nocover jinja2 = None # type: ignore class _TemplateResponse(Response): media_type = "text/html" def __init__( self, template: typing.Any, context: dict, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ): self.template = template self.context = context content = template.render(context) super().__init__(content, status_code, headers, media_type, background) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) if "http.response.template" in extensions: await send( { "type": "http.response.template", "template": self.template, "context": self.context, } ) await super().__call__(scope, receive, send) class Jinja2Templates: """ templates = Jinja2Templates("templates") return templates.TemplateResponse("index.html", {"request": request}) """ def __init__( self, directory: typing.Union[str, PathLike], **env_options: typing.Any ) -> None: assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" self.env = self._create_env(directory, **env_options) def _create_env( self, directory: typing.Union[str, PathLike], **env_options: typing.Any ) -> "jinja2.Environment": @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request = context["request"] return request.url_for(name, **path_params) loader = jinja2.FileSystemLoader(directory) env_options.setdefault("loader", loader) env_options.setdefault("autoescape", True) env = jinja2.Environment(**env_options) env.globals["url_for"] = url_for return env def get_template(self, name: str) -> "jinja2.Template": return self.env.get_template(name) def TemplateResponse( self, name: str, context: dict, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: if "request" not in context: raise ValueError('context must include a "request" key') template = self.get_template(name) return _TemplateResponse( template, context, status_code=status_code, headers=headers, media_type=media_type, background=background, )