44 lines
1.2 KiB

import asyncio
import sys
import typing
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette.concurrency import run_in_threadpool
P = ParamSpec("P")
class BackgroundTask:
def __init__(
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
self.is_async = asyncio.iscoroutinefunction(func)
async def __call__(self) -> None:
if self.is_async:
await self.func(*self.args, **self.kwargs)
else:
await run_in_threadpool(self.func, *self.args, **self.kwargs)
class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None):
self.tasks = list(tasks) if tasks else []
def add_task(
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
async def __call__(self) -> None:
for task in self.tasks:
await task()