You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
1.1 KiB
33 lines
1.1 KiB
2 years ago
|
import sys
|
||
|
from typing import AsyncGenerator, ContextManager, TypeVar
|
||
|
|
||
|
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
|
||
|
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
|
||
|
from starlette.concurrency import ( # noqa
|
||
|
run_until_first_complete as run_until_first_complete,
|
||
|
)
|
||
|
|
||
|
if sys.version_info >= (3, 7):
|
||
|
from contextlib import AsyncExitStack as AsyncExitStack
|
||
|
from contextlib import asynccontextmanager as asynccontextmanager
|
||
|
else:
|
||
|
from contextlib2 import AsyncExitStack as AsyncExitStack # noqa
|
||
|
from contextlib2 import asynccontextmanager as asynccontextmanager # noqa
|
||
|
|
||
|
|
||
|
_T = TypeVar("_T")
|
||
|
|
||
|
|
||
|
@asynccontextmanager
|
||
|
async def contextmanager_in_threadpool(
|
||
|
cm: ContextManager[_T],
|
||
|
) -> AsyncGenerator[_T, None]:
|
||
|
try:
|
||
|
yield await run_in_threadpool(cm.__enter__)
|
||
|
except Exception as e:
|
||
|
ok: bool = await run_in_threadpool(cm.__exit__, type(e), e, None)
|
||
|
if not ok:
|
||
|
raise e
|
||
|
else:
|
||
|
await run_in_threadpool(cm.__exit__, None, None, None)
|