You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.2 KiB
44 lines
1.2 KiB
2 years ago
|
import asyncio
|
||
|
import sys
|
||
|
import typing
|
||
|
|
||
|
if sys.version_info >= (3, 10): # pragma: no cover
|
||
|
from typing import ParamSpec
|
||
|
else: # pragma: no cover
|
||
|
from typing_extensions import ParamSpec
|
||
|
|
||
|
from starlette.concurrency import run_in_threadpool
|
||
|
|
||
|
P = ParamSpec("P")
|
||
|
|
||
|
|
||
|
class BackgroundTask:
|
||
|
def __init__(
|
||
|
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||
|
) -> None:
|
||
|
self.func = func
|
||
|
self.args = args
|
||
|
self.kwargs = kwargs
|
||
|
self.is_async = asyncio.iscoroutinefunction(func)
|
||
|
|
||
|
async def __call__(self) -> None:
|
||
|
if self.is_async:
|
||
|
await self.func(*self.args, **self.kwargs)
|
||
|
else:
|
||
|
await run_in_threadpool(self.func, *self.args, **self.kwargs)
|
||
|
|
||
|
|
||
|
class BackgroundTasks(BackgroundTask):
|
||
|
def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None):
|
||
|
self.tasks = list(tasks) if tasks else []
|
||
|
|
||
|
def add_task(
|
||
|
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||
|
) -> None:
|
||
|
task = BackgroundTask(func, *args, **kwargs)
|
||
|
self.tasks.append(task)
|
||
|
|
||
|
async def __call__(self) -> None:
|
||
|
for task in self.tasks:
|
||
|
await task()
|