171 lines
4.6 KiB
171 lines
4.6 KiB
2 years ago
|
import enum
|
||
|
import sys
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Any, Dict, Generic, Set, TypeVar, Union, overload
|
||
|
from weakref import WeakKeyDictionary
|
||
|
|
||
|
from ._core._eventloop import get_asynclib
|
||
|
|
||
|
if sys.version_info >= (3, 8):
|
||
|
from typing import Literal
|
||
|
else:
|
||
|
from typing_extensions import Literal
|
||
|
|
||
|
T = TypeVar("T")
|
||
|
D = TypeVar("D")
|
||
|
|
||
|
|
||
|
async def checkpoint() -> None:
|
||
|
"""
|
||
|
Check for cancellation and allow the scheduler to switch to another task.
|
||
|
|
||
|
Equivalent to (but more efficient than)::
|
||
|
|
||
|
await checkpoint_if_cancelled()
|
||
|
await cancel_shielded_checkpoint()
|
||
|
|
||
|
.. versionadded:: 3.0
|
||
|
|
||
|
"""
|
||
|
await get_asynclib().checkpoint()
|
||
|
|
||
|
|
||
|
async def checkpoint_if_cancelled() -> None:
|
||
|
"""
|
||
|
Enter a checkpoint if the enclosing cancel scope has been cancelled.
|
||
|
|
||
|
This does not allow the scheduler to switch to a different task.
|
||
|
|
||
|
.. versionadded:: 3.0
|
||
|
|
||
|
"""
|
||
|
await get_asynclib().checkpoint_if_cancelled()
|
||
|
|
||
|
|
||
|
async def cancel_shielded_checkpoint() -> None:
|
||
|
"""
|
||
|
Allow the scheduler to switch to another task but without checking for cancellation.
|
||
|
|
||
|
Equivalent to (but potentially more efficient than)::
|
||
|
|
||
|
with CancelScope(shield=True):
|
||
|
await checkpoint()
|
||
|
|
||
|
.. versionadded:: 3.0
|
||
|
|
||
|
"""
|
||
|
await get_asynclib().cancel_shielded_checkpoint()
|
||
|
|
||
|
|
||
|
def current_token() -> object:
|
||
|
"""Return a backend specific token object that can be used to get back to the event loop."""
|
||
|
return get_asynclib().current_token()
|
||
|
|
||
|
|
||
|
_run_vars = WeakKeyDictionary() # type: WeakKeyDictionary[Any, Dict[str, Any]]
|
||
|
_token_wrappers: Dict[Any, "_TokenWrapper"] = {}
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class _TokenWrapper:
|
||
|
__slots__ = "_token", "__weakref__"
|
||
|
_token: object
|
||
|
|
||
|
|
||
|
class _NoValueSet(enum.Enum):
|
||
|
NO_VALUE_SET = enum.auto()
|
||
|
|
||
|
|
||
|
class RunvarToken(Generic[T]):
|
||
|
__slots__ = "_var", "_value", "_redeemed"
|
||
|
|
||
|
def __init__(
|
||
|
self, var: "RunVar[T]", value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]]
|
||
|
):
|
||
|
self._var = var
|
||
|
self._value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = value
|
||
|
self._redeemed = False
|
||
|
|
||
|
|
||
|
class RunVar(Generic[T]):
|
||
|
"""Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop."""
|
||
|
|
||
|
__slots__ = "_name", "_default"
|
||
|
|
||
|
NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
|
||
|
|
||
|
_token_wrappers: Set[_TokenWrapper] = set()
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
default: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET,
|
||
|
):
|
||
|
self._name = name
|
||
|
self._default = default
|
||
|
|
||
|
@property
|
||
|
def _current_vars(self) -> Dict[str, T]:
|
||
|
token = current_token()
|
||
|
while True:
|
||
|
try:
|
||
|
return _run_vars[token]
|
||
|
except TypeError:
|
||
|
# Happens when token isn't weak referable (TrioToken).
|
||
|
# This workaround does mean that some memory will leak on Trio until the problem
|
||
|
# is fixed on their end.
|
||
|
token = _TokenWrapper(token)
|
||
|
self._token_wrappers.add(token)
|
||
|
except KeyError:
|
||
|
run_vars = _run_vars[token] = {}
|
||
|
return run_vars
|
||
|
|
||
|
@overload
|
||
|
def get(self, default: D) -> Union[T, D]:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
def get(self) -> T:
|
||
|
...
|
||
|
|
||
|
def get(
|
||
|
self, default: Union[D, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET
|
||
|
) -> Union[T, D]:
|
||
|
try:
|
||
|
return self._current_vars[self._name]
|
||
|
except KeyError:
|
||
|
if default is not RunVar.NO_VALUE_SET:
|
||
|
return default
|
||
|
elif self._default is not RunVar.NO_VALUE_SET:
|
||
|
return self._default
|
||
|
|
||
|
raise LookupError(
|
||
|
f'Run variable "{self._name}" has no value and no default set'
|
||
|
)
|
||
|
|
||
|
def set(self, value: T) -> RunvarToken[T]:
|
||
|
current_vars = self._current_vars
|
||
|
token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET))
|
||
|
current_vars[self._name] = value
|
||
|
return token
|
||
|
|
||
|
def reset(self, token: RunvarToken[T]) -> None:
|
||
|
if token._var is not self:
|
||
|
raise ValueError("This token does not belong to this RunVar")
|
||
|
|
||
|
if token._redeemed:
|
||
|
raise ValueError("This token has already been used")
|
||
|
|
||
|
if token._value is _NoValueSet.NO_VALUE_SET:
|
||
|
try:
|
||
|
del self._current_vars[self._name]
|
||
|
except KeyError:
|
||
|
pass
|
||
|
else:
|
||
|
self._current_vars[self._name] = token._value
|
||
|
|
||
|
token._redeemed = True
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f"<RunVar name={self._name!r}>"
|