parent
0258e89565
commit
2b06170730
@ -1,4 +1,5 @@
|
||||
cd %~dp0
|
||||
chcp 65001
|
||||
cd /d "%~dp0"
|
||||
taskkill /im "orpa-agent.exe" /F /fi "username eq %username%"
|
||||
copy /Y ..\Resources\WPy64-3720\python-3.7.2.amd64\pythonw.exe ..\Resources\WPy64-3720\python-3.7.2.amd64\orpa-agent.exe
|
||||
.\..\Resources\WPy64-3720\python-3.7.2.amd64\orpa-agent.exe "config.py"
|
||||
|
@ -1,4 +1,5 @@
|
||||
cd %~dp0
|
||||
chcp 65001
|
||||
cd /d "%~dp0"
|
||||
taskkill /im "orpa-agent.exe" /F /fi "username eq %username%"
|
||||
copy /Y ..\Resources\WPy64-3720\python-3.7.2.amd64\python.exe ..\Resources\WPy64-3720\python-3.7.2.amd64\orpa-agent.exe
|
||||
.\..\Resources\WPy64-3720\python-3.7.2.amd64\orpa-agent.exe "config.py"
|
||||
|
@ -1,4 +1,5 @@
|
||||
cd %~dp0
|
||||
chcp 65001
|
||||
cd /d "%~dp0"
|
||||
taskkill /im "orpa-orc.exe" /F /fi "username eq %username%"
|
||||
copy /Y ..\Resources\WPy64-3720\python-3.7.2.amd64\python.exe ..\Resources\WPy64-3720\python-3.7.2.amd64\orpa-orc.exe
|
||||
.\..\Resources\WPy64-3720\python-3.7.2.amd64\orpa-orc.exe "config.py"
|
||||
|
@ -0,0 +1 @@
|
||||
pip
|
@ -0,0 +1,20 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2018 Alex Grönholm
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,82 @@
|
||||
anyio-3.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
anyio-3.6.1.dist-info/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
|
||||
anyio-3.6.1.dist-info/METADATA,sha256=cXu3CLppFqT_rBl4Eo2HOb0J1zm6Ltu_tMFuzjuQnew,4654
|
||||
anyio-3.6.1.dist-info/RECORD,,
|
||||
anyio-3.6.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
||||
anyio-3.6.1.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
|
||||
anyio-3.6.1.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
|
||||
anyio/__init__.py,sha256=M2R8dk6L5gL5lXHArzpSfEn2oH5jMyUKhzyrkRiv2AM,4037
|
||||
anyio/__pycache__/__init__.cpython-37.pyc,,
|
||||
anyio/__pycache__/from_thread.cpython-37.pyc,,
|
||||
anyio/__pycache__/lowlevel.cpython-37.pyc,,
|
||||
anyio/__pycache__/pytest_plugin.cpython-37.pyc,,
|
||||
anyio/__pycache__/to_process.cpython-37.pyc,,
|
||||
anyio/__pycache__/to_thread.cpython-37.pyc,,
|
||||
anyio/_backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/_backends/__pycache__/__init__.cpython-37.pyc,,
|
||||
anyio/_backends/__pycache__/_asyncio.cpython-37.pyc,,
|
||||
anyio/_backends/__pycache__/_trio.cpython-37.pyc,,
|
||||
anyio/_backends/_asyncio.py,sha256=ZJDvRwfS4wv9WWcqWledNJyl8hx8A8-m9-gSKAJ6nBM,69238
|
||||
anyio/_backends/_trio.py,sha256=CebCaqr8Szi6uCnUzwtBRLfUitR5OnDT_wfH-KiqvBQ,29696
|
||||
anyio/_core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/_core/__pycache__/__init__.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_compat.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_eventloop.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_exceptions.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_fileio.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_resources.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_signals.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_sockets.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_streams.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_subprocesses.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_synchronization.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_tasks.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_testing.cpython-37.pyc,,
|
||||
anyio/_core/__pycache__/_typedattr.cpython-37.pyc,,
|
||||
anyio/_core/_compat.py,sha256=X99W70r-O-JLdkKNtbddcIY5H2Nyg3Nk34oUYE9WZRs,5790
|
||||
anyio/_core/_eventloop.py,sha256=DRn_hy679LtsJFsPX7dXjDv72bLtSFkTnWY9WVVfgCQ,4108
|
||||
anyio/_core/_exceptions.py,sha256=1wqraNldZroYkoyB0HZStAruz_7yDCBaW-4zYwsKj8s,2904
|
||||
anyio/_core/_fileio.py,sha256=au82uZXZX4fia8EoZq_E-JDwZFKe6ZtI0J6IkxK8FmQ,18298
|
||||
anyio/_core/_resources.py,sha256=M_uN-90N8eSsWuvo-0xluWU_OG2BTyccAgsQ7XtHxzs,399
|
||||
anyio/_core/_signals.py,sha256=D4btJN527tAADspKBeNKaCds-ZcEZJP8LWM_MjVuQRA,827
|
||||
anyio/_core/_sockets.py,sha256=fW_Cbg6kfw4xgYuVuWbcWrAYspOcDSEjwxVATMzf2fo,19820
|
||||
anyio/_core/_streams.py,sha256=gjT5xChJ1OoV8nNinljSv1yW4nqUS-QzZzIydQz3exQ,1494
|
||||
anyio/_core/_subprocesses.py,sha256=pcchMI2OII0QSjiVxRiTEz4M0B7TlQPzGurfCuka-xc,5049
|
||||
anyio/_core/_synchronization.py,sha256=xOOG4hF9783N6E2IcD3YKiukguA5bPrj6BodDsKNaJY,16822
|
||||
anyio/_core/_tasks.py,sha256=ebGLjHvwL6I9aGyPwvCig1drebSVYFzvY3pnN3TsB4o,5273
|
||||
anyio/_core/_testing.py,sha256=VZka_yebIhJ6mJ6Vo_ilO3Nbz53ieqg0WBijwciMwdY,2196
|
||||
anyio/_core/_typedattr.py,sha256=k5-wBvMlDlKHIpn18INVnXAlGwI3CrAvPmWoceHjnOQ,2534
|
||||
anyio/abc/__init__.py,sha256=hMa47CMs5O1twC2bBcSbzwX-3Q08BAgAPTRekQobb3E,2123
|
||||
anyio/abc/__pycache__/__init__.cpython-37.pyc,,
|
||||
anyio/abc/__pycache__/_resources.cpython-37.pyc,,
|
||||
anyio/abc/__pycache__/_sockets.cpython-37.pyc,,
|
||||
anyio/abc/__pycache__/_streams.cpython-37.pyc,,
|
||||
anyio/abc/__pycache__/_subprocesses.cpython-37.pyc,,
|
||||
anyio/abc/__pycache__/_tasks.cpython-37.pyc,,
|
||||
anyio/abc/__pycache__/_testing.cpython-37.pyc,,
|
||||
anyio/abc/_resources.py,sha256=js737mWPG6IW0fH8W4Tz9eNWLztse7dKxEC61z934Vk,752
|
||||
anyio/abc/_sockets.py,sha256=i1VdcJTLAuRlYeZoL6s5RBSWbX62Cu6ln5YZBL2YrWk,5754
|
||||
anyio/abc/_streams.py,sha256=0g70fhKAzbnK0KKmWwRgwmKdApBwduAcVj4TpjSzjzU,6501
|
||||
anyio/abc/_subprocesses.py,sha256=iREP_YQ91it88lDU4XIcI3HZ9HUvV5UmjQk_sSPonrw,2071
|
||||
anyio/abc/_tasks.py,sha256=mQQd1DANqpySKyehVVPdMfi_UEG49zZUJpt5blunOjg,3119
|
||||
anyio/abc/_testing.py,sha256=ifKCUPzcQdHAEGO-weu2GQvzjMQPPIWO24mQ0z6zkdU,1928
|
||||
anyio/from_thread.py,sha256=nSq6mafYMqwxKmzdJyISg8cp-AyBj9rxZPMt_b7klSM,16497
|
||||
anyio/lowlevel.py,sha256=W4ydshns7f86YuSESFc2igTf46AWMXnGPQGsY_Esl2E,4679
|
||||
anyio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/pytest_plugin.py,sha256=kWj2B8BJehePJd1sztRBmJBRh8O4hk1oGSYQRlX5Gr8,5134
|
||||
anyio/streams/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/streams/__pycache__/__init__.cpython-37.pyc,,
|
||||
anyio/streams/__pycache__/buffered.cpython-37.pyc,,
|
||||
anyio/streams/__pycache__/file.cpython-37.pyc,,
|
||||
anyio/streams/__pycache__/memory.cpython-37.pyc,,
|
||||
anyio/streams/__pycache__/stapled.cpython-37.pyc,,
|
||||
anyio/streams/__pycache__/text.cpython-37.pyc,,
|
||||
anyio/streams/__pycache__/tls.cpython-37.pyc,,
|
||||
anyio/streams/buffered.py,sha256=FegOSO4Xcxa5SaDfU1A3ZkTTxaPrv6G435Y_giZ8k44,4437
|
||||
anyio/streams/file.py,sha256=pujJ-m6BX-gOLnVoZwkE5kh-YDs5Vx9eJFVkvliQ0S4,4353
|
||||
anyio/streams/memory.py,sha256=3RGeZoevoGIgBWfD2_X1cqxIPOz-BqQkRf6lUcOnBYc,9209
|
||||
anyio/streams/stapled.py,sha256=0E0V15v8M5GVelpHe5RT0S33tQ9hGe4ZCXo_KJEjtt4,4258
|
||||
anyio/streams/text.py,sha256=WRFyjsRpBjQKdCmR4ZuzYTEAJqGx2s5oTJmGI1C6Ng0,5014
|
||||
anyio/streams/tls.py,sha256=-WXGsMV14XHXAxc38WpBvGusjuY7e449g4UCEHIlnWw,12040
|
||||
anyio/to_process.py,sha256=hu0ES3HJC-VEjcdPJMzAzjyTaekaCNToO3coj3jvnus,9247
|
||||
anyio/to_thread.py,sha256=VeMQoo8Va2zz0WFk2p123QikDpqk2wYZGw20COC3wqw,2124
|
@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.37.1)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
@ -0,0 +1,2 @@
|
||||
[pytest11]
|
||||
anyio = anyio.pytest_plugin
|
@ -0,0 +1 @@
|
||||
anyio
|
@ -0,0 +1,167 @@
|
||||
__all__ = (
|
||||
"maybe_async",
|
||||
"maybe_async_cm",
|
||||
"run",
|
||||
"sleep",
|
||||
"sleep_forever",
|
||||
"sleep_until",
|
||||
"current_time",
|
||||
"get_all_backends",
|
||||
"get_cancelled_exc_class",
|
||||
"BrokenResourceError",
|
||||
"BrokenWorkerProcess",
|
||||
"BusyResourceError",
|
||||
"ClosedResourceError",
|
||||
"DelimiterNotFound",
|
||||
"EndOfStream",
|
||||
"ExceptionGroup",
|
||||
"IncompleteRead",
|
||||
"TypedAttributeLookupError",
|
||||
"WouldBlock",
|
||||
"AsyncFile",
|
||||
"Path",
|
||||
"open_file",
|
||||
"wrap_file",
|
||||
"aclose_forcefully",
|
||||
"open_signal_receiver",
|
||||
"connect_tcp",
|
||||
"connect_unix",
|
||||
"create_tcp_listener",
|
||||
"create_unix_listener",
|
||||
"create_udp_socket",
|
||||
"create_connected_udp_socket",
|
||||
"getaddrinfo",
|
||||
"getnameinfo",
|
||||
"wait_socket_readable",
|
||||
"wait_socket_writable",
|
||||
"create_memory_object_stream",
|
||||
"run_process",
|
||||
"open_process",
|
||||
"create_lock",
|
||||
"CapacityLimiter",
|
||||
"CapacityLimiterStatistics",
|
||||
"Condition",
|
||||
"ConditionStatistics",
|
||||
"Event",
|
||||
"EventStatistics",
|
||||
"Lock",
|
||||
"LockStatistics",
|
||||
"Semaphore",
|
||||
"SemaphoreStatistics",
|
||||
"create_condition",
|
||||
"create_event",
|
||||
"create_semaphore",
|
||||
"create_capacity_limiter",
|
||||
"open_cancel_scope",
|
||||
"fail_after",
|
||||
"move_on_after",
|
||||
"current_effective_deadline",
|
||||
"TASK_STATUS_IGNORED",
|
||||
"CancelScope",
|
||||
"create_task_group",
|
||||
"TaskInfo",
|
||||
"get_current_task",
|
||||
"get_running_tasks",
|
||||
"wait_all_tasks_blocked",
|
||||
"run_sync_in_worker_thread",
|
||||
"run_async_from_thread",
|
||||
"run_sync_from_thread",
|
||||
"current_default_worker_thread_limiter",
|
||||
"create_blocking_portal",
|
||||
"start_blocking_portal",
|
||||
"typed_attribute",
|
||||
"TypedAttributeSet",
|
||||
"TypedAttributeProvider",
|
||||
)
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._core._compat import maybe_async, maybe_async_cm
|
||||
from ._core._eventloop import (
|
||||
current_time,
|
||||
get_all_backends,
|
||||
get_cancelled_exc_class,
|
||||
run,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
sleep_until,
|
||||
)
|
||||
from ._core._exceptions import (
|
||||
BrokenResourceError,
|
||||
BrokenWorkerProcess,
|
||||
BusyResourceError,
|
||||
ClosedResourceError,
|
||||
DelimiterNotFound,
|
||||
EndOfStream,
|
||||
ExceptionGroup,
|
||||
IncompleteRead,
|
||||
TypedAttributeLookupError,
|
||||
WouldBlock,
|
||||
)
|
||||
from ._core._fileio import AsyncFile, Path, open_file, wrap_file
|
||||
from ._core._resources import aclose_forcefully
|
||||
from ._core._signals import open_signal_receiver
|
||||
from ._core._sockets import (
|
||||
connect_tcp,
|
||||
connect_unix,
|
||||
create_connected_udp_socket,
|
||||
create_tcp_listener,
|
||||
create_udp_socket,
|
||||
create_unix_listener,
|
||||
getaddrinfo,
|
||||
getnameinfo,
|
||||
wait_socket_readable,
|
||||
wait_socket_writable,
|
||||
)
|
||||
from ._core._streams import create_memory_object_stream
|
||||
from ._core._subprocesses import open_process, run_process
|
||||
from ._core._synchronization import (
|
||||
CapacityLimiter,
|
||||
CapacityLimiterStatistics,
|
||||
Condition,
|
||||
ConditionStatistics,
|
||||
Event,
|
||||
EventStatistics,
|
||||
Lock,
|
||||
LockStatistics,
|
||||
Semaphore,
|
||||
SemaphoreStatistics,
|
||||
create_capacity_limiter,
|
||||
create_condition,
|
||||
create_event,
|
||||
create_lock,
|
||||
create_semaphore,
|
||||
)
|
||||
from ._core._tasks import (
|
||||
TASK_STATUS_IGNORED,
|
||||
CancelScope,
|
||||
create_task_group,
|
||||
current_effective_deadline,
|
||||
fail_after,
|
||||
move_on_after,
|
||||
open_cancel_scope,
|
||||
)
|
||||
from ._core._testing import (
|
||||
TaskInfo,
|
||||
get_current_task,
|
||||
get_running_tasks,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
from ._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute
|
||||
|
||||
# Re-exported here, for backwards compatibility
|
||||
# isort: off
|
||||
from .to_thread import current_default_worker_thread_limiter, run_sync_in_worker_thread
|
||||
from .from_thread import (
|
||||
create_blocking_portal,
|
||||
run_async_from_thread,
|
||||
run_sync_from_thread,
|
||||
start_blocking_portal,
|
||||
)
|
||||
|
||||
# Re-export imports so they look like they live directly in this package
|
||||
key: str
|
||||
value: Any
|
||||
for key, value in list(locals().items()):
|
||||
if getattr(value, "__module__", "").startswith("anyio."):
|
||||
value.__module__ = __name__
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,988 @@
|
||||
import array
|
||||
import math
|
||||
import socket
|
||||
from concurrent.futures import Future
|
||||
from contextvars import copy_context
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from io import IOBase
|
||||
from os import PathLike
|
||||
from signal import Signals
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
ContextManager,
|
||||
Coroutine,
|
||||
Deque,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import sniffio
|
||||
import trio.from_thread
|
||||
from outcome import Error, Outcome, Value
|
||||
from trio.socket import SocketType as TrioSocketType
|
||||
from trio.to_thread import run_sync
|
||||
|
||||
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
|
||||
from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable, T
|
||||
from .._core._eventloop import claim_worker_thread
|
||||
from .._core._exceptions import (
|
||||
BrokenResourceError,
|
||||
BusyResourceError,
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
)
|
||||
from .._core._exceptions import ExceptionGroup as BaseExceptionGroup
|
||||
from .._core._sockets import convert_ipv6_sockaddr
|
||||
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
|
||||
from .._core._synchronization import Event as BaseEvent
|
||||
from .._core._synchronization import ResourceGuard
|
||||
from .._core._tasks import CancelScope as BaseCancelScope
|
||||
from ..abc import IPSockAddrType, UDPPacketType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trio_typing import TaskStatus
|
||||
|
||||
try:
|
||||
from trio import lowlevel as trio_lowlevel
|
||||
except ImportError:
|
||||
from trio import hazmat as trio_lowlevel # type: ignore[no-redef]
|
||||
from trio.hazmat import wait_readable, wait_writable
|
||||
else:
|
||||
from trio.lowlevel import wait_readable, wait_writable
|
||||
|
||||
try:
|
||||
trio_open_process = trio_lowlevel.open_process # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
from trio import open_process as trio_open_process
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
|
||||
|
||||
|
||||
#
|
||||
# Event loop
|
||||
#
|
||||
|
||||
run = trio.run
|
||||
current_token = trio.lowlevel.current_trio_token
|
||||
RunVar = trio.lowlevel.RunVar
|
||||
|
||||
|
||||
#
|
||||
# Miscellaneous
|
||||
#
|
||||
|
||||
sleep = trio.sleep
|
||||
|
||||
|
||||
#
|
||||
# Timeouts and cancellation
|
||||
#
|
||||
|
||||
|
||||
class CancelScope(BaseCancelScope):
|
||||
def __new__(
|
||||
cls, original: Optional[trio.CancelScope] = None, **kwargs: object
|
||||
) -> "CancelScope":
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self, original: Optional[trio.CancelScope] = None, **kwargs: Any
|
||||
) -> None:
|
||||
self.__original = original or trio.CancelScope(**kwargs)
|
||||
|
||||
def __enter__(self) -> "CancelScope":
|
||||
self.__original.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return self.__original.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def cancel(self) -> DeprecatedAwaitable:
|
||||
self.__original.cancel()
|
||||
return DeprecatedAwaitable(self.cancel)
|
||||
|
||||
@property
|
||||
def deadline(self) -> float:
|
||||
return self.__original.deadline
|
||||
|
||||
@deadline.setter
|
||||
def deadline(self, value: float) -> None:
|
||||
self.__original.deadline = value
|
||||
|
||||
@property
|
||||
def cancel_called(self) -> bool:
|
||||
return self.__original.cancel_called
|
||||
|
||||
@property
|
||||
def shield(self) -> bool:
|
||||
return self.__original.shield
|
||||
|
||||
@shield.setter
|
||||
def shield(self, value: bool) -> None:
|
||||
self.__original.shield = value
|
||||
|
||||
|
||||
CancelledError = trio.Cancelled
|
||||
checkpoint = trio.lowlevel.checkpoint
|
||||
checkpoint_if_cancelled = trio.lowlevel.checkpoint_if_cancelled
|
||||
cancel_shielded_checkpoint = trio.lowlevel.cancel_shielded_checkpoint
|
||||
current_effective_deadline = trio.current_effective_deadline
|
||||
current_time = trio.current_time
|
||||
|
||||
|
||||
#
|
||||
# Task groups
|
||||
#
|
||||
|
||||
|
||||
class ExceptionGroup(BaseExceptionGroup, trio.MultiError):
|
||||
pass
|
||||
|
||||
|
||||
class TaskGroup(abc.TaskGroup):
|
||||
def __init__(self) -> None:
|
||||
self._active = False
|
||||
self._nursery_manager = trio.open_nursery()
|
||||
self.cancel_scope = None # type: ignore[assignment]
|
||||
|
||||
async def __aenter__(self) -> "TaskGroup":
|
||||
self._active = True
|
||||
self._nursery = await self._nursery_manager.__aenter__()
|
||||
self.cancel_scope = CancelScope(self._nursery.cancel_scope)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
try:
|
||||
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
|
||||
except trio.MultiError as exc:
|
||||
raise ExceptionGroup(exc.exceptions) from None
|
||||
finally:
|
||||
self._active = False
|
||||
|
||||
def start_soon(self, func: Callable, *args: object, name: object = None) -> None:
|
||||
if not self._active:
|
||||
raise RuntimeError(
|
||||
"This task group is not active; no new tasks can be started."
|
||||
)
|
||||
|
||||
self._nursery.start_soon(func, *args, name=name)
|
||||
|
||||
async def start(
|
||||
self, func: Callable[..., Coroutine], *args: object, name: object = None
|
||||
) -> object:
|
||||
if not self._active:
|
||||
raise RuntimeError(
|
||||
"This task group is not active; no new tasks can be started."
|
||||
)
|
||||
|
||||
return await self._nursery.start(func, *args, name=name)
|
||||
|
||||
|
||||
#
|
||||
# Threads
|
||||
#
|
||||
|
||||
|
||||
async def run_sync_in_worker_thread(
|
||||
func: Callable[..., T_Retval],
|
||||
*args: object,
|
||||
cancellable: bool = False,
|
||||
limiter: Optional[trio.CapacityLimiter] = None,
|
||||
) -> T_Retval:
|
||||
def wrapper() -> T_Retval:
|
||||
with claim_worker_thread("trio"):
|
||||
return func(*args)
|
||||
|
||||
# TODO: remove explicit context copying when trio 0.20 is the minimum requirement
|
||||
context = copy_context()
|
||||
context.run(sniffio.current_async_library_cvar.set, None)
|
||||
return await run_sync(
|
||||
context.run, wrapper, cancellable=cancellable, limiter=limiter
|
||||
)
|
||||
|
||||
|
||||
# TODO: remove this workaround when trio 0.20 is the minimum requirement
|
||||
def run_async_from_thread(
|
||||
fn: Callable[..., Awaitable[T_Retval]], *args: Any
|
||||
) -> T_Retval:
|
||||
async def wrapper() -> T_Retval:
|
||||
retval: T_Retval
|
||||
|
||||
async def inner() -> None:
|
||||
nonlocal retval
|
||||
__tracebackhide__ = True
|
||||
retval = await fn(*args)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
context.run(n.start_soon, inner)
|
||||
|
||||
__tracebackhide__ = True
|
||||
return retval
|
||||
|
||||
context = copy_context()
|
||||
context.run(sniffio.current_async_library_cvar.set, "trio")
|
||||
return trio.from_thread.run(wrapper)
|
||||
|
||||
|
||||
def run_sync_from_thread(fn: Callable[..., T_Retval], *args: Any) -> T_Retval:
|
||||
# TODO: remove explicit context copying when trio 0.20 is the minimum requirement
|
||||
retval = trio.from_thread.run_sync(copy_context().run, fn, *args)
|
||||
return cast(T_Retval, retval)
|
||||
|
||||
|
||||
class BlockingPortal(abc.BlockingPortal):
|
||||
def __new__(cls) -> "BlockingPortal":
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._token = trio.lowlevel.current_trio_token()
|
||||
|
||||
def _spawn_task_from_thread(
|
||||
self,
|
||||
func: Callable,
|
||||
args: tuple,
|
||||
kwargs: Dict[str, Any],
|
||||
name: object,
|
||||
future: Future,
|
||||
) -> None:
|
||||
context = copy_context()
|
||||
context.run(sniffio.current_async_library_cvar.set, "trio")
|
||||
trio.from_thread.run_sync(
|
||||
context.run,
|
||||
partial(self._task_group.start_soon, name=name),
|
||||
self._call_func,
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
future,
|
||||
trio_token=self._token,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Subprocesses
|
||||
#
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ReceiveStreamWrapper(abc.ByteReceiveStream):
|
||||
_stream: trio.abc.ReceiveStream
|
||||
|
||||
async def receive(self, max_bytes: Optional[int] = None) -> bytes:
|
||||
try:
|
||||
data = await self._stream.receive_some(max_bytes)
|
||||
except trio.ClosedResourceError as exc:
|
||||
raise ClosedResourceError from exc.__cause__
|
||||
except trio.BrokenResourceError as exc:
|
||||
raise BrokenResourceError from exc.__cause__
|
||||
|
||||
if data:
|
||||
return data
|
||||
else:
|
||||
raise EndOfStream
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class SendStreamWrapper(abc.ByteSendStream):
|
||||
_stream: trio.abc.SendStream
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
try:
|
||||
await self._stream.send_all(item)
|
||||
except trio.ClosedResourceError as exc:
|
||||
raise ClosedResourceError from exc.__cause__
|
||||
except trio.BrokenResourceError as exc:
|
||||
raise BrokenResourceError from exc.__cause__
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class Process(abc.Process):
|
||||
_process: trio.Process
|
||||
_stdin: Optional[abc.ByteSendStream]
|
||||
_stdout: Optional[abc.ByteReceiveStream]
|
||||
_stderr: Optional[abc.ByteReceiveStream]
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._stdin:
|
||||
await self._stdin.aclose()
|
||||
if self._stdout:
|
||||
await self._stdout.aclose()
|
||||
if self._stderr:
|
||||
await self._stderr.aclose()
|
||||
|
||||
await self.wait()
|
||||
|
||||
async def wait(self) -> int:
|
||||
return await self._process.wait()
|
||||
|
||||
def terminate(self) -> None:
|
||||
self._process.terminate()
|
||||
|
||||
def kill(self) -> None:
|
||||
self._process.kill()
|
||||
|
||||
def send_signal(self, signal: Signals) -> None:
|
||||
self._process.send_signal(signal)
|
||||
|
||||
@property
|
||||
def pid(self) -> int:
|
||||
return self._process.pid
|
||||
|
||||
@property
|
||||
def returncode(self) -> Optional[int]:
|
||||
return self._process.returncode
|
||||
|
||||
@property
|
||||
def stdin(self) -> Optional[abc.ByteSendStream]:
|
||||
return self._stdin
|
||||
|
||||
@property
|
||||
def stdout(self) -> Optional[abc.ByteReceiveStream]:
|
||||
return self._stdout
|
||||
|
||||
@property
|
||||
def stderr(self) -> Optional[abc.ByteReceiveStream]:
|
||||
return self._stderr
|
||||
|
||||
|
||||
async def open_process(
|
||||
command: Union[str, bytes, Sequence[Union[str, bytes]]],
|
||||
*,
|
||||
shell: bool,
|
||||
stdin: Union[int, IO[Any], None],
|
||||
stdout: Union[int, IO[Any], None],
|
||||
stderr: Union[int, IO[Any], None],
|
||||
cwd: Union[str, bytes, PathLike, None] = None,
|
||||
env: Optional[Mapping[str, str]] = None,
|
||||
start_new_session: bool = False,
|
||||
) -> Process:
|
||||
process = await trio_open_process(
|
||||
command,
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
shell=shell,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
)
|
||||
stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
|
||||
stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
|
||||
stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
|
||||
return Process(process, stdin_stream, stdout_stream, stderr_stream)
|
||||
|
||||
|
||||
class _ProcessPoolShutdownInstrument(trio.abc.Instrument):
|
||||
def after_run(self) -> None:
|
||||
super().after_run()
|
||||
|
||||
|
||||
current_default_worker_process_limiter: RunVar = RunVar(
|
||||
"current_default_worker_process_limiter"
|
||||
)
|
||||
|
||||
|
||||
async def _shutdown_process_pool(workers: Set[Process]) -> None:
|
||||
process: Process
|
||||
try:
|
||||
await sleep(math.inf)
|
||||
except trio.Cancelled:
|
||||
for process in workers:
|
||||
if process.returncode is None:
|
||||
process.kill()
|
||||
|
||||
with CancelScope(shield=True):
|
||||
for process in workers:
|
||||
await process.aclose()
|
||||
|
||||
|
||||
def setup_process_pool_exit_at_shutdown(workers: Set[Process]) -> None:
|
||||
trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers)
|
||||
|
||||
|
||||
#
|
||||
# Sockets and networking
|
||||
#
|
||||
|
||||
|
||||
class _TrioSocketMixin(Generic[T_SockAddr]):
|
||||
def __init__(self, trio_socket: TrioSocketType) -> None:
|
||||
self._trio_socket = trio_socket
|
||||
self._closed = False
|
||||
|
||||
def _check_closed(self) -> None:
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
if self._trio_socket.fileno() < 0:
|
||||
raise BrokenResourceError
|
||||
|
||||
@property
|
||||
def _raw_socket(self) -> socket.socket:
|
||||
return self._trio_socket._sock # type: ignore[attr-defined]
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._trio_socket.fileno() >= 0:
|
||||
self._closed = True
|
||||
self._trio_socket.close()
|
||||
|
||||
def _convert_socket_error(self, exc: BaseException) -> "NoReturn":
|
||||
if isinstance(exc, trio.ClosedResourceError):
|
||||
raise ClosedResourceError from exc
|
||||
elif self._trio_socket.fileno() < 0 and self._closed:
|
||||
raise ClosedResourceError from None
|
||||
elif isinstance(exc, OSError):
|
||||
raise BrokenResourceError from exc
|
||||
else:
|
||||
raise exc
|
||||
|
||||
|
||||
class SocketStream(_TrioSocketMixin, abc.SocketStream):
|
||||
def __init__(self, trio_socket: TrioSocketType) -> None:
|
||||
super().__init__(trio_socket)
|
||||
self._receive_guard = ResourceGuard("reading from")
|
||||
self._send_guard = ResourceGuard("writing to")
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
with self._receive_guard:
|
||||
try:
|
||||
data = await self._trio_socket.recv(max_bytes)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
if data:
|
||||
return data
|
||||
else:
|
||||
raise EndOfStream
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
with self._send_guard:
|
||||
view = memoryview(item)
|
||||
while view:
|
||||
try:
|
||||
bytes_sent = await self._trio_socket.send(view)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
view = view[bytes_sent:]
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
self._trio_socket.shutdown(socket.SHUT_WR)
|
||||
|
||||
|
||||
class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
|
||||
async def receive_fds(self, msglen: int, maxfds: int) -> Tuple[bytes, List[int]]:
|
||||
if not isinstance(msglen, int) or msglen < 0:
|
||||
raise ValueError("msglen must be a non-negative integer")
|
||||
if not isinstance(maxfds, int) or maxfds < 1:
|
||||
raise ValueError("maxfds must be a positive integer")
|
||||
|
||||
fds = array.array("i")
|
||||
await checkpoint()
|
||||
with self._receive_guard:
|
||||
while True:
|
||||
try:
|
||||
message, ancdata, flags, addr = await self._trio_socket.recvmsg(
|
||||
msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
|
||||
)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
else:
|
||||
if not message and not ancdata:
|
||||
raise EndOfStream
|
||||
|
||||
break
|
||||
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
|
||||
raise RuntimeError(
|
||||
f"Received unexpected ancillary data; message = {message!r}, "
|
||||
f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
|
||||
)
|
||||
|
||||
fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
|
||||
|
||||
return message, list(fds)
|
||||
|
||||
async def send_fds(
|
||||
self, message: bytes, fds: Collection[Union[int, IOBase]]
|
||||
) -> None:
|
||||
if not message:
|
||||
raise ValueError("message must not be empty")
|
||||
if not fds:
|
||||
raise ValueError("fds must not be empty")
|
||||
|
||||
filenos: List[int] = []
|
||||
for fd in fds:
|
||||
if isinstance(fd, int):
|
||||
filenos.append(fd)
|
||||
elif isinstance(fd, IOBase):
|
||||
filenos.append(fd.fileno())
|
||||
|
||||
fdarray = array.array("i", filenos)
|
||||
await checkpoint()
|
||||
with self._send_guard:
|
||||
while True:
|
||||
try:
|
||||
await self._trio_socket.sendmsg(
|
||||
[message],
|
||||
[
|
||||
(
|
||||
socket.SOL_SOCKET,
|
||||
socket.SCM_RIGHTS, # type: ignore[list-item]
|
||||
fdarray,
|
||||
)
|
||||
],
|
||||
)
|
||||
break
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
|
||||
class TCPSocketListener(_TrioSocketMixin, abc.SocketListener):
|
||||
def __init__(self, raw_socket: socket.socket):
|
||||
super().__init__(trio.socket.from_stdlib_socket(raw_socket))
|
||||
self._accept_guard = ResourceGuard("accepting connections from")
|
||||
|
||||
async def accept(self) -> SocketStream:
|
||||
with self._accept_guard:
|
||||
try:
|
||||
trio_socket, _addr = await self._trio_socket.accept()
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
return SocketStream(trio_socket)
|
||||
|
||||
|
||||
class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener):
|
||||
def __init__(self, raw_socket: socket.socket):
|
||||
super().__init__(trio.socket.from_stdlib_socket(raw_socket))
|
||||
self._accept_guard = ResourceGuard("accepting connections from")
|
||||
|
||||
async def accept(self) -> UNIXSocketStream:
|
||||
with self._accept_guard:
|
||||
try:
|
||||
trio_socket, _addr = await self._trio_socket.accept()
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
return UNIXSocketStream(trio_socket)
|
||||
|
||||
|
||||
class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket):
|
||||
def __init__(self, trio_socket: TrioSocketType) -> None:
|
||||
super().__init__(trio_socket)
|
||||
self._receive_guard = ResourceGuard("reading from")
|
||||
self._send_guard = ResourceGuard("writing to")
|
||||
|
||||
async def receive(self) -> Tuple[bytes, IPSockAddrType]:
|
||||
with self._receive_guard:
|
||||
try:
|
||||
data, addr = await self._trio_socket.recvfrom(65536)
|
||||
return data, convert_ipv6_sockaddr(addr)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
async def send(self, item: UDPPacketType) -> None:
|
||||
with self._send_guard:
|
||||
try:
|
||||
await self._trio_socket.sendto(*item)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
|
||||
class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket):
|
||||
def __init__(self, trio_socket: TrioSocketType) -> None:
|
||||
super().__init__(trio_socket)
|
||||
self._receive_guard = ResourceGuard("reading from")
|
||||
self._send_guard = ResourceGuard("writing to")
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
with self._receive_guard:
|
||||
try:
|
||||
return await self._trio_socket.recv(65536)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
with self._send_guard:
|
||||
try:
|
||||
await self._trio_socket.send(item)
|
||||
except BaseException as exc:
|
||||
self._convert_socket_error(exc)
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str, port: int, local_address: Optional[IPSockAddrType] = None
|
||||
) -> SocketStream:
|
||||
family = socket.AF_INET6 if ":" in host else socket.AF_INET
|
||||
trio_socket = trio.socket.socket(family)
|
||||
trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
if local_address:
|
||||
await trio_socket.bind(local_address)
|
||||
|
||||
try:
|
||||
await trio_socket.connect((host, port))
|
||||
except BaseException:
|
||||
trio_socket.close()
|
||||
raise
|
||||
|
||||
return SocketStream(trio_socket)
|
||||
|
||||
|
||||
async def connect_unix(path: str) -> UNIXSocketStream:
|
||||
trio_socket = trio.socket.socket(socket.AF_UNIX)
|
||||
try:
|
||||
await trio_socket.connect(path)
|
||||
except BaseException:
|
||||
trio_socket.close()
|
||||
raise
|
||||
|
||||
return UNIXSocketStream(trio_socket)
|
||||
|
||||
|
||||
async def create_udp_socket(
|
||||
family: socket.AddressFamily,
|
||||
local_address: Optional[IPSockAddrType],
|
||||
remote_address: Optional[IPSockAddrType],
|
||||
reuse_port: bool,
|
||||
) -> Union[UDPSocket, ConnectedUDPSocket]:
|
||||
trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
|
||||
|
||||
if reuse_port:
|
||||
trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
|
||||
if local_address:
|
||||
await trio_socket.bind(local_address)
|
||||
|
||||
if remote_address:
|
||||
await trio_socket.connect(remote_address)
|
||||
return ConnectedUDPSocket(trio_socket)
|
||||
else:
|
||||
return UDPSocket(trio_socket)
|
||||
|
||||
|
||||
getaddrinfo = trio.socket.getaddrinfo
|
||||
getnameinfo = trio.socket.getnameinfo
|
||||
|
||||
|
||||
async def wait_socket_readable(sock: socket.socket) -> None:
|
||||
try:
|
||||
await wait_readable(sock)
|
||||
except trio.ClosedResourceError as exc:
|
||||
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
|
||||
except trio.BusyResourceError:
|
||||
raise BusyResourceError("reading from") from None
|
||||
|
||||
|
||||
async def wait_socket_writable(sock: socket.socket) -> None:
|
||||
try:
|
||||
await wait_writable(sock)
|
||||
except trio.ClosedResourceError as exc:
|
||||
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
|
||||
except trio.BusyResourceError:
|
||||
raise BusyResourceError("writing to") from None
|
||||
|
||||
|
||||
#
|
||||
# Synchronization
|
||||
#
|
||||
|
||||
|
||||
class Event(BaseEvent):
|
||||
def __new__(cls) -> "Event":
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.__original = trio.Event()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self.__original.is_set()
|
||||
|
||||
async def wait(self) -> None:
|
||||
return await self.__original.wait()
|
||||
|
||||
def statistics(self) -> EventStatistics:
|
||||
orig_statistics = self.__original.statistics()
|
||||
return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting)
|
||||
|
||||
def set(self) -> DeprecatedAwaitable:
|
||||
self.__original.set()
|
||||
return DeprecatedAwaitable(self.set)
|
||||
|
||||
|
||||
class CapacityLimiter(BaseCapacityLimiter):
|
||||
def __new__(cls, *args: object, **kwargs: object) -> "CapacityLimiter":
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self, *args: Any, original: Optional[trio.CapacityLimiter] = None
|
||||
) -> None:
|
||||
self.__original = original or trio.CapacityLimiter(*args)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
return await self.__original.__aenter__()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return await self.__original.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> float:
|
||||
return self.__original.total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: float) -> None:
|
||||
self.__original.total_tokens = value
|
||||
|
||||
@property
|
||||
def borrowed_tokens(self) -> int:
|
||||
return self.__original.borrowed_tokens
|
||||
|
||||
@property
|
||||
def available_tokens(self) -> float:
|
||||
return self.__original.available_tokens
|
||||
|
||||
def acquire_nowait(self) -> DeprecatedAwaitable:
|
||||
self.__original.acquire_nowait()
|
||||
return DeprecatedAwaitable(self.acquire_nowait)
|
||||
|
||||
def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
|
||||
self.__original.acquire_on_behalf_of_nowait(borrower)
|
||||
return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait)
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.__original.acquire()
|
||||
|
||||
async def acquire_on_behalf_of(self, borrower: object) -> None:
|
||||
await self.__original.acquire_on_behalf_of(borrower)
|
||||
|
||||
def release(self) -> None:
|
||||
return self.__original.release()
|
||||
|
||||
def release_on_behalf_of(self, borrower: object) -> None:
|
||||
return self.__original.release_on_behalf_of(borrower)
|
||||
|
||||
def statistics(self) -> CapacityLimiterStatistics:
|
||||
orig = self.__original.statistics()
|
||||
return CapacityLimiterStatistics(
|
||||
borrowed_tokens=orig.borrowed_tokens,
|
||||
total_tokens=orig.total_tokens,
|
||||
borrowers=orig.borrowers,
|
||||
tasks_waiting=orig.tasks_waiting,
|
||||
)
|
||||
|
||||
|
||||
_capacity_limiter_wrapper: RunVar = RunVar("_capacity_limiter_wrapper")
|
||||
|
||||
|
||||
def current_default_thread_limiter() -> CapacityLimiter:
|
||||
try:
|
||||
return _capacity_limiter_wrapper.get()
|
||||
except LookupError:
|
||||
limiter = CapacityLimiter(
|
||||
original=trio.to_thread.current_default_thread_limiter()
|
||||
)
|
||||
_capacity_limiter_wrapper.set(limiter)
|
||||
return limiter
|
||||
|
||||
|
||||
#
|
||||
# Signal handling
|
||||
#
|
||||
|
||||
|
||||
class _SignalReceiver(DeprecatedAsyncContextManager[T]):
|
||||
def __init__(self, cm: ContextManager[T]):
|
||||
self._cm = cm
|
||||
|
||||
def __enter__(self) -> T:
|
||||
return self._cm.__enter__()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return self._cm.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
def open_signal_receiver(*signals: Signals) -> _SignalReceiver:
|
||||
cm = trio.open_signal_receiver(*signals)
|
||||
return _SignalReceiver(cm)
|
||||
|
||||
|
||||
#
|
||||
# Testing and debugging
|
||||
#
|
||||
|
||||
|
||||
def get_current_task() -> TaskInfo:
|
||||
task = trio_lowlevel.current_task()
|
||||
|
||||
parent_id = None
|
||||
if task.parent_nursery and task.parent_nursery.parent_task:
|
||||
parent_id = id(task.parent_nursery.parent_task)
|
||||
|
||||
return TaskInfo(id(task), parent_id, task.name, task.coro)
|
||||
|
||||
|
||||
def get_running_tasks() -> List[TaskInfo]:
|
||||
root_task = trio_lowlevel.current_root_task()
|
||||
task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
|
||||
nurseries = root_task.child_nurseries
|
||||
while nurseries:
|
||||
new_nurseries: List[trio.Nursery] = []
|
||||
for nursery in nurseries:
|
||||
for task in nursery.child_tasks:
|
||||
task_infos.append(
|
||||
TaskInfo(id(task), id(nursery.parent_task), task.name, task.coro)
|
||||
)
|
||||
new_nurseries.extend(task.child_nurseries)
|
||||
|
||||
nurseries = new_nurseries
|
||||
|
||||
return task_infos
|
||||
|
||||
|
||||
def wait_all_tasks_blocked() -> Awaitable[None]:
|
||||
import trio.testing
|
||||
|
||||
return trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
class TestRunner(abc.TestRunner):
|
||||
def __init__(self, **options: Any) -> None:
|
||||
from collections import deque
|
||||
from queue import Queue
|
||||
|
||||
self._call_queue: "Queue[Callable[..., object]]" = Queue()
|
||||
self._result_queue: Deque[Outcome] = deque()
|
||||
self._stop_event: Optional[trio.Event] = None
|
||||
self._nursery: Optional[trio.Nursery] = None
|
||||
self._options = options
|
||||
|
||||
async def _trio_main(self) -> None:
|
||||
self._stop_event = trio.Event()
|
||||
async with trio.open_nursery() as self._nursery:
|
||||
await self._stop_event.wait()
|
||||
|
||||
async def _call_func(
|
||||
self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict
|
||||
) -> None:
|
||||
try:
|
||||
retval = await func(*args, **kwargs)
|
||||
except BaseException as exc:
|
||||
self._result_queue.append(Error(exc))
|
||||
else:
|
||||
self._result_queue.append(Value(retval))
|
||||
|
||||
def _main_task_finished(self, outcome: object) -> None:
|
||||
self._nursery = None
|
||||
|
||||
def _get_nursery(self) -> trio.Nursery:
|
||||
if self._nursery is None:
|
||||
trio.lowlevel.start_guest_run(
|
||||
self._trio_main,
|
||||
run_sync_soon_threadsafe=self._call_queue.put,
|
||||
done_callback=self._main_task_finished,
|
||||
**self._options,
|
||||
)
|
||||
while self._nursery is None:
|
||||
self._call_queue.get()()
|
||||
|
||||
return self._nursery
|
||||
|
||||
def _call(
|
||||
self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
|
||||
) -> T_Retval:
|
||||
self._get_nursery().start_soon(self._call_func, func, args, kwargs)
|
||||
while not self._result_queue:
|
||||
self._call_queue.get()()
|
||||
|
||||
outcome = self._result_queue.pop()
|
||||
return outcome.unwrap()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
while self._nursery is not None:
|
||||
self._call_queue.get()()
|
||||
|
||||
def run_asyncgen_fixture(
|
||||
self,
|
||||
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Iterable[T_Retval]:
|
||||
async def fixture_runner(*, task_status: "TaskStatus") -> None:
|
||||
agen = fixture_func(**kwargs)
|
||||
retval = await agen.asend(None)
|
||||
task_status.started(retval)
|
||||
await teardown_event.wait()
|
||||
try:
|
||||
await agen.asend(None)
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
else:
|
||||
await agen.aclose()
|
||||
raise RuntimeError("Async generator fixture did not stop")
|
||||
|
||||
teardown_event = trio.Event()
|
||||
fixture_value = self._call(lambda: self._get_nursery().start(fixture_runner))
|
||||
yield fixture_value
|
||||
teardown_event.set()
|
||||
|
||||
def run_fixture(
|
||||
self,
|
||||
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> T_Retval:
|
||||
return self._call(fixture_func, **kwargs)
|
||||
|
||||
def run_test(
|
||||
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
self._call(test_func, **kwargs)
|
@ -0,0 +1,218 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from warnings import warn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._testing import TaskInfo
|
||||
else:
|
||||
TaskInfo = object
|
||||
|
||||
T = TypeVar("T")
|
||||
AnyDeprecatedAwaitable = Union[
|
||||
"DeprecatedAwaitable",
|
||||
"DeprecatedAwaitableFloat",
|
||||
"DeprecatedAwaitableList[T]",
|
||||
TaskInfo,
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
async def maybe_async(__obj: TaskInfo) -> TaskInfo:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def maybe_async(__obj: "DeprecatedAwaitableFloat") -> float:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def maybe_async(__obj: "DeprecatedAwaitableList[T]") -> List[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def maybe_async(__obj: "DeprecatedAwaitable") -> None:
|
||||
...
|
||||
|
||||
|
||||
async def maybe_async(
|
||||
__obj: "AnyDeprecatedAwaitable[T]",
|
||||
) -> Union[TaskInfo, float, List[T], None]:
|
||||
"""
|
||||
Await on the given object if necessary.
|
||||
|
||||
This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and
|
||||
methods were converted from coroutine functions into regular functions.
|
||||
|
||||
Do **not** try to use this for any other purpose!
|
||||
|
||||
:return: the result of awaiting on the object if coroutine, or the object itself otherwise
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
"""
|
||||
return __obj._unwrap()
|
||||
|
||||
|
||||
class _ContextManagerWrapper:
|
||||
def __init__(self, cm: ContextManager[T]):
|
||||
self._cm = cm
|
||||
|
||||
async def __aenter__(self) -> T:
|
||||
return self._cm.__enter__()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return self._cm.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
def maybe_async_cm(
|
||||
cm: Union[ContextManager[T], AsyncContextManager[T]]
|
||||
) -> AsyncContextManager[T]:
|
||||
"""
|
||||
Wrap a regular context manager as an async one if necessary.
|
||||
|
||||
This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and
|
||||
methods were changed to return regular context managers instead of async ones.
|
||||
|
||||
:param cm: a regular or async context manager
|
||||
:return: an async context manager
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
"""
|
||||
if not isinstance(cm, AbstractContextManager):
|
||||
raise TypeError("Given object is not an context manager")
|
||||
|
||||
return _ContextManagerWrapper(cm)
|
||||
|
||||
|
||||
def _warn_deprecation(
|
||||
awaitable: "AnyDeprecatedAwaitable[Any]", stacklevel: int = 1
|
||||
) -> None:
|
||||
warn(
|
||||
f'Awaiting on {awaitable._name}() is deprecated. Use "await '
|
||||
f"anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x "
|
||||
f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.',
|
||||
DeprecationWarning,
|
||||
stacklevel=stacklevel + 1,
|
||||
)
|
||||
|
||||
|
||||
class DeprecatedAwaitable:
|
||||
def __init__(self, func: Callable[..., "DeprecatedAwaitable"]):
|
||||
self._name = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
def __await__(self) -> Generator[None, None, None]:
|
||||
_warn_deprecation(self)
|
||||
if False:
|
||||
yield
|
||||
|
||||
def __reduce__(self) -> Tuple[Type[None], Tuple[()]]:
|
||||
return type(None), ()
|
||||
|
||||
def _unwrap(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class DeprecatedAwaitableFloat(float):
|
||||
def __new__(
|
||||
cls, x: float, func: Callable[..., "DeprecatedAwaitableFloat"]
|
||||
) -> "DeprecatedAwaitableFloat":
|
||||
return super().__new__(cls, x)
|
||||
|
||||
def __init__(self, x: float, func: Callable[..., "DeprecatedAwaitableFloat"]):
|
||||
self._name = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
def __await__(self) -> Generator[None, None, float]:
|
||||
_warn_deprecation(self)
|
||||
if False:
|
||||
yield
|
||||
|
||||
return float(self)
|
||||
|
||||
def __reduce__(self) -> Tuple[Type[float], Tuple[float]]:
|
||||
return float, (float(self),)
|
||||
|
||||
def _unwrap(self) -> float:
|
||||
return float(self)
|
||||
|
||||
|
||||
class DeprecatedAwaitableList(List[T]):
|
||||
def __init__(
|
||||
self,
|
||||
iterable: Iterable[T] = (),
|
||||
*,
|
||||
func: Callable[..., "DeprecatedAwaitableList[T]"],
|
||||
):
|
||||
super().__init__(iterable)
|
||||
self._name = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
def __await__(self) -> Generator[None, None, List[T]]:
|
||||
_warn_deprecation(self)
|
||||
if False:
|
||||
yield
|
||||
|
||||
return list(self)
|
||||
|
||||
def __reduce__(self) -> Tuple[Type[List[T]], Tuple[List[T]]]:
|
||||
return list, (list(self),)
|
||||
|
||||
def _unwrap(self) -> List[T]:
|
||||
return list(self)
|
||||
|
||||
|
||||
class DeprecatedAsyncContextManager(Generic[T], metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> T:
|
||||
warn(
|
||||
f"Using {self.__class__.__name__} as an async context manager has been deprecated. "
|
||||
f'Use "async with anyio.maybe_async_cm(yourcontextmanager) as foo:" if you have to '
|
||||
f'support both AnyIO 2.x and 3.x, or just remove the "async" from "async with" if '
|
||||
f"you are completely migrating to AnyIO 3+.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self.__enter__()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return self.__exit__(exc_type, exc_val, exc_tb)
|
@ -0,0 +1,93 @@
|
||||
from traceback import format_exception
|
||||
from typing import List
|
||||
|
||||
|
||||
class BrokenResourceError(Exception):
|
||||
"""
|
||||
Raised when trying to use a resource that has been rendered unusable due to external causes
|
||||
(e.g. a send stream whose peer has disconnected).
|
||||
"""
|
||||
|
||||
|
||||
class BrokenWorkerProcess(Exception):
|
||||
"""
|
||||
Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or otherwise
|
||||
misbehaves.
|
||||
"""
|
||||
|
||||
|
||||
class BusyResourceError(Exception):
|
||||
"""Raised when two tasks are trying to read from or write to the same resource concurrently."""
|
||||
|
||||
def __init__(self, action: str):
|
||||
super().__init__(f"Another task is already {action} this resource")
|
||||
|
||||
|
||||
class ClosedResourceError(Exception):
|
||||
"""Raised when trying to use a resource that has been closed."""
|
||||
|
||||
|
||||
class DelimiterNotFound(Exception):
|
||||
"""
|
||||
Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
|
||||
maximum number of bytes has been read without the delimiter being found.
|
||||
"""
|
||||
|
||||
def __init__(self, max_bytes: int) -> None:
|
||||
super().__init__(
|
||||
f"The delimiter was not found among the first {max_bytes} bytes"
|
||||
)
|
||||
|
||||
|
||||
class EndOfStream(Exception):
|
||||
"""Raised when trying to read from a stream that has been closed from the other end."""
|
||||
|
||||
|
||||
class ExceptionGroup(BaseException):
|
||||
"""
|
||||
Raised when multiple exceptions have been raised in a task group.
|
||||
|
||||
:var ~typing.Sequence[BaseException] exceptions: the sequence of exceptions raised together
|
||||
"""
|
||||
|
||||
SEPARATOR = "----------------------------\n"
|
||||
|
||||
exceptions: List[BaseException]
|
||||
|
||||
def __str__(self) -> str:
|
||||
tracebacks = [
|
||||
"".join(format_exception(type(exc), exc, exc.__traceback__))
|
||||
for exc in self.exceptions
|
||||
]
|
||||
return (
|
||||
f"{len(self.exceptions)} exceptions were raised in the task group:\n"
|
||||
f"{self.SEPARATOR}{self.SEPARATOR.join(tracebacks)}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exception_reprs = ", ".join(repr(exc) for exc in self.exceptions)
|
||||
return f"<{self.__class__.__name__}: {exception_reprs}>"
|
||||
|
||||
|
||||
class IncompleteRead(Exception):
|
||||
"""
|
||||
Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or
|
||||
:meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
|
||||
connection is closed before the requested amount of bytes has been read.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"The stream was closed before the read operation could be completed"
|
||||
)
|
||||
|
||||
|
||||
class TypedAttributeLookupError(LookupError):
|
||||
"""
|
||||
Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute is not
|
||||
found and no default value has been given.
|
||||
"""
|
||||
|
||||
|
||||
class WouldBlock(Exception):
|
||||
"""Raised by ``X_nowait`` functions if ``X()`` would block."""
|
@ -0,0 +1,607 @@
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from os import PathLike
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AnyStr,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from .. import to_thread
|
||||
from ..abc import AsyncResource
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Final
|
||||
else:
|
||||
from typing_extensions import Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer
|
||||
else:
|
||||
ReadableBuffer = OpenBinaryMode = OpenTextMode = WriteableBuffer = object
|
||||
|
||||
|
||||
class AsyncFile(AsyncResource, Generic[AnyStr]):
|
||||
"""
|
||||
An asynchronous file object.
|
||||
|
||||
This class wraps a standard file object and provides async friendly versions of the following
|
||||
blocking methods (where available on the original file object):
|
||||
|
||||
* read
|
||||
* read1
|
||||
* readline
|
||||
* readlines
|
||||
* readinto
|
||||
* readinto1
|
||||
* write
|
||||
* writelines
|
||||
* truncate
|
||||
* seek
|
||||
* tell
|
||||
* flush
|
||||
|
||||
All other methods are directly passed through.
|
||||
|
||||
This class supports the asynchronous context manager protocol which closes the underlying file
|
||||
at the end of the context block.
|
||||
|
||||
This class also supports asynchronous iteration::
|
||||
|
||||
async with await open_file(...) as f:
|
||||
async for line in f:
|
||||
print(line)
|
||||
"""
|
||||
|
||||
def __init__(self, fp: IO[AnyStr]) -> None:
|
||||
self._fp: Any = fp
|
||||
|
||||
def __getattr__(self, name: str) -> object:
|
||||
return getattr(self._fp, name)
|
||||
|
||||
@property
|
||||
def wrapped(self) -> IO[AnyStr]:
|
||||
"""The wrapped file object."""
|
||||
return self._fp
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[AnyStr]:
|
||||
while True:
|
||||
line = await self.readline()
|
||||
if line:
|
||||
yield line
|
||||
else:
|
||||
break
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return await to_thread.run_sync(self._fp.close)
|
||||
|
||||
async def read(self, size: int = -1) -> AnyStr:
|
||||
return await to_thread.run_sync(self._fp.read, size)
|
||||
|
||||
async def read1(self: "AsyncFile[bytes]", size: int = -1) -> bytes:
|
||||
return await to_thread.run_sync(self._fp.read1, size)
|
||||
|
||||
async def readline(self) -> AnyStr:
|
||||
return await to_thread.run_sync(self._fp.readline)
|
||||
|
||||
async def readlines(self) -> List[AnyStr]:
|
||||
return await to_thread.run_sync(self._fp.readlines)
|
||||
|
||||
async def readinto(self: "AsyncFile[bytes]", b: WriteableBuffer) -> bytes:
|
||||
return await to_thread.run_sync(self._fp.readinto, b)
|
||||
|
||||
async def readinto1(self: "AsyncFile[bytes]", b: WriteableBuffer) -> bytes:
|
||||
return await to_thread.run_sync(self._fp.readinto1, b)
|
||||
|
||||
@overload
|
||||
async def write(self: "AsyncFile[bytes]", b: ReadableBuffer) -> int:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def write(self: "AsyncFile[str]", b: str) -> int:
|
||||
...
|
||||
|
||||
async def write(self, b: Union[ReadableBuffer, str]) -> int:
|
||||
return await to_thread.run_sync(self._fp.write, b)
|
||||
|
||||
@overload
|
||||
async def writelines(
|
||||
self: "AsyncFile[bytes]", lines: Iterable[ReadableBuffer]
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def writelines(self: "AsyncFile[str]", lines: Iterable[str]) -> None:
|
||||
...
|
||||
|
||||
async def writelines(
|
||||
self, lines: Union[Iterable[ReadableBuffer], Iterable[str]]
|
||||
) -> None:
|
||||
return await to_thread.run_sync(self._fp.writelines, lines)
|
||||
|
||||
async def truncate(self, size: Optional[int] = None) -> int:
|
||||
return await to_thread.run_sync(self._fp.truncate, size)
|
||||
|
||||
async def seek(self, offset: int, whence: Optional[int] = os.SEEK_SET) -> int:
|
||||
return await to_thread.run_sync(self._fp.seek, offset, whence)
|
||||
|
||||
async def tell(self) -> int:
|
||||
return await to_thread.run_sync(self._fp.tell)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return await to_thread.run_sync(self._fp.flush)
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: Union[str, "PathLike[str]", int],
|
||||
mode: OpenBinaryMode,
|
||||
buffering: int = ...,
|
||||
encoding: Optional[str] = ...,
|
||||
errors: Optional[str] = ...,
|
||||
newline: Optional[str] = ...,
|
||||
closefd: bool = ...,
|
||||
opener: Optional[Callable[[str, int], int]] = ...,
|
||||
) -> AsyncFile[bytes]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: Union[str, "PathLike[str]", int],
|
||||
mode: OpenTextMode = ...,
|
||||
buffering: int = ...,
|
||||
encoding: Optional[str] = ...,
|
||||
errors: Optional[str] = ...,
|
||||
newline: Optional[str] = ...,
|
||||
closefd: bool = ...,
|
||||
opener: Optional[Callable[[str, int], int]] = ...,
|
||||
) -> AsyncFile[str]:
|
||||
...
|
||||
|
||||
|
||||
async def open_file(
|
||||
file: Union[str, "PathLike[str]", int],
|
||||
mode: str = "r",
|
||||
buffering: int = -1,
|
||||
encoding: Optional[str] = None,
|
||||
errors: Optional[str] = None,
|
||||
newline: Optional[str] = None,
|
||||
closefd: bool = True,
|
||||
opener: Optional[Callable[[str, int], int]] = None,
|
||||
) -> AsyncFile[Any]:
|
||||
"""
|
||||
Open a file asynchronously.
|
||||
|
||||
The arguments are exactly the same as for the builtin :func:`open`.
|
||||
|
||||
:return: an asynchronous file object
|
||||
|
||||
"""
|
||||
fp = await to_thread.run_sync(
|
||||
open, file, mode, buffering, encoding, errors, newline, closefd, opener
|
||||
)
|
||||
return AsyncFile(fp)
|
||||
|
||||
|
||||
def wrap_file(file: IO[AnyStr]) -> AsyncFile[AnyStr]:
|
||||
"""
|
||||
Wrap an existing file as an asynchronous file.
|
||||
|
||||
:param file: an existing file-like object
|
||||
:return: an asynchronous file object
|
||||
|
||||
"""
|
||||
return AsyncFile(file)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class _PathIterator(AsyncIterator["Path"]):
|
||||
iterator: Iterator["PathLike[str]"]
|
||||
|
||||
async def __anext__(self) -> "Path":
|
||||
nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True)
|
||||
if nextval is None:
|
||||
raise StopAsyncIteration from None
|
||||
|
||||
return Path(cast("PathLike[str]", nextval))
|
||||
|
||||
|
||||
class Path:
|
||||
"""
|
||||
An asynchronous version of :class:`pathlib.Path`.
|
||||
|
||||
This class cannot be substituted for :class:`pathlib.Path` or :class:`pathlib.PurePath`, but
|
||||
it is compatible with the :class:`os.PathLike` interface.
|
||||
|
||||
It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for the
|
||||
deprecated :meth:`~pathlib.Path.link_to` method.
|
||||
|
||||
Any methods that do disk I/O need to be awaited on. These methods are:
|
||||
|
||||
* :meth:`~pathlib.Path.absolute`
|
||||
* :meth:`~pathlib.Path.chmod`
|
||||
* :meth:`~pathlib.Path.cwd`
|
||||
* :meth:`~pathlib.Path.exists`
|
||||
* :meth:`~pathlib.Path.expanduser`
|
||||
* :meth:`~pathlib.Path.group`
|
||||
* :meth:`~pathlib.Path.hardlink_to`
|
||||
* :meth:`~pathlib.Path.home`
|
||||
* :meth:`~pathlib.Path.is_block_device`
|
||||
* :meth:`~pathlib.Path.is_char_device`
|
||||
* :meth:`~pathlib.Path.is_dir`
|
||||
* :meth:`~pathlib.Path.is_fifo`
|
||||
* :meth:`~pathlib.Path.is_file`
|
||||
* :meth:`~pathlib.Path.is_mount`
|
||||
* :meth:`~pathlib.Path.lchmod`
|
||||
* :meth:`~pathlib.Path.lstat`
|
||||
* :meth:`~pathlib.Path.mkdir`
|
||||
* :meth:`~pathlib.Path.open`
|
||||
* :meth:`~pathlib.Path.owner`
|
||||
* :meth:`~pathlib.Path.read_bytes`
|
||||
* :meth:`~pathlib.Path.read_text`
|
||||
* :meth:`~pathlib.Path.readlink`
|
||||
* :meth:`~pathlib.Path.rename`
|
||||
* :meth:`~pathlib.Path.replace`
|
||||
* :meth:`~pathlib.Path.rmdir`
|
||||
* :meth:`~pathlib.Path.samefile`
|
||||
* :meth:`~pathlib.Path.stat`
|
||||
* :meth:`~pathlib.Path.touch`
|
||||
* :meth:`~pathlib.Path.unlink`
|
||||
* :meth:`~pathlib.Path.write_bytes`
|
||||
* :meth:`~pathlib.Path.write_text`
|
||||
|
||||
Additionally, the following methods return an async iterator yielding :class:`~.Path` objects:
|
||||
|
||||
* :meth:`~pathlib.Path.glob`
|
||||
* :meth:`~pathlib.Path.iterdir`
|
||||
* :meth:`~pathlib.Path.rglob`
|
||||
"""
|
||||
|
||||
__slots__ = "_path", "__weakref__"
|
||||
|
||||
__weakref__: Any
|
||||
|
||||
def __init__(self, *args: Union[str, "PathLike[str]"]) -> None:
|
||||
self._path: Final[pathlib.Path] = pathlib.Path(*args)
|
||||
|
||||
def __fspath__(self) -> str:
|
||||
return self._path.__fspath__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._path.__str__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.as_posix()!r})"
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self._path.__bytes__()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._path.__hash__()
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
target = other._path if isinstance(other, Path) else other
|
||||
return self._path.__eq__(target)
|
||||
|
||||
def __lt__(self, other: "Path") -> bool:
|
||||
target = other._path if isinstance(other, Path) else other
|
||||
return self._path.__lt__(target)
|
||||
|
||||
def __le__(self, other: "Path") -> bool:
|
||||
target = other._path if isinstance(other, Path) else other
|
||||
return self._path.__le__(target)
|
||||
|
||||
def __gt__(self, other: "Path") -> bool:
|
||||
target = other._path if isinstance(other, Path) else other
|
||||
return self._path.__gt__(target)
|
||||
|
||||
def __ge__(self, other: "Path") -> bool:
|
||||
target = other._path if isinstance(other, Path) else other
|
||||
return self._path.__ge__(target)
|
||||
|
||||
def __truediv__(self, other: Any) -> "Path":
|
||||
return Path(self._path / other)
|
||||
|
||||
def __rtruediv__(self, other: Any) -> "Path":
|
||||
return Path(other) / self
|
||||
|
||||
@property
|
||||
def parts(self) -> Tuple[str, ...]:
|
||||
return self._path.parts
|
||||
|
||||
@property
|
||||
def drive(self) -> str:
|
||||
return self._path.drive
|
||||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
return self._path.root
|
||||
|
||||
@property
|
||||
def anchor(self) -> str:
|
||||
return self._path.anchor
|
||||
|
||||
@property
|
||||
def parents(self) -> Sequence["Path"]:
|
||||
return tuple(Path(p) for p in self._path.parents)
|
||||
|
||||
@property
|
||||
def parent(self) -> "Path":
|
||||
return Path(self._path.parent)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._path.name
|
||||
|
||||
@property
|
||||
def suffix(self) -> str:
|
||||
return self._path.suffix
|
||||
|
||||
@property
|
||||
def suffixes(self) -> List[str]:
|
||||
return self._path.suffixes
|
||||
|
||||
@property
|
||||
def stem(self) -> str:
|
||||
return self._path.stem
|
||||
|
||||
async def absolute(self) -> "Path":
|
||||
path = await to_thread.run_sync(self._path.absolute)
|
||||
return Path(path)
|
||||
|
||||
def as_posix(self) -> str:
|
||||
return self._path.as_posix()
|
||||
|
||||
def as_uri(self) -> str:
|
||||
return self._path.as_uri()
|
||||
|
||||
def match(self, path_pattern: str) -> bool:
|
||||
return self._path.match(path_pattern)
|
||||
|
||||
def is_relative_to(self, *other: Union[str, "PathLike[str]"]) -> bool:
|
||||
try:
|
||||
self.relative_to(*other)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None:
|
||||
func = partial(os.chmod, follow_symlinks=follow_symlinks)
|
||||
return await to_thread.run_sync(func, self._path, mode)
|
||||
|
||||
@classmethod
|
||||
async def cwd(cls) -> "Path":
|
||||
path = await to_thread.run_sync(pathlib.Path.cwd)
|
||||
return cls(path)
|
||||
|
||||
async def exists(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.exists, cancellable=True)
|
||||
|
||||
async def expanduser(self) -> "Path":
|
||||
return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True))
|
||||
|
||||
def glob(self, pattern: str) -> AsyncIterator["Path"]:
|
||||
gen = self._path.glob(pattern)
|
||||
return _PathIterator(gen)
|
||||
|
||||
async def group(self) -> str:
|
||||
return await to_thread.run_sync(self._path.group, cancellable=True)
|
||||
|
||||
async def hardlink_to(self, target: Union[str, pathlib.Path, "Path"]) -> None:
|
||||
if isinstance(target, Path):
|
||||
target = target._path
|
||||
|
||||
await to_thread.run_sync(os.link, target, self)
|
||||
|
||||
@classmethod
|
||||
async def home(cls) -> "Path":
|
||||
home_path = await to_thread.run_sync(pathlib.Path.home)
|
||||
return cls(home_path)
|
||||
|
||||
def is_absolute(self) -> bool:
|
||||
return self._path.is_absolute()
|
||||
|
||||
async def is_block_device(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_block_device, cancellable=True)
|
||||
|
||||
async def is_char_device(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_char_device, cancellable=True)
|
||||
|
||||
async def is_dir(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_dir, cancellable=True)
|
||||
|
||||
async def is_fifo(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_fifo, cancellable=True)
|
||||
|
||||
async def is_file(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_file, cancellable=True)
|
||||
|
||||
async def is_mount(self) -> bool:
|
||||
return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True)
|
||||
|
||||
def is_reserved(self) -> bool:
|
||||
return self._path.is_reserved()
|
||||
|
||||
async def is_socket(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_socket, cancellable=True)
|
||||
|
||||
async def is_symlink(self) -> bool:
|
||||
return await to_thread.run_sync(self._path.is_symlink, cancellable=True)
|
||||
|
||||
def iterdir(self) -> AsyncIterator["Path"]:
|
||||
gen = self._path.iterdir()
|
||||
return _PathIterator(gen)
|
||||
|
||||
def joinpath(self, *args: Union[str, "PathLike[str]"]) -> "Path":
|
||||
return Path(self._path.joinpath(*args))
|
||||
|
||||
async def lchmod(self, mode: int) -> None:
|
||||
await to_thread.run_sync(self._path.lchmod, mode)
|
||||
|
||||
async def lstat(self) -> os.stat_result:
|
||||
return await to_thread.run_sync(self._path.lstat, cancellable=True)
|
||||
|
||||
async def mkdir(
|
||||
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
|
||||
) -> None:
|
||||
await to_thread.run_sync(self._path.mkdir, mode, parents, exist_ok)
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenBinaryMode,
|
||||
buffering: int = ...,
|
||||
encoding: Optional[str] = ...,
|
||||
errors: Optional[str] = ...,
|
||||
newline: Optional[str] = ...,
|
||||
) -> AsyncFile[bytes]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenTextMode = ...,
|
||||
buffering: int = ...,
|
||||
encoding: Optional[str] = ...,
|
||||
errors: Optional[str] = ...,
|
||||
newline: Optional[str] = ...,
|
||||
) -> AsyncFile[str]:
|
||||
...
|
||||
|
||||
async def open(
|
||||
self,
|
||||
mode: str = "r",
|
||||
buffering: int = -1,
|
||||
encoding: Optional[str] = None,
|
||||
errors: Optional[str] = None,
|
||||
newline: Optional[str] = None,
|
||||
) -> AsyncFile[Any]:
|
||||
fp = await to_thread.run_sync(
|
||||
self._path.open, mode, buffering, encoding, errors, newline
|
||||
)
|
||||
return AsyncFile(fp)
|
||||
|
||||
async def owner(self) -> str:
|
||||
return await to_thread.run_sync(self._path.owner, cancellable=True)
|
||||
|
||||
async def read_bytes(self) -> bytes:
|
||||
return await to_thread.run_sync(self._path.read_bytes)
|
||||
|
||||
async def read_text(
|
||||
self, encoding: Optional[str] = None, errors: Optional[str] = None
|
||||
) -> str:
|
||||
return await to_thread.run_sync(self._path.read_text, encoding, errors)
|
||||
|
||||
def relative_to(self, *other: Union[str, "PathLike[str]"]) -> "Path":
|
||||
return Path(self._path.relative_to(*other))
|
||||
|
||||
async def readlink(self) -> "Path":
|
||||
target = await to_thread.run_sync(os.readlink, self._path)
|
||||
return Path(cast(str, target))
|
||||
|
||||
async def rename(self, target: Union[str, pathlib.PurePath, "Path"]) -> "Path":
|
||||
if isinstance(target, Path):
|
||||
target = target._path
|
||||
|
||||
await to_thread.run_sync(self._path.rename, target)
|
||||
return Path(target)
|
||||
|
||||
async def replace(self, target: Union[str, pathlib.PurePath, "Path"]) -> "Path":
|
||||
if isinstance(target, Path):
|
||||
target = target._path
|
||||
|
||||
await to_thread.run_sync(self._path.replace, target)
|
||||
return Path(target)
|
||||
|
||||
async def resolve(self, strict: bool = False) -> "Path":
|
||||
func = partial(self._path.resolve, strict=strict)
|
||||
return Path(await to_thread.run_sync(func, cancellable=True))
|
||||
|
||||
def rglob(self, pattern: str) -> AsyncIterator["Path"]:
|
||||
gen = self._path.rglob(pattern)
|
||||
return _PathIterator(gen)
|
||||
|
||||
async def rmdir(self) -> None:
|
||||
await to_thread.run_sync(self._path.rmdir)
|
||||
|
||||
async def samefile(
|
||||
self, other_path: Union[str, bytes, int, pathlib.Path, "Path"]
|
||||
) -> bool:
|
||||
if isinstance(other_path, Path):
|
||||
other_path = other_path._path
|
||||
|
||||
return await to_thread.run_sync(
|
||||
self._path.samefile, other_path, cancellable=True
|
||||
)
|
||||
|
||||
async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
|
||||
func = partial(os.stat, follow_symlinks=follow_symlinks)
|
||||
return await to_thread.run_sync(func, self._path, cancellable=True)
|
||||
|
||||
async def symlink_to(
|
||||
self,
|
||||
target: Union[str, pathlib.Path, "Path"],
|
||||
target_is_directory: bool = False,
|
||||
) -> None:
|
||||
if isinstance(target, Path):
|
||||
target = target._path
|
||||
|
||||
await to_thread.run_sync(self._path.symlink_to, target, target_is_directory)
|
||||
|
||||
async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None:
|
||||
await to_thread.run_sync(self._path.touch, mode, exist_ok)
|
||||
|
||||
async def unlink(self, missing_ok: bool = False) -> None:
|
||||
try:
|
||||
await to_thread.run_sync(self._path.unlink)
|
||||
except FileNotFoundError:
|
||||
if not missing_ok:
|
||||
raise
|
||||
|
||||
def with_name(self, name: str) -> "Path":
|
||||
return Path(self._path.with_name(name))
|
||||
|
||||
def with_stem(self, stem: str) -> "Path":
|
||||
return Path(self._path.with_name(stem + self._path.suffix))
|
||||
|
||||
def with_suffix(self, suffix: str) -> "Path":
|
||||
return Path(self._path.with_suffix(suffix))
|
||||
|
||||
async def write_bytes(self, data: bytes) -> int:
|
||||
return await to_thread.run_sync(self._path.write_bytes, data)
|
||||
|
||||
async def write_text(
|
||||
self,
|
||||
data: str,
|
||||
encoding: Optional[str] = None,
|
||||
errors: Optional[str] = None,
|
||||
newline: Optional[str] = None,
|
||||
) -> int:
|
||||
# Path.write_text() does not support the "newline" parameter before Python 3.10
|
||||
def sync_write_text() -> int:
|
||||
with self._path.open(
|
||||
"w", encoding=encoding, errors=errors, newline=newline
|
||||
) as fp:
|
||||
return fp.write(data)
|
||||
|
||||
return await to_thread.run_sync(sync_write_text)
|
||||
|
||||
|
||||
PathLike.register(Path)
|
@ -0,0 +1,16 @@
|
||||
from ..abc import AsyncResource
|
||||
from ._tasks import CancelScope
|
||||
|
||||
|
||||
async def aclose_forcefully(resource: AsyncResource) -> None:
|
||||
"""
|
||||
Close an asynchronous resource in a cancelled scope.
|
||||
|
||||
Doing this closes the resource without waiting on anything.
|
||||
|
||||
:param resource: the resource to close
|
||||
|
||||
"""
|
||||
with CancelScope() as scope:
|
||||
scope.cancel()
|
||||
await resource.aclose()
|
@ -0,0 +1,24 @@
|
||||
from typing import AsyncIterator
|
||||
|
||||
from ._compat import DeprecatedAsyncContextManager
|
||||
from ._eventloop import get_asynclib
|
||||
|
||||
|
||||
def open_signal_receiver(
|
||||
*signals: int,
|
||||
) -> DeprecatedAsyncContextManager[AsyncIterator[int]]:
|
||||
"""
|
||||
Start receiving operating system signals.
|
||||
|
||||
:param signals: signals to receive (e.g. ``signal.SIGINT``)
|
||||
:return: an asynchronous context manager for an asynchronous iterator which yields signal
|
||||
numbers
|
||||
|
||||
.. warning:: Windows does not support signals natively so it is best to avoid relying on this
|
||||
in cross-platform applications.
|
||||
|
||||
.. warning:: On asyncio, this permanently replaces any previous signal handler for the given
|
||||
signals, as set via :meth:`~asyncio.loop.add_signal_handler`.
|
||||
|
||||
"""
|
||||
return get_asynclib().open_signal_receiver(*signals)
|
@ -0,0 +1,45 @@
|
||||
import math
|
||||
from typing import Any, Optional, Tuple, Type, TypeVar, overload
|
||||
|
||||
from ..streams.memory import (
|
||||
MemoryObjectReceiveStream,
|
||||
MemoryObjectSendStream,
|
||||
MemoryObjectStreamState,
|
||||
)
|
||||
|
||||
T_Item = TypeVar("T_Item")
|
||||
|
||||
|
||||
@overload
|
||||
def create_memory_object_stream(
|
||||
max_buffer_size: float, item_type: Type[T_Item]
|
||||
) -> Tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def create_memory_object_stream(
|
||||
max_buffer_size: float = 0,
|
||||
) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def create_memory_object_stream(
|
||||
max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None
|
||||
) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]:
|
||||
"""
|
||||
Create a memory object stream.
|
||||
|
||||
:param max_buffer_size: number of items held in the buffer until ``send()`` starts blocking
|
||||
:param item_type: type of item, for marking the streams with the right generic type for
|
||||
static typing (not used at run time)
|
||||
:return: a tuple of (send stream, receive stream)
|
||||
|
||||
"""
|
||||
if max_buffer_size != math.inf and not isinstance(max_buffer_size, int):
|
||||
raise ValueError("max_buffer_size must be either an integer or math.inf")
|
||||
if max_buffer_size < 0:
|
||||
raise ValueError("max_buffer_size cannot be negative")
|
||||
|
||||
state: MemoryObjectStreamState = MemoryObjectStreamState(max_buffer_size)
|
||||
return MemoryObjectSendStream(state), MemoryObjectReceiveStream(state)
|
@ -0,0 +1,136 @@
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess
|
||||
from typing import (
|
||||
IO,
|
||||
Any,
|
||||
AsyncIterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..abc import Process
|
||||
from ._eventloop import get_asynclib
|
||||
from ._tasks import create_task_group
|
||||
|
||||
|
||||
async def run_process(
|
||||
command: Union[str, bytes, Sequence[Union[str, bytes]]],
|
||||
*,
|
||||
input: Optional[bytes] = None,
|
||||
stdout: Union[int, IO[Any], None] = PIPE,
|
||||
stderr: Union[int, IO[Any], None] = PIPE,
|
||||
check: bool = True,
|
||||
cwd: Union[str, bytes, "PathLike[str]", None] = None,
|
||||
env: Optional[Mapping[str, str]] = None,
|
||||
start_new_session: bool = False,
|
||||
) -> "CompletedProcess[bytes]":
|
||||
"""
|
||||
Run an external command in a subprocess and wait until it completes.
|
||||
|
||||
.. seealso:: :func:`subprocess.run`
|
||||
|
||||
:param command: either a string to pass to the shell, or an iterable of strings containing the
|
||||
executable name or path and its arguments
|
||||
:param input: bytes passed to the standard input of the subprocess
|
||||
:param stdout: either :data:`subprocess.PIPE` or :data:`subprocess.DEVNULL`
|
||||
:param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL` or
|
||||
:data:`subprocess.STDOUT`
|
||||
:param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the process
|
||||
terminates with a return code other than 0
|
||||
:param cwd: If not ``None``, change the working directory to this before running the command
|
||||
:param env: if not ``None``, this mapping replaces the inherited environment variables from the
|
||||
parent process
|
||||
:param start_new_session: if ``true`` the setsid() system call will be made in the child
|
||||
process prior to the execution of the subprocess. (POSIX only)
|
||||
:return: an object representing the completed process
|
||||
:raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process exits with a
|
||||
nonzero return code
|
||||
|
||||
"""
|
||||
|
||||
async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None:
|
||||
buffer = BytesIO()
|
||||
async for chunk in stream:
|
||||
buffer.write(chunk)
|
||||
|
||||
stream_contents[index] = buffer.getvalue()
|
||||
|
||||
async with await open_process(
|
||||
command,
|
||||
stdin=PIPE if input else DEVNULL,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
) as process:
|
||||
stream_contents: List[Optional[bytes]] = [None, None]
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
if process.stdout:
|
||||
tg.start_soon(drain_stream, process.stdout, 0)
|
||||
if process.stderr:
|
||||
tg.start_soon(drain_stream, process.stderr, 1)
|
||||
if process.stdin and input:
|
||||
await process.stdin.send(input)
|
||||
await process.stdin.aclose()
|
||||
|
||||
await process.wait()
|
||||
except BaseException:
|
||||
process.kill()
|
||||
raise
|
||||
|
||||
output, errors = stream_contents
|
||||
if check and process.returncode != 0:
|
||||
raise CalledProcessError(cast(int, process.returncode), command, output, errors)
|
||||
|
||||
return CompletedProcess(command, cast(int, process.returncode), output, errors)
|
||||
|
||||
|
||||
async def open_process(
|
||||
command: Union[str, bytes, Sequence[Union[str, bytes]]],
|
||||
*,
|
||||
stdin: Union[int, IO[Any], None] = PIPE,
|
||||
stdout: Union[int, IO[Any], None] = PIPE,
|
||||
stderr: Union[int, IO[Any], None] = PIPE,
|
||||
cwd: Union[str, bytes, "PathLike[str]", None] = None,
|
||||
env: Optional[Mapping[str, str]] = None,
|
||||
start_new_session: bool = False,
|
||||
) -> Process:
|
||||
"""
|
||||
Start an external command in a subprocess.
|
||||
|
||||
.. seealso:: :class:`subprocess.Popen`
|
||||
|
||||
:param command: either a string to pass to the shell, or an iterable of strings containing the
|
||||
executable name or path and its arguments
|
||||
:param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, a
|
||||
file-like object, or ``None``
|
||||
:param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
|
||||
a file-like object, or ``None``
|
||||
:param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
|
||||
:data:`subprocess.STDOUT`, a file-like object, or ``None``
|
||||
:param cwd: If not ``None``, the working directory is changed before executing
|
||||
:param env: If env is not ``None``, it must be a mapping that defines the environment
|
||||
variables for the new process
|
||||
:param start_new_session: if ``true`` the setsid() system call will be made in the child
|
||||
process prior to the execution of the subprocess. (POSIX only)
|
||||
:return: an asynchronous process object
|
||||
|
||||
"""
|
||||
shell = isinstance(command, str)
|
||||
return await get_asynclib().open_process(
|
||||
command,
|
||||
shell=shell,
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
)
|
@ -0,0 +1,595 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from types import TracebackType
|
||||
from typing import Deque, Optional, Tuple, Type
|
||||
from warnings import warn
|
||||
|
||||
from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
|
||||
from ._compat import DeprecatedAwaitable
|
||||
from ._eventloop import get_asynclib
|
||||
from ._exceptions import BusyResourceError, WouldBlock
|
||||
from ._tasks import CancelScope
|
||||
from ._testing import TaskInfo, get_current_task
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EventStatistics:
|
||||
"""
|
||||
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait`
|
||||
"""
|
||||
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CapacityLimiterStatistics:
|
||||
"""
|
||||
:ivar int borrowed_tokens: number of tokens currently borrowed by tasks
|
||||
:ivar float total_tokens: total number of available tokens
|
||||
:ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this
|
||||
limiter
|
||||
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or
|
||||
:meth:`~.CapacityLimiter.acquire_on_behalf_of`
|
||||
"""
|
||||
|
||||
borrowed_tokens: int
|
||||
total_tokens: float
|
||||
borrowers: Tuple[object, ...]
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LockStatistics:
|
||||
"""
|
||||
:ivar bool locked: flag indicating if this lock is locked or not
|
||||
:ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not
|
||||
held by any task)
|
||||
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire`
|
||||
"""
|
||||
|
||||
locked: bool
|
||||
owner: Optional[TaskInfo]
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditionStatistics:
|
||||
"""
|
||||
:ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait`
|
||||
:ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock`
|
||||
"""
|
||||
|
||||
tasks_waiting: int
|
||||
lock_statistics: LockStatistics
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemaphoreStatistics:
|
||||
"""
|
||||
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire`
|
||||
|
||||
"""
|
||||
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
class Event:
|
||||
def __new__(cls) -> "Event":
|
||||
return get_asynclib().Event()
|
||||
|
||||
def set(self) -> DeprecatedAwaitable:
|
||||
"""Set the flag, notifying all listeners."""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_set(self) -> bool:
|
||||
"""Return ``True`` if the flag is set, ``False`` if not."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""
|
||||
Wait until the flag has been set.
|
||||
|
||||
If the flag has already been set when this method is called, it returns immediately.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def statistics(self) -> EventStatistics:
|
||||
"""Return statistics about the current state of this event."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Lock:
|
||||
_owner_task: Optional[TaskInfo] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._waiters: Deque[Tuple[TaskInfo, Event]] = deque()
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.release()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Acquire the lock."""
|
||||
await checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_nowait()
|
||||
except WouldBlock:
|
||||
task = get_current_task()
|
||||
event = Event()
|
||||
token = task, event
|
||||
self._waiters.append(token)
|
||||
try:
|
||||
await event.wait()
|
||||
except BaseException:
|
||||
if not event.is_set():
|
||||
self._waiters.remove(token)
|
||||
elif self._owner_task == task:
|
||||
self.release()
|
||||
|
||||
raise
|
||||
|
||||
assert self._owner_task == task
|
||||
else:
|
||||
try:
|
||||
await cancel_shielded_checkpoint()
|
||||
except BaseException:
|
||||
self.release()
|
||||
raise
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""
|
||||
Acquire the lock, without blocking.
|
||||
|
||||
:raises ~WouldBlock: if the operation would block
|
||||
|
||||
"""
|
||||
task = get_current_task()
|
||||
if self._owner_task == task:
|
||||
raise RuntimeError("Attempted to acquire an already held Lock")
|
||||
|
||||
if self._owner_task is not None:
|
||||
raise WouldBlock
|
||||
|
||||
self._owner_task = task
|
||||
|
||||
def release(self) -> DeprecatedAwaitable:
|
||||
"""Release the lock."""
|
||||
if self._owner_task != get_current_task():
|
||||
raise RuntimeError("The current task is not holding this lock")
|
||||
|
||||
if self._waiters:
|
||||
self._owner_task, event = self._waiters.popleft()
|
||||
event.set()
|
||||
else:
|
||||
del self._owner_task
|
||||
|
||||
return DeprecatedAwaitable(self.release)
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Return True if the lock is currently held."""
|
||||
return self._owner_task is not None
|
||||
|
||||
def statistics(self) -> LockStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this lock.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return LockStatistics(self.locked(), self._owner_task, len(self._waiters))
|
||||
|
||||
|
||||
class Condition:
|
||||
_owner_task: Optional[TaskInfo] = None
|
||||
|
||||
def __init__(self, lock: Optional[Lock] = None):
|
||||
self._lock = lock or Lock()
|
||||
self._waiters: Deque[Event] = deque()
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.release()
|
||||
|
||||
def _check_acquired(self) -> None:
|
||||
if self._owner_task != get_current_task():
|
||||
raise RuntimeError("The current task is not holding the underlying lock")
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Acquire the underlying lock."""
|
||||
await self._lock.acquire()
|
||||
self._owner_task = get_current_task()
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""
|
||||
Acquire the underlying lock, without blocking.
|
||||
|
||||
:raises ~WouldBlock: if the operation would block
|
||||
|
||||
"""
|
||||
self._lock.acquire_nowait()
|
||||
self._owner_task = get_current_task()
|
||||
|
||||
def release(self) -> DeprecatedAwaitable:
|
||||
"""Release the underlying lock."""
|
||||
self._lock.release()
|
||||
return DeprecatedAwaitable(self.release)
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Return True if the lock is set."""
|
||||
return self._lock.locked()
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
"""Notify exactly n listeners."""
|
||||
self._check_acquired()
|
||||
for _ in range(n):
|
||||
try:
|
||||
event = self._waiters.popleft()
|
||||
except IndexError:
|
||||
break
|
||||
|
||||
event.set()
|
||||
|
||||
def notify_all(self) -> None:
|
||||
"""Notify all the listeners."""
|
||||
self._check_acquired()
|
||||
for event in self._waiters:
|
||||
event.set()
|
||||
|
||||
self._waiters.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""Wait for a notification."""
|
||||
await checkpoint()
|
||||
event = Event()
|
||||
self._waiters.append(event)
|
||||
self.release()
|
||||
try:
|
||||
await event.wait()
|
||||
except BaseException:
|
||||
if not event.is_set():
|
||||
self._waiters.remove(event)
|
||||
|
||||
raise
|
||||
finally:
|
||||
with CancelScope(shield=True):
|
||||
await self.acquire()
|
||||
|
||||
def statistics(self) -> ConditionStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this condition.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return ConditionStatistics(len(self._waiters), self._lock.statistics())
|
||||
|
||||
|
||||
class Semaphore:
|
||||
def __init__(self, initial_value: int, *, max_value: Optional[int] = None):
|
||||
if not isinstance(initial_value, int):
|
||||
raise TypeError("initial_value must be an integer")
|
||||
if initial_value < 0:
|
||||
raise ValueError("initial_value must be >= 0")
|
||||
if max_value is not None:
|
||||
if not isinstance(max_value, int):
|
||||
raise TypeError("max_value must be an integer or None")
|
||||
if max_value < initial_value:
|
||||
raise ValueError(
|
||||
"max_value must be equal to or higher than initial_value"
|
||||
)
|
||||
|
||||
self._value = initial_value
|
||||
self._max_value = max_value
|
||||
self._waiters: Deque[Event] = deque()
|
||||
|
||||
async def __aenter__(self) -> "Semaphore":
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.release()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Decrement the semaphore value, blocking if necessary."""
|
||||
await checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_nowait()
|
||||
except WouldBlock:
|
||||
event = Event()
|
||||
self._waiters.append(event)
|
||||
try:
|
||||
await event.wait()
|
||||
except BaseException:
|
||||
if not event.is_set():
|
||||
self._waiters.remove(event)
|
||||
else:
|
||||
self.release()
|
||||
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
await cancel_shielded_checkpoint()
|
||||
except BaseException:
|
||||
self.release()
|
||||
raise
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""
|
||||
Acquire the underlying lock, without blocking.
|
||||
|
||||
:raises ~WouldBlock: if the operation would block
|
||||
|
||||
"""
|
||||
if self._value == 0:
|
||||
raise WouldBlock
|
||||
|
||||
self._value -= 1
|
||||
|
||||
def release(self) -> DeprecatedAwaitable:
|
||||
"""Increment the semaphore value."""
|
||||
if self._max_value is not None and self._value == self._max_value:
|
||||
raise ValueError("semaphore released too many times")
|
||||
|
||||
if self._waiters:
|
||||
self._waiters.popleft().set()
|
||||
else:
|
||||
self._value += 1
|
||||
|
||||
return DeprecatedAwaitable(self.release)
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
"""The current value of the semaphore."""
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def max_value(self) -> Optional[int]:
|
||||
"""The maximum value of the semaphore."""
|
||||
return self._max_value
|
||||
|
||||
def statistics(self) -> SemaphoreStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this semaphore.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return SemaphoreStatistics(len(self._waiters))
|
||||
|
||||
|
||||
class CapacityLimiter:
|
||||
def __new__(cls, total_tokens: float) -> "CapacityLimiter":
|
||||
return get_asynclib().CapacityLimiter(total_tokens)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> float:
|
||||
"""
|
||||
The total number of tokens available for borrowing.
|
||||
|
||||
This is a read-write property. If the total number of tokens is increased, the
|
||||
proportionate number of tasks waiting on this limiter will be granted their tokens.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
The property is now writable.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: float) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def set_total_tokens(self, value: float) -> None:
|
||||
warn(
|
||||
"CapacityLimiter.set_total_tokens has been deprecated. Set the value of the"
|
||||
'"total_tokens" attribute directly.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.total_tokens = value
|
||||
|
||||
@property
|
||||
def borrowed_tokens(self) -> int:
|
||||
"""The number of tokens that have currently been borrowed."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def available_tokens(self) -> float:
|
||||
"""The number of tokens currently available to be borrowed"""
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_nowait(self) -> DeprecatedAwaitable:
|
||||
"""
|
||||
Acquire a token for the current task without waiting for one to become available.
|
||||
|
||||
:raises ~anyio.WouldBlock: if there are no tokens available for borrowing
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
|
||||
"""
|
||||
Acquire a token without waiting for one to become available.
|
||||
|
||||
:param borrower: the entity borrowing a token
|
||||
:raises ~anyio.WouldBlock: if there are no tokens available for borrowing
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""
|
||||
Acquire a token for the current task, waiting if necessary for one to become available.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def acquire_on_behalf_of(self, borrower: object) -> None:
|
||||
"""
|
||||
Acquire a token, waiting if necessary for one to become available.
|
||||
|
||||
:param borrower: the entity borrowing a token
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Release the token held by the current task.
|
||||
:raises RuntimeError: if the current task has not borrowed a token from this limiter.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def release_on_behalf_of(self, borrower: object) -> None:
|
||||
"""
|
||||
Release the token held by the given borrower.
|
||||
|
||||
:raises RuntimeError: if the borrower has not borrowed a token from this limiter.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def statistics(self) -> CapacityLimiterStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this limiter.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_lock() -> Lock:
|
||||
"""
|
||||
Create an asynchronous lock.
|
||||
|
||||
:return: a lock object
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`~Lock` directly.
|
||||
|
||||
"""
|
||||
warn("create_lock() is deprecated -- use Lock() directly", DeprecationWarning)
|
||||
return Lock()
|
||||
|
||||
|
||||
def create_condition(lock: Optional[Lock] = None) -> Condition:
|
||||
"""
|
||||
Create an asynchronous condition.
|
||||
|
||||
:param lock: the lock to base the condition object on
|
||||
:return: a condition object
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`~Condition` directly.
|
||||
|
||||
"""
|
||||
warn(
|
||||
"create_condition() is deprecated -- use Condition() directly",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Condition(lock=lock)
|
||||
|
||||
|
||||
def create_event() -> Event:
|
||||
"""
|
||||
Create an asynchronous event object.
|
||||
|
||||
:return: an event object
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`~Event` directly.
|
||||
|
||||
"""
|
||||
warn("create_event() is deprecated -- use Event() directly", DeprecationWarning)
|
||||
return get_asynclib().Event()
|
||||
|
||||
|
||||
def create_semaphore(value: int, *, max_value: Optional[int] = None) -> Semaphore:
|
||||
"""
|
||||
Create an asynchronous semaphore.
|
||||
|
||||
:param value: the semaphore's initial value
|
||||
:param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the
|
||||
semaphore's value would exceed this number
|
||||
:return: a semaphore object
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`~Semaphore` directly.
|
||||
|
||||
"""
|
||||
warn(
|
||||
"create_semaphore() is deprecated -- use Semaphore() directly",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Semaphore(value, max_value=max_value)
|
||||
|
||||
|
||||
def create_capacity_limiter(total_tokens: float) -> CapacityLimiter:
|
||||
"""
|
||||
Create a capacity limiter.
|
||||
|
||||
:param total_tokens: the total number of tokens available for borrowing (can be an integer or
|
||||
:data:`math.inf`)
|
||||
:return: a capacity limiter object
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`~CapacityLimiter` directly.
|
||||
|
||||
"""
|
||||
warn(
|
||||
"create_capacity_limiter() is deprecated -- use CapacityLimiter() directly",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return get_asynclib().CapacityLimiter(total_tokens)
|
||||
|
||||
|
||||
class ResourceGuard:
|
||||
__slots__ = "action", "_guarded"
|
||||
|
||||
def __init__(self, action: str):
|
||||
self.action = action
|
||||
self._guarded = False
|
||||
|
||||
def __enter__(self) -> None:
|
||||
if self._guarded:
|
||||
raise BusyResourceError(self.action)
|
||||
|
||||
self._guarded = True
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._guarded = False
|
||||
return None
|
@ -0,0 +1,178 @@
|
||||
import math
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
from warnings import warn
|
||||
|
||||
from ..abc._tasks import TaskGroup, TaskStatus
|
||||
from ._compat import (
|
||||
DeprecatedAsyncContextManager,
|
||||
DeprecatedAwaitable,
|
||||
DeprecatedAwaitableFloat,
|
||||
)
|
||||
from ._eventloop import get_asynclib
|
||||
|
||||
|
||||
class _IgnoredTaskStatus(TaskStatus):
|
||||
def started(self, value: object = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
TASK_STATUS_IGNORED = _IgnoredTaskStatus()
|
||||
|
||||
|
||||
class CancelScope(DeprecatedAsyncContextManager["CancelScope"]):
|
||||
"""
|
||||
Wraps a unit of work that can be made separately cancellable.
|
||||
|
||||
:param deadline: The time (clock value) when this scope is cancelled automatically
|
||||
:param shield: ``True`` to shield the cancel scope from external cancellation
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
cls, *, deadline: float = math.inf, shield: bool = False
|
||||
) -> "CancelScope":
|
||||
return get_asynclib().CancelScope(shield=shield, deadline=deadline)
|
||||
|
||||
def cancel(self) -> DeprecatedAwaitable:
|
||||
"""Cancel this scope immediately."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def deadline(self) -> float:
|
||||
"""
|
||||
The time (clock value) when this scope is cancelled automatically.
|
||||
|
||||
Will be ``float('inf')`` if no timeout has been set.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deadline.setter
|
||||
def deadline(self, value: float) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def cancel_called(self) -> bool:
|
||||
"""``True`` if :meth:`cancel` has been called."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def shield(self) -> bool:
|
||||
"""
|
||||
``True`` if this scope is shielded from external cancellation.
|
||||
|
||||
While a scope is shielded, it will not receive cancellations from outside.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@shield.setter
|
||||
def shield(self, value: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CancelScope":
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def open_cancel_scope(*, shield: bool = False) -> CancelScope:
|
||||
"""
|
||||
Open a cancel scope.
|
||||
|
||||
:param shield: ``True`` to shield the cancel scope from external cancellation
|
||||
:return: a cancel scope
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`~CancelScope` directly.
|
||||
|
||||
"""
|
||||
warn(
|
||||
"open_cancel_scope() is deprecated -- use CancelScope() directly",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return get_asynclib().CancelScope(shield=shield)
|
||||
|
||||
|
||||
class FailAfterContextManager(DeprecatedAsyncContextManager[CancelScope]):
|
||||
def __init__(self, cancel_scope: CancelScope):
|
||||
self._cancel_scope = cancel_scope
|
||||
|
||||
def __enter__(self) -> CancelScope:
|
||||
return self._cancel_scope.__enter__()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
if self._cancel_scope.cancel_called:
|
||||
raise TimeoutError
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def fail_after(delay: Optional[float], shield: bool = False) -> FailAfterContextManager:
|
||||
"""
|
||||
Create a context manager which raises a :class:`TimeoutError` if does not finish in time.
|
||||
|
||||
:param delay: maximum allowed time (in seconds) before raising the exception, or ``None`` to
|
||||
disable the timeout
|
||||
:param shield: ``True`` to shield the cancel scope from external cancellation
|
||||
:return: a context manager that yields a cancel scope
|
||||
:rtype: :class:`~typing.ContextManager`\\[:class:`~anyio.abc.CancelScope`\\]
|
||||
|
||||
"""
|
||||
deadline = (
|
||||
(get_asynclib().current_time() + delay) if delay is not None else math.inf
|
||||
)
|
||||
cancel_scope = get_asynclib().CancelScope(deadline=deadline, shield=shield)
|
||||
return FailAfterContextManager(cancel_scope)
|
||||
|
||||
|
||||
def move_on_after(delay: Optional[float], shield: bool = False) -> CancelScope:
|
||||
"""
|
||||
Create a cancel scope with a deadline that expires after the given delay.
|
||||
|
||||
:param delay: maximum allowed time (in seconds) before exiting the context block, or ``None``
|
||||
to disable the timeout
|
||||
:param shield: ``True`` to shield the cancel scope from external cancellation
|
||||
:return: a cancel scope
|
||||
|
||||
"""
|
||||
deadline = (
|
||||
(get_asynclib().current_time() + delay) if delay is not None else math.inf
|
||||
)
|
||||
return get_asynclib().CancelScope(deadline=deadline, shield=shield)
|
||||
|
||||
|
||||
def current_effective_deadline() -> DeprecatedAwaitableFloat:
|
||||
"""
|
||||
Return the nearest deadline among all the cancel scopes effective for the current task.
|
||||
|
||||
:return: a clock value from the event loop's internal clock (``float('inf')`` if there is no
|
||||
deadline in effect)
|
||||
:rtype: float
|
||||
|
||||
"""
|
||||
return DeprecatedAwaitableFloat(
|
||||
get_asynclib().current_effective_deadline(), current_effective_deadline
|
||||
)
|
||||
|
||||
|
||||
def create_task_group() -> "TaskGroup":
|
||||
"""
|
||||
Create a task group.
|
||||
|
||||
:return: a task group
|
||||
|
||||
"""
|
||||
return get_asynclib().TaskGroup()
|
@ -0,0 +1,80 @@
|
||||
from typing import Any, Awaitable, Generator, Optional, Union
|
||||
|
||||
from ._compat import DeprecatedAwaitableList, _warn_deprecation
|
||||
from ._eventloop import get_asynclib
|
||||
|
||||
|
||||
class TaskInfo:
|
||||
"""
|
||||
Represents an asynchronous task.
|
||||
|
||||
:ivar int id: the unique identifier of the task
|
||||
:ivar parent_id: the identifier of the parent task, if any
|
||||
:vartype parent_id: Optional[int]
|
||||
:ivar str name: the description of the task (if any)
|
||||
:ivar ~collections.abc.Coroutine coro: the coroutine object of the task
|
||||
"""
|
||||
|
||||
__slots__ = "_name", "id", "parent_id", "name", "coro"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
parent_id: Optional[int],
|
||||
name: Optional[str],
|
||||
coro: Union[Generator, Awaitable[Any]],
|
||||
):
|
||||
func = get_current_task
|
||||
self._name = f"{func.__module__}.{func.__qualname__}"
|
||||
self.id: int = id
|
||||
self.parent_id: Optional[int] = parent_id
|
||||
self.name: Optional[str] = name
|
||||
self.coro: Union[Generator, Awaitable[Any]] = coro
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, TaskInfo):
|
||||
return self.id == other.id
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"
|
||||
|
||||
def __await__(self) -> Generator[None, None, "TaskInfo"]:
|
||||
_warn_deprecation(self)
|
||||
if False:
|
||||
yield
|
||||
|
||||
return self
|
||||
|
||||
def _unwrap(self) -> "TaskInfo":
|
||||
return self
|
||||
|
||||
|
||||
def get_current_task() -> TaskInfo:
|
||||
"""
|
||||
Return the current task.
|
||||
|
||||
:return: a representation of the current task
|
||||
|
||||
"""
|
||||
return get_asynclib().get_current_task()
|
||||
|
||||
|
||||
def get_running_tasks() -> DeprecatedAwaitableList[TaskInfo]:
|
||||
"""
|
||||
Return a list of running tasks in the current event loop.
|
||||
|
||||
:return: a list of task info objects
|
||||
|
||||
"""
|
||||
tasks = get_asynclib().get_running_tasks()
|
||||
return DeprecatedAwaitableList(tasks, func=get_running_tasks)
|
||||
|
||||
|
||||
async def wait_all_tasks_blocked() -> None:
|
||||
"""Wait until all other tasks are waiting for something."""
|
||||
await get_asynclib().wait_all_tasks_blocked()
|
@ -0,0 +1,81 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Mapping, TypeVar, Union, overload
|
||||
|
||||
from ._exceptions import TypedAttributeLookupError
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import final
|
||||
else:
|
||||
from typing_extensions import final
|
||||
|
||||
T_Attr = TypeVar("T_Attr")
|
||||
T_Default = TypeVar("T_Default")
|
||||
undefined = object()
|
||||
|
||||
|
||||
def typed_attribute() -> Any:
|
||||
"""Return a unique object, used to mark typed attributes."""
|
||||
return object()
|
||||
|
||||
|
||||
class TypedAttributeSet:
|
||||
"""
|
||||
Superclass for typed attribute collections.
|
||||
|
||||
Checks that every public attribute of every subclass has a type annotation.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
annotations: Dict[str, Any] = getattr(cls, "__annotations__", {})
|
||||
for attrname in dir(cls):
|
||||
if not attrname.startswith("_") and attrname not in annotations:
|
||||
raise TypeError(
|
||||
f"Attribute {attrname!r} is missing its type annotation"
|
||||
)
|
||||
|
||||
super().__init_subclass__()
|
||||
|
||||
|
||||
class TypedAttributeProvider:
|
||||
"""Base class for classes that wish to provide typed extra attributes."""
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[T_Attr, Callable[[], T_Attr]]:
|
||||
"""
|
||||
A mapping of the extra attributes to callables that return the corresponding values.
|
||||
|
||||
If the provider wraps another provider, the attributes from that wrapper should also be
|
||||
included in the returned mapping (but the wrapper may override the callables from the
|
||||
wrapped instance).
|
||||
|
||||
"""
|
||||
return {}
|
||||
|
||||
@overload
|
||||
def extra(self, attribute: T_Attr) -> T_Attr:
|
||||
...
|
||||
|
||||
@overload
|
||||
def extra(self, attribute: T_Attr, default: T_Default) -> Union[T_Attr, T_Default]:
|
||||
...
|
||||
|
||||
@final
|
||||
def extra(self, attribute: Any, default: object = undefined) -> object:
|
||||
"""
|
||||
extra(attribute, default=undefined)
|
||||
|
||||
Return the value of the given typed extra attribute.
|
||||
|
||||
:param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to look for
|
||||
:param default: the value that should be returned if no value is found for the attribute
|
||||
:raises ~anyio.TypedAttributeLookupError: if the search failed and no default value was
|
||||
given
|
||||
|
||||
"""
|
||||
try:
|
||||
return self.extra_attributes[attribute]()
|
||||
except KeyError:
|
||||
if default is undefined:
|
||||
raise TypedAttributeLookupError("Attribute not found") from None
|
||||
else:
|
||||
return default
|
@ -0,0 +1,88 @@
|
||||
__all__ = (
|
||||
"AsyncResource",
|
||||
"IPAddressType",
|
||||
"IPSockAddrType",
|
||||
"SocketAttribute",
|
||||
"SocketStream",
|
||||
"SocketListener",
|
||||
"UDPSocket",
|
||||
"UNIXSocketStream",
|
||||
"UDPPacketType",
|
||||
"ConnectedUDPSocket",
|
||||
"UnreliableObjectReceiveStream",
|
||||
"UnreliableObjectSendStream",
|
||||
"UnreliableObjectStream",
|
||||
"ObjectReceiveStream",
|
||||
"ObjectSendStream",
|
||||
"ObjectStream",
|
||||
"ByteReceiveStream",
|
||||
"ByteSendStream",
|
||||
"ByteStream",
|
||||
"AnyUnreliableByteReceiveStream",
|
||||
"AnyUnreliableByteSendStream",
|
||||
"AnyUnreliableByteStream",
|
||||
"AnyByteReceiveStream",
|
||||
"AnyByteSendStream",
|
||||
"AnyByteStream",
|
||||
"Listener",
|
||||
"Process",
|
||||
"Event",
|
||||
"Condition",
|
||||
"Lock",
|
||||
"Semaphore",
|
||||
"CapacityLimiter",
|
||||
"CancelScope",
|
||||
"TaskGroup",
|
||||
"TaskStatus",
|
||||
"TestRunner",
|
||||
"BlockingPortal",
|
||||
)
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._resources import AsyncResource
|
||||
from ._sockets import (
|
||||
ConnectedUDPSocket,
|
||||
IPAddressType,
|
||||
IPSockAddrType,
|
||||
SocketAttribute,
|
||||
SocketListener,
|
||||
SocketStream,
|
||||
UDPPacketType,
|
||||
UDPSocket,
|
||||
UNIXSocketStream,
|
||||
)
|
||||
from ._streams import (
|
||||
AnyByteReceiveStream,
|
||||
AnyByteSendStream,
|
||||
AnyByteStream,
|
||||
AnyUnreliableByteReceiveStream,
|
||||
AnyUnreliableByteSendStream,
|
||||
AnyUnreliableByteStream,
|
||||
ByteReceiveStream,
|
||||
ByteSendStream,
|
||||
ByteStream,
|
||||
Listener,
|
||||
ObjectReceiveStream,
|
||||
ObjectSendStream,
|
||||
ObjectStream,
|
||||
UnreliableObjectReceiveStream,
|
||||
UnreliableObjectSendStream,
|
||||
UnreliableObjectStream,
|
||||
)
|
||||
from ._subprocesses import Process
|
||||
from ._tasks import TaskGroup, TaskStatus
|
||||
from ._testing import TestRunner
|
||||
|
||||
# Re-exported here, for backwards compatibility
|
||||
# isort: off
|
||||
from .._core._synchronization import CapacityLimiter, Condition, Event, Lock, Semaphore
|
||||
from .._core._tasks import CancelScope
|
||||
from ..from_thread import BlockingPortal
|
||||
|
||||
# Re-export imports so they look like they live directly in this package
|
||||
key: str
|
||||
value: Any
|
||||
for key, value in list(locals().items()):
|
||||
if getattr(value, "__module__", "").startswith("anyio.abc."):
|
||||
value.__module__ = __name__
|
@ -0,0 +1,29 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncResource(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract base class for all closeable asynchronous resources.
|
||||
|
||||
Works as an asynchronous context manager which returns the instance itself on enter, and calls
|
||||
:meth:`aclose` on exit.
|
||||
"""
|
||||
|
||||
async def __aenter__(self: T) -> T:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
@abstractmethod
|
||||
async def aclose(self) -> None:
|
||||
"""Close the resource."""
|
@ -0,0 +1,183 @@
|
||||
import socket
|
||||
from abc import abstractmethod
|
||||
from io import IOBase
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from socket import AddressFamily
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .._core._typedattr import (
|
||||
TypedAttributeProvider,
|
||||
TypedAttributeSet,
|
||||
typed_attribute,
|
||||
)
|
||||
from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream
|
||||
from ._tasks import TaskGroup
|
||||
|
||||
IPAddressType = Union[str, IPv4Address, IPv6Address]
|
||||
IPSockAddrType = Tuple[str, int]
|
||||
SockAddrType = Union[IPSockAddrType, str]
|
||||
UDPPacketType = Tuple[bytes, IPSockAddrType]
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
|
||||
|
||||
class _NullAsyncContextManager:
|
||||
async def __aenter__(self) -> None:
|
||||
pass
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return None
|
||||
|
||||
|
||||
class SocketAttribute(TypedAttributeSet):
|
||||
#: the address family of the underlying socket
|
||||
family: AddressFamily = typed_attribute()
|
||||
#: the local socket address of the underlying socket
|
||||
local_address: SockAddrType = typed_attribute()
|
||||
#: for IP addresses, the local port the underlying socket is bound to
|
||||
local_port: int = typed_attribute()
|
||||
#: the underlying stdlib socket object
|
||||
raw_socket: socket.socket = typed_attribute()
|
||||
#: the remote address the underlying socket is connected to
|
||||
remote_address: SockAddrType = typed_attribute()
|
||||
#: for IP addresses, the remote port the underlying socket is connected to
|
||||
remote_port: int = typed_attribute()
|
||||
|
||||
|
||||
class _SocketProvider(TypedAttributeProvider):
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
from .._core._sockets import convert_ipv6_sockaddr as convert
|
||||
|
||||
attributes: Dict[Any, Callable[[], Any]] = {
|
||||
SocketAttribute.family: lambda: self._raw_socket.family,
|
||||
SocketAttribute.local_address: lambda: convert(
|
||||
self._raw_socket.getsockname()
|
||||
),
|
||||
SocketAttribute.raw_socket: lambda: self._raw_socket,
|
||||
}
|
||||
try:
|
||||
peername: Optional[Tuple[str, int]] = convert(
|
||||
self._raw_socket.getpeername()
|
||||
)
|
||||
except OSError:
|
||||
peername = None
|
||||
|
||||
# Provide the remote address for connected sockets
|
||||
if peername is not None:
|
||||
attributes[SocketAttribute.remote_address] = lambda: peername
|
||||
|
||||
# Provide local and remote ports for IP based sockets
|
||||
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
|
||||
attributes[
|
||||
SocketAttribute.local_port
|
||||
] = lambda: self._raw_socket.getsockname()[1]
|
||||
if peername is not None:
|
||||
remote_port = peername[1]
|
||||
attributes[SocketAttribute.remote_port] = lambda: remote_port
|
||||
|
||||
return attributes
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _raw_socket(self) -> socket.socket:
|
||||
pass
|
||||
|
||||
|
||||
class SocketStream(ByteStream, _SocketProvider):
|
||||
"""
|
||||
Transports bytes over a socket.
|
||||
|
||||
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||||
"""
|
||||
|
||||
|
||||
class UNIXSocketStream(SocketStream):
|
||||
@abstractmethod
|
||||
async def send_fds(
|
||||
self, message: bytes, fds: Collection[Union[int, IOBase]]
|
||||
) -> None:
|
||||
"""
|
||||
Send file descriptors along with a message to the peer.
|
||||
|
||||
:param message: a non-empty bytestring
|
||||
:param fds: a collection of files (either numeric file descriptors or open file or socket
|
||||
objects)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def receive_fds(self, msglen: int, maxfds: int) -> Tuple[bytes, List[int]]:
|
||||
"""
|
||||
Receive file descriptors along with a message from the peer.
|
||||
|
||||
:param msglen: length of the message to expect from the peer
|
||||
:param maxfds: maximum number of file descriptors to expect from the peer
|
||||
:return: a tuple of (message, file descriptors)
|
||||
"""
|
||||
|
||||
|
||||
class SocketListener(Listener[SocketStream], _SocketProvider):
|
||||
"""
|
||||
Listens to incoming socket connections.
|
||||
|
||||
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def accept(self) -> SocketStream:
|
||||
"""Accept an incoming connection."""
|
||||
|
||||
async def serve(
|
||||
self, handler: Callable[[T_Stream], Any], task_group: Optional[TaskGroup] = None
|
||||
) -> None:
|
||||
from .. import create_task_group
|
||||
|
||||
context_manager: AsyncContextManager
|
||||
if task_group is None:
|
||||
task_group = context_manager = create_task_group()
|
||||
else:
|
||||
# Can be replaced with AsyncExitStack once on py3.7+
|
||||
context_manager = _NullAsyncContextManager()
|
||||
|
||||
async with context_manager:
|
||||
while True:
|
||||
stream = await self.accept()
|
||||
task_group.start_soon(handler, stream)
|
||||
|
||||
|
||||
class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
|
||||
"""
|
||||
Represents an unconnected UDP socket.
|
||||
|
||||
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||||
"""
|
||||
|
||||
async def sendto(self, data: bytes, host: str, port: int) -> None:
|
||||
"""Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port)))."""
|
||||
return await self.send((data, (host, port)))
|
||||
|
||||
|
||||
class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
|
||||
"""
|
||||
Represents an connected UDP socket.
|
||||
|
||||
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||||
"""
|
@ -0,0 +1,198 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union
|
||||
|
||||
from .._core._exceptions import EndOfStream
|
||||
from .._core._typedattr import TypedAttributeProvider
|
||||
from ._resources import AsyncResource
|
||||
from ._tasks import TaskGroup
|
||||
|
||||
T_Item = TypeVar("T_Item")
|
||||
T_Stream = TypeVar("T_Stream")
|
||||
|
||||
|
||||
class UnreliableObjectReceiveStream(
|
||||
Generic[T_Item], AsyncResource, TypedAttributeProvider
|
||||
):
|
||||
"""
|
||||
An interface for receiving objects.
|
||||
|
||||
This interface makes no guarantees that the received messages arrive in the order in which they
|
||||
were sent, or that no messages are missed.
|
||||
|
||||
Asynchronously iterating over objects of this type will yield objects matching the given type
|
||||
parameter.
|
||||
"""
|
||||
|
||||
def __aiter__(self) -> "UnreliableObjectReceiveStream[T_Item]":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> T_Item:
|
||||
try:
|
||||
return await self.receive()
|
||||
except EndOfStream:
|
||||
raise StopAsyncIteration
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> T_Item:
|
||||
"""
|
||||
Receive the next item.
|
||||
|
||||
:raises ~anyio.ClosedResourceError: if the receive stream has been explicitly
|
||||
closed
|
||||
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
|
||||
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
|
||||
due to external causes
|
||||
"""
|
||||
|
||||
|
||||
class UnreliableObjectSendStream(
|
||||
Generic[T_Item], AsyncResource, TypedAttributeProvider
|
||||
):
|
||||
"""
|
||||
An interface for sending objects.
|
||||
|
||||
This interface makes no guarantees that the messages sent will reach the recipient(s) in the
|
||||
same order in which they were sent, or at all.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, item: T_Item) -> None:
|
||||
"""
|
||||
Send an item to the peer(s).
|
||||
|
||||
:param item: the item to send
|
||||
:raises ~anyio.ClosedResourceError: if the send stream has been explicitly
|
||||
closed
|
||||
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
|
||||
due to external causes
|
||||
"""
|
||||
|
||||
|
||||
class UnreliableObjectStream(
|
||||
UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item]
|
||||
):
|
||||
"""
|
||||
A bidirectional message stream which does not guarantee the order or reliability of message
|
||||
delivery.
|
||||
"""
|
||||
|
||||
|
||||
class ObjectReceiveStream(UnreliableObjectReceiveStream[T_Item]):
|
||||
"""
|
||||
A receive message stream which guarantees that messages are received in the same order in
|
||||
which they were sent, and that no messages are missed.
|
||||
"""
|
||||
|
||||
|
||||
class ObjectSendStream(UnreliableObjectSendStream[T_Item]):
|
||||
"""
|
||||
A send message stream which guarantees that messages are delivered in the same order in which
|
||||
they were sent, without missing any messages in the middle.
|
||||
"""
|
||||
|
||||
|
||||
class ObjectStream(
|
||||
ObjectReceiveStream[T_Item],
|
||||
ObjectSendStream[T_Item],
|
||||
UnreliableObjectStream[T_Item],
|
||||
):
|
||||
"""
|
||||
A bidirectional message stream which guarantees the order and reliability of message delivery.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_eof(self) -> None:
|
||||
"""
|
||||
Send an end-of-file indication to the peer.
|
||||
|
||||
You should not try to send any further data to this stream after calling this method.
|
||||
This method is idempotent (does nothing on successive calls).
|
||||
"""
|
||||
|
||||
|
||||
class ByteReceiveStream(AsyncResource, TypedAttributeProvider):
|
||||
"""
|
||||
An interface for receiving bytes from a single peer.
|
||||
|
||||
Iterating this byte stream will yield a byte string of arbitrary length, but no more than
|
||||
65536 bytes.
|
||||
"""
|
||||
|
||||
def __aiter__(self) -> "ByteReceiveStream":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
try:
|
||||
return await self.receive()
|
||||
except EndOfStream:
|
||||
raise StopAsyncIteration
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
"""
|
||||
Receive at most ``max_bytes`` bytes from the peer.
|
||||
|
||||
.. note:: Implementors of this interface should not return an empty :class:`bytes` object,
|
||||
and users should ignore them.
|
||||
|
||||
:param max_bytes: maximum number of bytes to receive
|
||||
:return: the received bytes
|
||||
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
|
||||
"""
|
||||
|
||||
|
||||
class ByteSendStream(AsyncResource, TypedAttributeProvider):
|
||||
"""An interface for sending bytes to a single peer."""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, item: bytes) -> None:
|
||||
"""
|
||||
Send the given bytes to the peer.
|
||||
|
||||
:param item: the bytes to send
|
||||
"""
|
||||
|
||||
|
||||
class ByteStream(ByteReceiveStream, ByteSendStream):
|
||||
"""A bidirectional byte stream."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_eof(self) -> None:
|
||||
"""
|
||||
Send an end-of-file indication to the peer.
|
||||
|
||||
You should not try to send any further data to this stream after calling this method.
|
||||
This method is idempotent (does nothing on successive calls).
|
||||
"""
|
||||
|
||||
|
||||
#: Type alias for all unreliable bytes-oriented receive streams.
|
||||
AnyUnreliableByteReceiveStream = Union[
|
||||
UnreliableObjectReceiveStream[bytes], ByteReceiveStream
|
||||
]
|
||||
#: Type alias for all unreliable bytes-oriented send streams.
|
||||
AnyUnreliableByteSendStream = Union[UnreliableObjectSendStream[bytes], ByteSendStream]
|
||||
#: Type alias for all unreliable bytes-oriented streams.
|
||||
AnyUnreliableByteStream = Union[UnreliableObjectStream[bytes], ByteStream]
|
||||
#: Type alias for all bytes-oriented receive streams.
|
||||
AnyByteReceiveStream = Union[ObjectReceiveStream[bytes], ByteReceiveStream]
|
||||
#: Type alias for all bytes-oriented send streams.
|
||||
AnyByteSendStream = Union[ObjectSendStream[bytes], ByteSendStream]
|
||||
#: Type alias for all bytes-oriented streams.
|
||||
AnyByteStream = Union[ObjectStream[bytes], ByteStream]
|
||||
|
||||
|
||||
class Listener(Generic[T_Stream], AsyncResource, TypedAttributeProvider):
|
||||
"""An interface for objects that let you accept incoming connections."""
|
||||
|
||||
@abstractmethod
|
||||
async def serve(
|
||||
self, handler: Callable[[T_Stream], Any], task_group: Optional[TaskGroup] = None
|
||||
) -> None:
|
||||
"""
|
||||
Accept incoming connections as they come in and start tasks to handle them.
|
||||
|
||||
:param handler: a callable that will be used to handle each accepted connection
|
||||
:param task_group: the task group that will be used to start tasks for handling each
|
||||
accepted connection (if omitted, an ad-hoc task group will be created)
|
||||
"""
|
@ -0,0 +1,78 @@
|
||||
from abc import abstractmethod
|
||||
from signal import Signals
|
||||
from typing import Optional
|
||||
|
||||
from ._resources import AsyncResource
|
||||
from ._streams import ByteReceiveStream, ByteSendStream
|
||||
|
||||
|
||||
class Process(AsyncResource):
|
||||
"""An asynchronous version of :class:`subprocess.Popen`."""
|
||||
|
||||
@abstractmethod
|
||||
async def wait(self) -> int:
|
||||
"""
|
||||
Wait until the process exits.
|
||||
|
||||
:return: the exit code of the process
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def terminate(self) -> None:
|
||||
"""
|
||||
Terminates the process, gracefully if possible.
|
||||
|
||||
On Windows, this calls ``TerminateProcess()``.
|
||||
On POSIX systems, this sends ``SIGTERM`` to the process.
|
||||
|
||||
.. seealso:: :meth:`subprocess.Popen.terminate`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def kill(self) -> None:
|
||||
"""
|
||||
Kills the process.
|
||||
|
||||
On Windows, this calls ``TerminateProcess()``.
|
||||
On POSIX systems, this sends ``SIGKILL`` to the process.
|
||||
|
||||
.. seealso:: :meth:`subprocess.Popen.kill`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_signal(self, signal: Signals) -> None:
|
||||
"""
|
||||
Send a signal to the subprocess.
|
||||
|
||||
.. seealso:: :meth:`subprocess.Popen.send_signal`
|
||||
|
||||
:param signal: the signal number (e.g. :data:`signal.SIGHUP`)
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pid(self) -> int:
|
||||
"""The process ID of the process."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def returncode(self) -> Optional[int]:
|
||||
"""
|
||||
The return code of the process. If the process has not yet terminated, this will be
|
||||
``None``.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stdin(self) -> Optional[ByteSendStream]:
|
||||
"""The stream for the standard input of the process."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stdout(self) -> Optional[ByteReceiveStream]:
|
||||
"""The stream for the standard output of the process."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stderr(self) -> Optional[ByteReceiveStream]:
|
||||
"""The stream for the standard error output of the process."""
|
@ -0,0 +1,104 @@
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Coroutine, Optional, Type, TypeVar
|
||||
from warnings import warn
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from anyio._core._tasks import CancelScope
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
|
||||
|
||||
class TaskStatus(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def started(self, value: object = None) -> None:
|
||||
"""
|
||||
Signal that the task has started.
|
||||
|
||||
:param value: object passed back to the starter of the task
|
||||
"""
|
||||
|
||||
|
||||
class TaskGroup(metaclass=ABCMeta):
|
||||
"""
|
||||
Groups several asynchronous tasks together.
|
||||
|
||||
:ivar cancel_scope: the cancel scope inherited by all child tasks
|
||||
:vartype cancel_scope: CancelScope
|
||||
"""
|
||||
|
||||
cancel_scope: "CancelScope"
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> None:
|
||||
"""
|
||||
Start a new task in this task group.
|
||||
|
||||
:param func: a coroutine function
|
||||
:param args: positional arguments to call the function with
|
||||
:param name: name of the task, for the purposes of introspection and debugging
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :meth:`start_soon` instead. If your code needs AnyIO 2 compatibility, you
|
||||
can keep using this until AnyIO 4.
|
||||
|
||||
"""
|
||||
warn(
|
||||
'spawn() is deprecated -- use start_soon() (without the "await") instead',
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.start_soon(func, *args, name=name)
|
||||
|
||||
@abstractmethod
|
||||
def start_soon(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> None:
|
||||
"""
|
||||
Start a new task in this task group.
|
||||
|
||||
:param func: a coroutine function
|
||||
:param args: positional arguments to call the function with
|
||||
:param name: name of the task, for the purposes of introspection and debugging
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def start(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> object:
|
||||
"""
|
||||
Start a new task and wait until it signals for readiness.
|
||||
|
||||
:param func: a coroutine function
|
||||
:param args: positional arguments to call the function with
|
||||
:param name: name of the task, for the purposes of introspection and debugging
|
||||
:return: the value passed to ``task_status.started()``
|
||||
:raises RuntimeError: if the task finishes without calling ``task_status.started()``
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self) -> "TaskGroup":
|
||||
"""Enter the task group context and allow starting new tasks."""
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
"""Exit the task group context waiting for all tasks to finish."""
|
@ -0,0 +1,68 @@
|
||||
import types
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class TestRunner(metaclass=ABCMeta):
|
||||
"""
|
||||
Encapsulates a running event loop. Every call made through this object will use the same event
|
||||
loop.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "TestRunner":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[types.TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the event loop."""
|
||||
|
||||
@abstractmethod
|
||||
def run_asyncgen_fixture(
|
||||
self,
|
||||
fixture_func: Callable[..., "AsyncGenerator[_T, Any]"],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> "Iterable[_T]":
|
||||
"""
|
||||
Run an async generator fixture.
|
||||
|
||||
:param fixture_func: the fixture function
|
||||
:param kwargs: keyword arguments to call the fixture function with
|
||||
:return: an iterator yielding the value yielded from the async generator
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run_fixture(
|
||||
self,
|
||||
fixture_func: Callable[..., Coroutine[Any, Any, _T]],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> _T:
|
||||
"""
|
||||
Run an async fixture.
|
||||
|
||||
:param fixture_func: the fixture function
|
||||
:param kwargs: keyword arguments to call the fixture function with
|
||||
:return: the return value of the fixture function
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run_test(
|
||||
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Run an async test function.
|
||||
|
||||
:param test_func: the test function
|
||||
:param kwargs: keyword arguments to call the test function with
|
||||
"""
|
@ -0,0 +1,502 @@
|
||||
import threading
|
||||
from asyncio import iscoroutine
|
||||
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from warnings import warn
|
||||
|
||||
from ._core import _eventloop
|
||||
from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals
|
||||
from ._core._synchronization import Event
|
||||
from ._core._tasks import CancelScope, create_task_group
|
||||
from .abc._tasks import TaskStatus
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_co = TypeVar("T_co")
|
||||
|
||||
|
||||
def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval:
|
||||
"""
|
||||
Call a coroutine function from a worker thread.
|
||||
|
||||
:param func: a coroutine function
|
||||
:param args: positional arguments for the callable
|
||||
:return: the return value of the coroutine function
|
||||
|
||||
"""
|
||||
try:
|
||||
asynclib = threadlocals.current_async_module
|
||||
except AttributeError:
|
||||
raise RuntimeError("This function can only be run from an AnyIO worker thread")
|
||||
|
||||
return asynclib.run_async_from_thread(func, *args)
|
||||
|
||||
|
||||
def run_async_from_thread(
|
||||
func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object
|
||||
) -> T_Retval:
|
||||
warn(
|
||||
"run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return run(func, *args)
|
||||
|
||||
|
||||
def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
|
||||
"""
|
||||
Call a function in the event loop thread from a worker thread.
|
||||
|
||||
:param func: a callable
|
||||
:param args: positional arguments for the callable
|
||||
:return: the return value of the callable
|
||||
|
||||
"""
|
||||
try:
|
||||
asynclib = threadlocals.current_async_module
|
||||
except AttributeError:
|
||||
raise RuntimeError("This function can only be run from an AnyIO worker thread")
|
||||
|
||||
return asynclib.run_sync_from_thread(func, *args)
|
||||
|
||||
|
||||
def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval:
|
||||
warn(
|
||||
"run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return run_sync(func, *args)
|
||||
|
||||
|
||||
class _BlockingAsyncContextManager(AbstractContextManager):
|
||||
_enter_future: Future
|
||||
_exit_future: Future
|
||||
_exit_event: Event
|
||||
_exit_exc_info: Tuple[
|
||||
Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]
|
||||
] = (None, None, None)
|
||||
|
||||
def __init__(self, async_cm: AsyncContextManager[T_co], portal: "BlockingPortal"):
|
||||
self._async_cm = async_cm
|
||||
self._portal = portal
|
||||
|
||||
async def run_async_cm(self) -> Optional[bool]:
|
||||
try:
|
||||
self._exit_event = Event()
|
||||
value = await self._async_cm.__aenter__()
|
||||
except BaseException as exc:
|
||||
self._enter_future.set_exception(exc)
|
||||
raise
|
||||
else:
|
||||
self._enter_future.set_result(value)
|
||||
|
||||
try:
|
||||
# Wait for the sync context manager to exit.
|
||||
# This next statement can raise `get_cancelled_exc_class()` if
|
||||
# something went wrong in a task group in this async context
|
||||
# manager.
|
||||
await self._exit_event.wait()
|
||||
finally:
|
||||
# In case of cancellation, it could be that we end up here before
|
||||
# `_BlockingAsyncContextManager.__exit__` is called, and an
|
||||
# `_exit_exc_info` has been set.
|
||||
result = await self._async_cm.__aexit__(*self._exit_exc_info)
|
||||
return result
|
||||
|
||||
def __enter__(self) -> T_co:
|
||||
self._enter_future = Future()
|
||||
self._exit_future = self._portal.start_task_soon(self.run_async_cm)
|
||||
cm = self._enter_future.result()
|
||||
return cast(T_co, cm)
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
__exc_type: Optional[Type[BaseException]],
|
||||
__exc_value: Optional[BaseException],
|
||||
__traceback: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._exit_exc_info = __exc_type, __exc_value, __traceback
|
||||
self._portal.call(self._exit_event.set)
|
||||
return self._exit_future.result()
|
||||
|
||||
|
||||
class _BlockingPortalTaskStatus(TaskStatus):
|
||||
def __init__(self, future: Future):
|
||||
self._future = future
|
||||
|
||||
def started(self, value: object = None) -> None:
|
||||
self._future.set_result(value)
|
||||
|
||||
|
||||
class BlockingPortal:
|
||||
"""An object that lets external threads run code in an asynchronous event loop."""
|
||||
|
||||
def __new__(cls) -> "BlockingPortal":
|
||||
return get_asynclib().BlockingPortal()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event_loop_thread_id: Optional[int] = threading.get_ident()
|
||||
self._stop_event = Event()
|
||||
self._task_group = create_task_group()
|
||||
self._cancelled_exc_class = get_cancelled_exc_class()
|
||||
|
||||
async def __aenter__(self) -> "BlockingPortal":
|
||||
await self._task_group.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
await self.stop()
|
||||
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def _check_running(self) -> None:
|
||||
if self._event_loop_thread_id is None:
|
||||
raise RuntimeError("This portal is not running")
|
||||
if self._event_loop_thread_id == threading.get_ident():
|
||||
raise RuntimeError(
|
||||
"This method cannot be called from the event loop thread"
|
||||
)
|
||||
|
||||
async def sleep_until_stopped(self) -> None:
|
||||
"""Sleep until :meth:`stop` is called."""
|
||||
await self._stop_event.wait()
|
||||
|
||||
async def stop(self, cancel_remaining: bool = False) -> None:
|
||||
"""
|
||||
Signal the portal to shut down.
|
||||
|
||||
This marks the portal as no longer accepting new calls and exits from
|
||||
:meth:`sleep_until_stopped`.
|
||||
|
||||
:param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them
|
||||
finish before returning
|
||||
|
||||
"""
|
||||
self._event_loop_thread_id = None
|
||||
self._stop_event.set()
|
||||
if cancel_remaining:
|
||||
self._task_group.cancel_scope.cancel()
|
||||
|
||||
async def _call_func(
|
||||
self, func: Callable, args: tuple, kwargs: Dict[str, Any], future: Future
|
||||
) -> None:
|
||||
def callback(f: Future) -> None:
|
||||
if f.cancelled() and self._event_loop_thread_id not in (
|
||||
None,
|
||||
threading.get_ident(),
|
||||
):
|
||||
self.call(scope.cancel)
|
||||
|
||||
try:
|
||||
retval = func(*args, **kwargs)
|
||||
if iscoroutine(retval):
|
||||
with CancelScope() as scope:
|
||||
if future.cancelled():
|
||||
scope.cancel()
|
||||
else:
|
||||
future.add_done_callback(callback)
|
||||
|
||||
retval = await retval
|
||||
except self._cancelled_exc_class:
|
||||
future.cancel()
|
||||
except BaseException as exc:
|
||||
if not future.cancelled():
|
||||
future.set_exception(exc)
|
||||
|
||||
# Let base exceptions fall through
|
||||
if not isinstance(exc, Exception):
|
||||
raise
|
||||
else:
|
||||
if not future.cancelled():
|
||||
future.set_result(retval)
|
||||
finally:
|
||||
scope = None # type: ignore[assignment]
|
||||
|
||||
def _spawn_task_from_thread(
|
||||
self,
|
||||
func: Callable,
|
||||
args: tuple,
|
||||
kwargs: Dict[str, Any],
|
||||
name: object,
|
||||
future: Future,
|
||||
) -> None:
|
||||
"""
|
||||
Spawn a new task using the given callable.
|
||||
|
||||
Implementors must ensure that the future is resolved when the task finishes.
|
||||
|
||||
:param func: a callable
|
||||
:param args: positional arguments to be passed to the callable
|
||||
:param kwargs: keyword arguments to be passed to the callable
|
||||
:param name: name of the task (will be coerced to a string if not ``None``)
|
||||
:param future: a future that will resolve to the return value of the callable, or the
|
||||
exception raised during its execution
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object
|
||||
) -> T_Retval:
|
||||
...
|
||||
|
||||
@overload
|
||||
def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval:
|
||||
...
|
||||
|
||||
def call(
|
||||
self,
|
||||
func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
|
||||
*args: object
|
||||
) -> T_Retval:
|
||||
"""
|
||||
Call the given function in the event loop thread.
|
||||
|
||||
If the callable returns a coroutine object, it is awaited on.
|
||||
|
||||
:param func: any callable
|
||||
:raises RuntimeError: if the portal is not running or if this method is called from within
|
||||
the event loop thread
|
||||
|
||||
"""
|
||||
return cast(T_Retval, self.start_task_soon(func, *args).result())
|
||||
|
||||
@overload
|
||||
def spawn_task(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, T_Retval]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> "Future[T_Retval]":
|
||||
...
|
||||
|
||||
@overload
|
||||
def spawn_task(
|
||||
self, func: Callable[..., T_Retval], *args: object, name: object = None
|
||||
) -> "Future[T_Retval]":
|
||||
...
|
||||
|
||||
def spawn_task(
|
||||
self,
|
||||
func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> "Future[T_Retval]":
|
||||
"""
|
||||
Start a task in the portal's task group.
|
||||
|
||||
:param func: the target coroutine function
|
||||
:param args: positional arguments passed to ``func``
|
||||
:param name: name of the task (will be coerced to a string if not ``None``)
|
||||
:return: a future that resolves with the return value of the callable if the task completes
|
||||
successfully, or with the exception raised in the task
|
||||
:raises RuntimeError: if the portal is not running or if this method is called from within
|
||||
the event loop thread
|
||||
|
||||
.. versionadded:: 2.1
|
||||
.. deprecated:: 3.0
|
||||
Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you
|
||||
can keep using this until AnyIO 4.
|
||||
|
||||
"""
|
||||
warn(
|
||||
"spawn_task() is deprecated -- use start_task_soon() instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self.start_task_soon(func, *args, name=name) # type: ignore[arg-type]
|
||||
|
||||
@overload
|
||||
def start_task_soon(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, T_Retval]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> "Future[T_Retval]":
|
||||
...
|
||||
|
||||
@overload
|
||||
def start_task_soon(
|
||||
self, func: Callable[..., T_Retval], *args: object, name: object = None
|
||||
) -> "Future[T_Retval]":
|
||||
...
|
||||
|
||||
def start_task_soon(
|
||||
self,
|
||||
func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> "Future[T_Retval]":
|
||||
"""
|
||||
Start a task in the portal's task group.
|
||||
|
||||
The task will be run inside a cancel scope which can be cancelled by cancelling the
|
||||
returned future.
|
||||
|
||||
:param func: the target coroutine function
|
||||
:param args: positional arguments passed to ``func``
|
||||
:param name: name of the task (will be coerced to a string if not ``None``)
|
||||
:return: a future that resolves with the return value of the callable if the task completes
|
||||
successfully, or with the exception raised in the task
|
||||
:raises RuntimeError: if the portal is not running or if this method is called from within
|
||||
the event loop thread
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
self._check_running()
|
||||
f: Future = Future()
|
||||
self._spawn_task_from_thread(func, args, {}, name, f)
|
||||
return f
|
||||
|
||||
def start_task(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: object,
|
||||
name: object = None
|
||||
) -> Tuple["Future[Any]", Any]:
|
||||
"""
|
||||
Start a task in the portal's task group and wait until it signals for readiness.
|
||||
|
||||
This method works the same way as :meth:`TaskGroup.start`.
|
||||
|
||||
:param func: the target coroutine function
|
||||
:param args: positional arguments passed to ``func``
|
||||
:param name: name of the task (will be coerced to a string if not ``None``)
|
||||
:return: a tuple of (future, task_status_value) where the ``task_status_value`` is the
|
||||
value passed to ``task_status.started()`` from within the target function
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
|
||||
def task_done(future: Future) -> None:
|
||||
if not task_status_future.done():
|
||||
if future.cancelled():
|
||||
task_status_future.cancel()
|
||||
elif future.exception():
|
||||
task_status_future.set_exception(future.exception())
|
||||
else:
|
||||
exc = RuntimeError(
|
||||
"Task exited without calling task_status.started()"
|
||||
)
|
||||
task_status_future.set_exception(exc)
|
||||
|
||||
self._check_running()
|
||||
task_status_future: Future = Future()
|
||||
task_status = _BlockingPortalTaskStatus(task_status_future)
|
||||
f: Future = Future()
|
||||
f.add_done_callback(task_done)
|
||||
self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
|
||||
return f, task_status_future.result()
|
||||
|
||||
def wrap_async_context_manager(
|
||||
self, cm: AsyncContextManager[T_co]
|
||||
) -> ContextManager[T_co]:
|
||||
"""
|
||||
Wrap an async context manager as a synchronous context manager via this portal.
|
||||
|
||||
Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the
|
||||
middle until the synchronous context manager exits.
|
||||
|
||||
:param cm: an asynchronous context manager
|
||||
:return: a synchronous context manager
|
||||
|
||||
.. versionadded:: 2.1
|
||||
|
||||
"""
|
||||
return _BlockingAsyncContextManager(cm, self)
|
||||
|
||||
|
||||
def create_blocking_portal() -> BlockingPortal:
|
||||
"""
|
||||
Create a portal for running functions in the event loop thread from external threads.
|
||||
|
||||
Use this function in asynchronous code when you need to allow external threads access to the
|
||||
event loop where your asynchronous code is currently running.
|
||||
|
||||
.. deprecated:: 3.0
|
||||
Use :class:`.BlockingPortal` directly.
|
||||
|
||||
"""
|
||||
warn(
|
||||
"create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() "
|
||||
"directly",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return BlockingPortal()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def start_blocking_portal(
|
||||
backend: str = "asyncio", backend_options: Optional[Dict[str, Any]] = None
|
||||
) -> Generator[BlockingPortal, Any, None]:
|
||||
"""
|
||||
Start a new event loop in a new thread and run a blocking portal in its main task.
|
||||
|
||||
The parameters are the same as for :func:`~anyio.run`.
|
||||
|
||||
:param backend: name of the backend
|
||||
:param backend_options: backend options
|
||||
:return: a context manager that yields a blocking portal
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
Usage as a context manager is now required.
|
||||
|
||||
"""
|
||||
|
||||
async def run_portal() -> None:
|
||||
async with BlockingPortal() as portal_:
|
||||
if future.set_running_or_notify_cancel():
|
||||
future.set_result(portal_)
|
||||
await portal_.sleep_until_stopped()
|
||||
|
||||
future: Future[BlockingPortal] = Future()
|
||||
with ThreadPoolExecutor(1) as executor:
|
||||
run_future = executor.submit(
|
||||
_eventloop.run,
|
||||
run_portal, # type: ignore[arg-type]
|
||||
backend=backend,
|
||||
backend_options=backend_options,
|
||||
)
|
||||
try:
|
||||
wait(
|
||||
cast(Iterable[Future], [run_future, future]),
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
except BaseException:
|
||||
future.cancel()
|
||||
run_future.cancel()
|
||||
raise
|
||||
|
||||
if future.done():
|
||||
portal = future.result()
|
||||
try:
|
||||
yield portal
|
||||
except BaseException:
|
||||
portal.call(portal.stop, True)
|
||||
raise
|
||||
|
||||
portal.call(portal.stop, False)
|
||||
|
||||
run_future.result()
|
@ -0,0 +1,170 @@
|
||||
import enum
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, Set, TypeVar, Union, overload
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from ._core._eventloop import get_asynclib
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D")
|
||||
|
||||
|
||||
async def checkpoint() -> None:
|
||||
"""
|
||||
Check for cancellation and allow the scheduler to switch to another task.
|
||||
|
||||
Equivalent to (but more efficient than)::
|
||||
|
||||
await checkpoint_if_cancelled()
|
||||
await cancel_shielded_checkpoint()
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
await get_asynclib().checkpoint()
|
||||
|
||||
|
||||
async def checkpoint_if_cancelled() -> None:
|
||||
"""
|
||||
Enter a checkpoint if the enclosing cancel scope has been cancelled.
|
||||
|
||||
This does not allow the scheduler to switch to a different task.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
await get_asynclib().checkpoint_if_cancelled()
|
||||
|
||||
|
||||
async def cancel_shielded_checkpoint() -> None:
|
||||
"""
|
||||
Allow the scheduler to switch to another task but without checking for cancellation.
|
||||
|
||||
Equivalent to (but potentially more efficient than)::
|
||||
|
||||
with CancelScope(shield=True):
|
||||
await checkpoint()
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
await get_asynclib().cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
def current_token() -> object:
|
||||
"""Return a backend specific token object that can be used to get back to the event loop."""
|
||||
return get_asynclib().current_token()
|
||||
|
||||
|
||||
_run_vars = WeakKeyDictionary() # type: WeakKeyDictionary[Any, Dict[str, Any]]
|
||||
_token_wrappers: Dict[Any, "_TokenWrapper"] = {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _TokenWrapper:
|
||||
__slots__ = "_token", "__weakref__"
|
||||
_token: object
|
||||
|
||||
|
||||
class _NoValueSet(enum.Enum):
|
||||
NO_VALUE_SET = enum.auto()
|
||||
|
||||
|
||||
class RunvarToken(Generic[T]):
|
||||
__slots__ = "_var", "_value", "_redeemed"
|
||||
|
||||
def __init__(
|
||||
self, var: "RunVar[T]", value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]]
|
||||
):
|
||||
self._var = var
|
||||
self._value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = value
|
||||
self._redeemed = False
|
||||
|
||||
|
||||
class RunVar(Generic[T]):
|
||||
"""Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop."""
|
||||
|
||||
__slots__ = "_name", "_default"
|
||||
|
||||
NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
|
||||
|
||||
_token_wrappers: Set[_TokenWrapper] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET,
|
||||
):
|
||||
self._name = name
|
||||
self._default = default
|
||||
|
||||
@property
|
||||
def _current_vars(self) -> Dict[str, T]:
|
||||
token = current_token()
|
||||
while True:
|
||||
try:
|
||||
return _run_vars[token]
|
||||
except TypeError:
|
||||
# Happens when token isn't weak referable (TrioToken).
|
||||
# This workaround does mean that some memory will leak on Trio until the problem
|
||||
# is fixed on their end.
|
||||
token = _TokenWrapper(token)
|
||||
self._token_wrappers.add(token)
|
||||
except KeyError:
|
||||
run_vars = _run_vars[token] = {}
|
||||
return run_vars
|
||||
|
||||
@overload
|
||||
def get(self, default: D) -> Union[T, D]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self) -> T:
|
||||
...
|
||||
|
||||
def get(
|
||||
self, default: Union[D, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET
|
||||
) -> Union[T, D]:
|
||||
try:
|
||||
return self._current_vars[self._name]
|
||||
except KeyError:
|
||||
if default is not RunVar.NO_VALUE_SET:
|
||||
return default
|
||||
elif self._default is not RunVar.NO_VALUE_SET:
|
||||
return self._default
|
||||
|
||||
raise LookupError(
|
||||
f'Run variable "{self._name}" has no value and no default set'
|
||||
)
|
||||
|
||||
def set(self, value: T) -> RunvarToken[T]:
|
||||
current_vars = self._current_vars
|
||||
token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET))
|
||||
current_vars[self._name] = value
|
||||
return token
|
||||
|
||||
def reset(self, token: RunvarToken[T]) -> None:
|
||||
if token._var is not self:
|
||||
raise ValueError("This token does not belong to this RunVar")
|
||||
|
||||
if token._redeemed:
|
||||
raise ValueError("This token has already been used")
|
||||
|
||||
if token._value is _NoValueSet.NO_VALUE_SET:
|
||||
try:
|
||||
del self._current_vars[self._name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self._current_vars[self._name] = token._value
|
||||
|
||||
token._redeemed = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<RunVar name={self._name!r}>"
|
@ -0,0 +1,144 @@
|
||||
from contextlib import contextmanager
|
||||
from inspect import isasyncgenfunction, iscoroutinefunction
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, cast
|
||||
|
||||
import pytest
|
||||
import sniffio
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
|
||||
from ._core._eventloop import get_all_backends, get_asynclib
|
||||
from .abc import TestRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.config import Config
|
||||
|
||||
_current_runner: Optional[TestRunner] = None
|
||||
|
||||
|
||||
def extract_backend_and_options(backend: object) -> Tuple[str, Dict[str, Any]]:
|
||||
if isinstance(backend, str):
|
||||
return backend, {}
|
||||
elif isinstance(backend, tuple) and len(backend) == 2:
|
||||
if isinstance(backend[0], str) and isinstance(backend[1], dict):
|
||||
return cast(Tuple[str, Dict[str, Any]], backend)
|
||||
|
||||
raise TypeError("anyio_backend must be either a string or tuple of (string, dict)")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_runner(
|
||||
backend_name: str, backend_options: Dict[str, Any]
|
||||
) -> Generator[TestRunner, object, None]:
|
||||
global _current_runner
|
||||
if _current_runner:
|
||||
yield _current_runner
|
||||
return
|
||||
|
||||
asynclib = get_asynclib(backend_name)
|
||||
token = None
|
||||
if sniffio.current_async_library_cvar.get(None) is None:
|
||||
# Since we're in control of the event loop, we can cache the name of the async library
|
||||
token = sniffio.current_async_library_cvar.set(backend_name)
|
||||
|
||||
try:
|
||||
backend_options = backend_options or {}
|
||||
with asynclib.TestRunner(**backend_options) as runner:
|
||||
_current_runner = runner
|
||||
yield runner
|
||||
finally:
|
||||
_current_runner = None
|
||||
if token:
|
||||
sniffio.current_async_library_cvar.reset(token)
|
||||
|
||||
|
||||
def pytest_configure(config: "Config") -> None:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"anyio: mark the (coroutine function) test to be run "
|
||||
"asynchronously via anyio.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_fixture_setup(fixturedef: Any, request: FixtureRequest) -> None:
|
||||
def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
|
||||
backend_name, backend_options = extract_backend_and_options(anyio_backend)
|
||||
if has_backend_arg:
|
||||
kwargs["anyio_backend"] = anyio_backend
|
||||
|
||||
with get_runner(backend_name, backend_options) as runner:
|
||||
if isasyncgenfunction(func):
|
||||
yield from runner.run_asyncgen_fixture(func, kwargs)
|
||||
else:
|
||||
yield runner.run_fixture(func, kwargs)
|
||||
|
||||
# Only apply this to coroutine functions and async generator functions in requests that involve
|
||||
# the anyio_backend fixture
|
||||
func = fixturedef.func
|
||||
if isasyncgenfunction(func) or iscoroutinefunction(func):
|
||||
if "anyio_backend" in request.fixturenames:
|
||||
has_backend_arg = "anyio_backend" in fixturedef.argnames
|
||||
fixturedef.func = wrapper
|
||||
if not has_backend_arg:
|
||||
fixturedef.argnames += ("anyio_backend",)
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:
|
||||
if collector.istestfunction(obj, name):
|
||||
inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj
|
||||
if iscoroutinefunction(inner_func):
|
||||
marker = collector.get_closest_marker("anyio")
|
||||
own_markers = getattr(obj, "pytestmark", ())
|
||||
if marker or any(marker.name == "anyio" for marker in own_markers):
|
||||
pytest.mark.usefixtures("anyio_backend")(obj)
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_pyfunc_call(pyfuncitem: Any) -> Optional[bool]:
|
||||
def run_with_hypothesis(**kwargs: Any) -> None:
|
||||
with get_runner(backend_name, backend_options) as runner:
|
||||
runner.run_test(original_func, kwargs)
|
||||
|
||||
backend = pyfuncitem.funcargs.get("anyio_backend")
|
||||
if backend:
|
||||
backend_name, backend_options = extract_backend_and_options(backend)
|
||||
|
||||
if hasattr(pyfuncitem.obj, "hypothesis"):
|
||||
# Wrap the inner test function unless it's already wrapped
|
||||
original_func = pyfuncitem.obj.hypothesis.inner_test
|
||||
if original_func.__qualname__ != run_with_hypothesis.__qualname__:
|
||||
if iscoroutinefunction(original_func):
|
||||
pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
|
||||
|
||||
return None
|
||||
|
||||
if iscoroutinefunction(pyfuncitem.obj):
|
||||
funcargs = pyfuncitem.funcargs
|
||||
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
|
||||
with get_runner(backend_name, backend_options) as runner:
|
||||
runner.run_test(pyfuncitem.obj, testargs)
|
||||
|
||||
return True
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(params=get_all_backends())
|
||||
def anyio_backend(request: Any) -> Any:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend_name(anyio_backend: Any) -> str:
|
||||
if isinstance(anyio_backend, str):
|
||||
return anyio_backend
|
||||
else:
|
||||
return anyio_backend[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend_options(anyio_backend: Any) -> Dict[str, Any]:
|
||||
if isinstance(anyio_backend, str):
|
||||
return {}
|
||||
else:
|
||||
return anyio_backend[1]
|
@ -0,0 +1,116 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Mapping
|
||||
|
||||
from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
|
||||
from ..abc import AnyByteReceiveStream, ByteReceiveStream
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BufferedByteReceiveStream(ByteReceiveStream):
|
||||
"""
|
||||
Wraps any bytes-based receive stream and uses a buffer to provide sophisticated receiving
|
||||
capabilities in the form of a byte stream.
|
||||
"""
|
||||
|
||||
receive_stream: AnyByteReceiveStream
|
||||
_buffer: bytearray = field(init=False, default_factory=bytearray)
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.receive_stream.aclose()
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def buffer(self) -> bytes:
|
||||
"""The bytes currently in the buffer."""
|
||||
return bytes(self._buffer)
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return self.receive_stream.extra_attributes
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
if self._buffer:
|
||||
chunk = bytes(self._buffer[:max_bytes])
|
||||
del self._buffer[:max_bytes]
|
||||
return chunk
|
||||
elif isinstance(self.receive_stream, ByteReceiveStream):
|
||||
return await self.receive_stream.receive(max_bytes)
|
||||
else:
|
||||
# With a bytes-oriented object stream, we need to handle any surplus bytes we get from
|
||||
# the receive() call
|
||||
chunk = await self.receive_stream.receive()
|
||||
if len(chunk) > max_bytes:
|
||||
# Save the surplus bytes in the buffer
|
||||
self._buffer.extend(chunk[max_bytes:])
|
||||
return chunk[:max_bytes]
|
||||
else:
|
||||
return chunk
|
||||
|
||||
async def receive_exactly(self, nbytes: int) -> bytes:
|
||||
"""
|
||||
Read exactly the given amount of bytes from the stream.
|
||||
|
||||
:param nbytes: the number of bytes to read
|
||||
:return: the bytes read
|
||||
:raises ~anyio.IncompleteRead: if the stream was closed before the requested
|
||||
amount of bytes could be read from the stream
|
||||
|
||||
"""
|
||||
while True:
|
||||
remaining = nbytes - len(self._buffer)
|
||||
if remaining <= 0:
|
||||
retval = self._buffer[:nbytes]
|
||||
del self._buffer[:nbytes]
|
||||
return bytes(retval)
|
||||
|
||||
try:
|
||||
if isinstance(self.receive_stream, ByteReceiveStream):
|
||||
chunk = await self.receive_stream.receive(remaining)
|
||||
else:
|
||||
chunk = await self.receive_stream.receive()
|
||||
except EndOfStream as exc:
|
||||
raise IncompleteRead from exc
|
||||
|
||||
self._buffer.extend(chunk)
|
||||
|
||||
async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
|
||||
"""
|
||||
Read from the stream until the delimiter is found or max_bytes have been read.
|
||||
|
||||
:param delimiter: the marker to look for in the stream
|
||||
:param max_bytes: maximum number of bytes that will be read before raising
|
||||
:exc:`~anyio.DelimiterNotFound`
|
||||
:return: the bytes read (not including the delimiter)
|
||||
:raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
|
||||
was found
|
||||
:raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
|
||||
bytes read up to the maximum allowed
|
||||
|
||||
"""
|
||||
delimiter_size = len(delimiter)
|
||||
offset = 0
|
||||
while True:
|
||||
# Check if the delimiter can be found in the current buffer
|
||||
index = self._buffer.find(delimiter, offset)
|
||||
if index >= 0:
|
||||
found = self._buffer[:index]
|
||||
del self._buffer[: index + len(delimiter) :]
|
||||
return bytes(found)
|
||||
|
||||
# Check if the buffer is already at or over the limit
|
||||
if len(self._buffer) >= max_bytes:
|
||||
raise DelimiterNotFound(max_bytes)
|
||||
|
||||
# Read more data into the buffer from the socket
|
||||
try:
|
||||
data = await self.receive_stream.receive()
|
||||
except EndOfStream as exc:
|
||||
raise IncompleteRead from exc
|
||||
|
||||
# Move the offset forward and add the new data to the buffer
|
||||
offset = max(len(self._buffer) - delimiter_size + 1, 0)
|
||||
self._buffer.extend(data)
|
@ -0,0 +1,145 @@
|
||||
from io import SEEK_SET, UnsupportedOperation
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Callable, Dict, Mapping, Union, cast
|
||||
|
||||
from .. import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
TypedAttributeSet,
|
||||
to_thread,
|
||||
typed_attribute,
|
||||
)
|
||||
from ..abc import ByteReceiveStream, ByteSendStream
|
||||
|
||||
|
||||
class FileStreamAttribute(TypedAttributeSet):
|
||||
#: the open file descriptor
|
||||
file: BinaryIO = typed_attribute()
|
||||
#: the path of the file on the file system, if available (file must be a real file)
|
||||
path: Path = typed_attribute()
|
||||
#: the file number, if available (file must be a real file or a TTY)
|
||||
fileno: int = typed_attribute()
|
||||
|
||||
|
||||
class _BaseFileStream:
|
||||
def __init__(self, file: BinaryIO):
|
||||
self._file = file
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await to_thread.run_sync(self._file.close)
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
attributes: Dict[Any, Callable[[], Any]] = {
|
||||
FileStreamAttribute.file: lambda: self._file,
|
||||
}
|
||||
|
||||
if hasattr(self._file, "name"):
|
||||
attributes[FileStreamAttribute.path] = lambda: Path(self._file.name)
|
||||
|
||||
try:
|
||||
self._file.fileno()
|
||||
except UnsupportedOperation:
|
||||
pass
|
||||
else:
|
||||
attributes[FileStreamAttribute.fileno] = lambda: self._file.fileno()
|
||||
|
||||
return attributes
|
||||
|
||||
|
||||
class FileReadStream(_BaseFileStream, ByteReceiveStream):
|
||||
"""
|
||||
A byte stream that reads from a file in the file system.
|
||||
|
||||
:param file: a file that has been opened for reading in binary mode
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def from_path(cls, path: Union[str, "PathLike[str]"]) -> "FileReadStream":
|
||||
"""
|
||||
Create a file read stream by opening the given file.
|
||||
|
||||
:param path: path of the file to read from
|
||||
|
||||
"""
|
||||
file = await to_thread.run_sync(Path(path).open, "rb")
|
||||
return cls(cast(BinaryIO, file))
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
try:
|
||||
data = await to_thread.run_sync(self._file.read, max_bytes)
|
||||
except ValueError:
|
||||
raise ClosedResourceError from None
|
||||
except OSError as exc:
|
||||
raise BrokenResourceError from exc
|
||||
|
||||
if data:
|
||||
return data
|
||||
else:
|
||||
raise EndOfStream
|
||||
|
||||
async def seek(self, position: int, whence: int = SEEK_SET) -> int:
|
||||
"""
|
||||
Seek the file to the given position.
|
||||
|
||||
.. seealso:: :meth:`io.IOBase.seek`
|
||||
|
||||
.. note:: Not all file descriptors are seekable.
|
||||
|
||||
:param position: position to seek the file to
|
||||
:param whence: controls how ``position`` is interpreted
|
||||
:return: the new absolute position
|
||||
:raises OSError: if the file is not seekable
|
||||
|
||||
"""
|
||||
return await to_thread.run_sync(self._file.seek, position, whence)
|
||||
|
||||
async def tell(self) -> int:
|
||||
"""
|
||||
Return the current stream position.
|
||||
|
||||
.. note:: Not all file descriptors are seekable.
|
||||
|
||||
:return: the current absolute position
|
||||
:raises OSError: if the file is not seekable
|
||||
|
||||
"""
|
||||
return await to_thread.run_sync(self._file.tell)
|
||||
|
||||
|
||||
class FileWriteStream(_BaseFileStream, ByteSendStream):
|
||||
"""
|
||||
A byte stream that writes to a file in the file system.
|
||||
|
||||
:param file: a file that has been opened for writing in binary mode
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def from_path(
|
||||
cls, path: Union[str, "PathLike[str]"], append: bool = False
|
||||
) -> "FileWriteStream":
|
||||
"""
|
||||
Create a file write stream by opening the given file for writing.
|
||||
|
||||
:param path: path of the file to write to
|
||||
:param append: if ``True``, open the file for appending; if ``False``, any existing file
|
||||
at the given path will be truncated
|
||||
|
||||
"""
|
||||
mode = "ab" if append else "wb"
|
||||
file = await to_thread.run_sync(Path(path).open, mode)
|
||||
return cls(cast(BinaryIO, file))
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
try:
|
||||
await to_thread.run_sync(self._file.write, item)
|
||||
except ValueError:
|
||||
raise ClosedResourceError from None
|
||||
except OSError as exc:
|
||||
raise BrokenResourceError from exc
|
@ -0,0 +1,275 @@
|
||||
from collections import OrderedDict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from types import TracebackType
|
||||
from typing import Deque, Generic, List, NamedTuple, Optional, Type, TypeVar
|
||||
|
||||
from .. import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
WouldBlock,
|
||||
get_cancelled_exc_class,
|
||||
)
|
||||
from .._core._compat import DeprecatedAwaitable
|
||||
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
|
||||
from ..lowlevel import checkpoint
|
||||
|
||||
T_Item = TypeVar("T_Item")
|
||||
|
||||
|
||||
class MemoryObjectStreamStatistics(NamedTuple):
|
||||
current_buffer_used: int #: number of items stored in the buffer
|
||||
#: maximum number of items that can be stored on this stream (or :data:`math.inf`)
|
||||
max_buffer_size: float
|
||||
open_send_streams: int #: number of unclosed clones of the send stream
|
||||
open_receive_streams: int #: number of unclosed clones of the receive stream
|
||||
tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
|
||||
#: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
|
||||
tasks_waiting_receive: int
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryObjectStreamState(Generic[T_Item]):
|
||||
max_buffer_size: float = field()
|
||||
buffer: Deque[T_Item] = field(init=False, default_factory=deque)
|
||||
open_send_channels: int = field(init=False, default=0)
|
||||
open_receive_channels: int = field(init=False, default=0)
|
||||
waiting_receivers: "OrderedDict[Event, List[T_Item]]" = field(
|
||||
init=False, default_factory=OrderedDict
|
||||
)
|
||||
waiting_senders: "OrderedDict[Event, T_Item]" = field(
|
||||
init=False, default_factory=OrderedDict
|
||||
)
|
||||
|
||||
def statistics(self) -> MemoryObjectStreamStatistics:
|
||||
return MemoryObjectStreamStatistics(
|
||||
len(self.buffer),
|
||||
self.max_buffer_size,
|
||||
self.open_send_channels,
|
||||
self.open_receive_channels,
|
||||
len(self.waiting_senders),
|
||||
len(self.waiting_receivers),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryObjectReceiveStream(Generic[T_Item], ObjectReceiveStream[T_Item]):
|
||||
_state: MemoryObjectStreamState[T_Item]
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._state.open_receive_channels += 1
|
||||
|
||||
def receive_nowait(self) -> T_Item:
|
||||
"""
|
||||
Receive the next item if it can be done without waiting.
|
||||
|
||||
:return: the received item
|
||||
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||||
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
|
||||
closed from the sending end
|
||||
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
|
||||
waiting to send
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
if self._state.waiting_senders:
|
||||
# Get the item from the next sender
|
||||
send_event, item = self._state.waiting_senders.popitem(last=False)
|
||||
self._state.buffer.append(item)
|
||||
send_event.set()
|
||||
|
||||
if self._state.buffer:
|
||||
return self._state.buffer.popleft()
|
||||
elif not self._state.open_send_channels:
|
||||
raise EndOfStream
|
||||
|
||||
raise WouldBlock
|
||||
|
||||
async def receive(self) -> T_Item:
|
||||
await checkpoint()
|
||||
try:
|
||||
return self.receive_nowait()
|
||||
except WouldBlock:
|
||||
# Add ourselves in the queue
|
||||
receive_event = Event()
|
||||
container: List[T_Item] = []
|
||||
self._state.waiting_receivers[receive_event] = container
|
||||
|
||||
try:
|
||||
await receive_event.wait()
|
||||
except get_cancelled_exc_class():
|
||||
# Ignore the immediate cancellation if we already received an item, so as not to
|
||||
# lose it
|
||||
if not container:
|
||||
raise
|
||||
finally:
|
||||
self._state.waiting_receivers.pop(receive_event, None)
|
||||
|
||||
if container:
|
||||
return container[0]
|
||||
else:
|
||||
raise EndOfStream
|
||||
|
||||
def clone(self) -> "MemoryObjectReceiveStream[T_Item]":
|
||||
"""
|
||||
Create a clone of this receive stream.
|
||||
|
||||
Each clone can be closed separately. Only when all clones have been closed will the
|
||||
receiving end of the memory stream be considered closed by the sending ends.
|
||||
|
||||
:return: the cloned stream
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
return MemoryObjectReceiveStream(_state=self._state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the stream.
|
||||
|
||||
This works the exact same way as :meth:`aclose`, but is provided as a special case for the
|
||||
benefit of synchronous callbacks.
|
||||
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._state.open_receive_channels -= 1
|
||||
if self._state.open_receive_channels == 0:
|
||||
send_events = list(self._state.waiting_senders.keys())
|
||||
for event in send_events:
|
||||
event.set()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
|
||||
def statistics(self) -> MemoryObjectStreamStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this stream.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._state.statistics()
|
||||
|
||||
def __enter__(self) -> "MemoryObjectReceiveStream[T_Item]":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryObjectSendStream(Generic[T_Item], ObjectSendStream[T_Item]):
|
||||
_state: MemoryObjectStreamState[T_Item]
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._state.open_send_channels += 1
|
||||
|
||||
def send_nowait(self, item: T_Item) -> DeprecatedAwaitable:
|
||||
"""
|
||||
Send an item immediately if it can be done without waiting.
|
||||
|
||||
:param item: the item to send
|
||||
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||||
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
|
||||
receiving end
|
||||
:raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
|
||||
to receive
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
if not self._state.open_receive_channels:
|
||||
raise BrokenResourceError
|
||||
|
||||
if self._state.waiting_receivers:
|
||||
receive_event, container = self._state.waiting_receivers.popitem(last=False)
|
||||
container.append(item)
|
||||
receive_event.set()
|
||||
elif len(self._state.buffer) < self._state.max_buffer_size:
|
||||
self._state.buffer.append(item)
|
||||
else:
|
||||
raise WouldBlock
|
||||
|
||||
return DeprecatedAwaitable(self.send_nowait)
|
||||
|
||||
async def send(self, item: T_Item) -> None:
|
||||
await checkpoint()
|
||||
try:
|
||||
self.send_nowait(item)
|
||||
except WouldBlock:
|
||||
# Wait until there's someone on the receiving end
|
||||
send_event = Event()
|
||||
self._state.waiting_senders[send_event] = item
|
||||
try:
|
||||
await send_event.wait()
|
||||
except BaseException:
|
||||
self._state.waiting_senders.pop(send_event, None) # type: ignore[arg-type]
|
||||
raise
|
||||
|
||||
if self._state.waiting_senders.pop(send_event, None): # type: ignore[arg-type]
|
||||
raise BrokenResourceError
|
||||
|
||||
def clone(self) -> "MemoryObjectSendStream[T_Item]":
|
||||
"""
|
||||
Create a clone of this send stream.
|
||||
|
||||
Each clone can be closed separately. Only when all clones have been closed will the
|
||||
sending end of the memory stream be considered closed by the receiving ends.
|
||||
|
||||
:return: the cloned stream
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
return MemoryObjectSendStream(_state=self._state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the stream.
|
||||
|
||||
This works the exact same way as :meth:`aclose`, but is provided as a special case for the
|
||||
benefit of synchronous callbacks.
|
||||
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._state.open_send_channels -= 1
|
||||
if self._state.open_send_channels == 0:
|
||||
receive_events = list(self._state.waiting_receivers.keys())
|
||||
self._state.waiting_receivers.clear()
|
||||
for event in receive_events:
|
||||
event.set()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
|
||||
def statistics(self) -> MemoryObjectStreamStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this stream.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._state.statistics()
|
||||
|
||||
def __enter__(self) -> "MemoryObjectSendStream[T_Item]":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
@ -0,0 +1,138 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Generic, List, Mapping, Optional, Sequence, TypeVar
|
||||
|
||||
from ..abc import (
|
||||
ByteReceiveStream,
|
||||
ByteSendStream,
|
||||
ByteStream,
|
||||
Listener,
|
||||
ObjectReceiveStream,
|
||||
ObjectSendStream,
|
||||
ObjectStream,
|
||||
TaskGroup,
|
||||
)
|
||||
|
||||
T_Item = TypeVar("T_Item")
|
||||
T_Stream = TypeVar("T_Stream")
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class StapledByteStream(ByteStream):
|
||||
"""
|
||||
Combines two byte streams into a single, bidirectional byte stream.
|
||||
|
||||
Extra attributes will be provided from both streams, with the receive stream providing the
|
||||
values in case of a conflict.
|
||||
|
||||
:param ByteSendStream send_stream: the sending byte stream
|
||||
:param ByteReceiveStream receive_stream: the receiving byte stream
|
||||
"""
|
||||
|
||||
send_stream: ByteSendStream
|
||||
receive_stream: ByteReceiveStream
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
return await self.receive_stream.receive(max_bytes)
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
await self.send_stream.send(item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self.send_stream.extra_attributes,
|
||||
**self.receive_stream.extra_attributes,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
|
||||
"""
|
||||
Combines two object streams into a single, bidirectional object stream.
|
||||
|
||||
Extra attributes will be provided from both streams, with the receive stream providing the
|
||||
values in case of a conflict.
|
||||
|
||||
:param ObjectSendStream send_stream: the sending object stream
|
||||
:param ObjectReceiveStream receive_stream: the receiving object stream
|
||||
"""
|
||||
|
||||
send_stream: ObjectSendStream[T_Item]
|
||||
receive_stream: ObjectReceiveStream[T_Item]
|
||||
|
||||
async def receive(self) -> T_Item:
|
||||
return await self.receive_stream.receive()
|
||||
|
||||
async def send(self, item: T_Item) -> None:
|
||||
await self.send_stream.send(item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self.send_stream.extra_attributes,
|
||||
**self.receive_stream.extra_attributes,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MultiListener(Generic[T_Stream], Listener[T_Stream]):
|
||||
"""
|
||||
Combines multiple listeners into one, serving connections from all of them at once.
|
||||
|
||||
Any MultiListeners in the given collection of listeners will have their listeners moved into
|
||||
this one.
|
||||
|
||||
Extra attributes are provided from each listener, with each successive listener overriding any
|
||||
conflicting attributes from the previous one.
|
||||
|
||||
:param listeners: listeners to serve
|
||||
:type listeners: Sequence[Listener[T_Stream]]
|
||||
"""
|
||||
|
||||
listeners: Sequence[Listener[T_Stream]]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
listeners: List[Listener[T_Stream]] = []
|
||||
for listener in self.listeners:
|
||||
if isinstance(listener, MultiListener):
|
||||
listeners.extend(listener.listeners)
|
||||
del listener.listeners[:] # type: ignore[attr-defined]
|
||||
else:
|
||||
listeners.append(listener)
|
||||
|
||||
self.listeners = listeners
|
||||
|
||||
async def serve(
|
||||
self, handler: Callable[[T_Stream], Any], task_group: Optional[TaskGroup] = None
|
||||
) -> None:
|
||||
from .. import create_task_group
|
||||
|
||||
async with create_task_group() as tg:
|
||||
for listener in self.listeners:
|
||||
tg.start_soon(listener.serve, handler, task_group)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
for listener in self.listeners:
|
||||
await listener.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
attributes: dict = {}
|
||||
for listener in self.listeners:
|
||||
attributes.update(listener.extra_attributes)
|
||||
|
||||
return attributes
|
@ -0,0 +1,141 @@
|
||||
import codecs
|
||||
from dataclasses import InitVar, dataclass, field
|
||||
from typing import Any, Callable, Mapping, Tuple
|
||||
|
||||
from ..abc import (
|
||||
AnyByteReceiveStream,
|
||||
AnyByteSendStream,
|
||||
AnyByteStream,
|
||||
ObjectReceiveStream,
|
||||
ObjectSendStream,
|
||||
ObjectStream,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TextReceiveStream(ObjectReceiveStream[str]):
|
||||
"""
|
||||
Stream wrapper that decodes bytes to strings using the given encoding.
|
||||
|
||||
Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any completely
|
||||
received unicode characters as soon as they come in.
|
||||
|
||||
:param transport_stream: any bytes-based receive stream
|
||||
:param encoding: character encoding to use for decoding bytes to strings (defaults to
|
||||
``utf-8``)
|
||||
:param errors: handling scheme for decoding errors (defaults to ``strict``; see the
|
||||
`codecs module documentation`_ for a comprehensive list of options)
|
||||
|
||||
.. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteReceiveStream
|
||||
encoding: InitVar[str] = "utf-8"
|
||||
errors: InitVar[str] = "strict"
|
||||
_decoder: codecs.IncrementalDecoder = field(init=False)
|
||||
|
||||
def __post_init__(self, encoding: str, errors: str) -> None:
|
||||
decoder_class = codecs.getincrementaldecoder(encoding)
|
||||
self._decoder = decoder_class(errors=errors)
|
||||
|
||||
async def receive(self) -> str:
|
||||
while True:
|
||||
chunk = await self.transport_stream.receive()
|
||||
decoded = self._decoder.decode(chunk)
|
||||
if decoded:
|
||||
return decoded
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.transport_stream.aclose()
|
||||
self._decoder.reset()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return self.transport_stream.extra_attributes
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TextSendStream(ObjectSendStream[str]):
|
||||
"""
|
||||
Sends strings to the wrapped stream as bytes using the given encoding.
|
||||
|
||||
:param AnyByteSendStream transport_stream: any bytes-based send stream
|
||||
:param str encoding: character encoding to use for encoding strings to bytes (defaults to
|
||||
``utf-8``)
|
||||
:param str errors: handling scheme for encoding errors (defaults to ``strict``; see the
|
||||
`codecs module documentation`_ for a comprehensive list of options)
|
||||
|
||||
.. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteSendStream
|
||||
encoding: InitVar[str] = "utf-8"
|
||||
errors: str = "strict"
|
||||
_encoder: Callable[..., Tuple[bytes, int]] = field(init=False)
|
||||
|
||||
def __post_init__(self, encoding: str) -> None:
|
||||
self._encoder = codecs.getencoder(encoding)
|
||||
|
||||
async def send(self, item: str) -> None:
|
||||
encoded = self._encoder(item, self.errors)[0]
|
||||
await self.transport_stream.send(encoded)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.transport_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return self.transport_stream.extra_attributes
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TextStream(ObjectStream[str]):
|
||||
"""
|
||||
A bidirectional stream that decodes bytes to strings on receive and encodes strings to bytes on
|
||||
send.
|
||||
|
||||
Extra attributes will be provided from both streams, with the receive stream providing the
|
||||
values in case of a conflict.
|
||||
|
||||
:param AnyByteStream transport_stream: any bytes-based stream
|
||||
:param str encoding: character encoding to use for encoding/decoding strings to/from bytes
|
||||
(defaults to ``utf-8``)
|
||||
:param str errors: handling scheme for encoding errors (defaults to ``strict``; see the
|
||||
`codecs module documentation`_ for a comprehensive list of options)
|
||||
|
||||
.. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteStream
|
||||
encoding: InitVar[str] = "utf-8"
|
||||
errors: InitVar[str] = "strict"
|
||||
_receive_stream: TextReceiveStream = field(init=False)
|
||||
_send_stream: TextSendStream = field(init=False)
|
||||
|
||||
def __post_init__(self, encoding: str, errors: str) -> None:
|
||||
self._receive_stream = TextReceiveStream(
|
||||
self.transport_stream, encoding=encoding, errors=errors
|
||||
)
|
||||
self._send_stream = TextSendStream(
|
||||
self.transport_stream, encoding=encoding, errors=errors
|
||||
)
|
||||
|
||||
async def receive(self) -> str:
|
||||
return await self._receive_stream.receive()
|
||||
|
||||
async def send(self, item: str) -> None:
|
||||
await self._send_stream.send(item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
await self.transport_stream.send_eof()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._send_stream.aclose()
|
||||
await self._receive_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self._send_stream.extra_attributes,
|
||||
**self._receive_stream.extra_attributes,
|
||||
}
|
@ -0,0 +1,317 @@
|
||||
import logging
|
||||
import re
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
|
||||
|
||||
from .. import (
|
||||
BrokenResourceError,
|
||||
EndOfStream,
|
||||
aclose_forcefully,
|
||||
get_cancelled_exc_class,
|
||||
)
|
||||
from .._core._typedattr import TypedAttributeSet, typed_attribute
|
||||
from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
_PCTRTT = Tuple[Tuple[str, str], ...]
|
||||
_PCTRTTT = Tuple[_PCTRTT, ...]
|
||||
|
||||
|
||||
class TLSAttribute(TypedAttributeSet):
|
||||
"""Contains Transport Layer Security related attributes."""
|
||||
|
||||
#: the selected ALPN protocol
|
||||
alpn_protocol: Optional[str] = typed_attribute()
|
||||
#: the channel binding for type ``tls-unique``
|
||||
channel_binding_tls_unique: bytes = typed_attribute()
|
||||
#: the selected cipher
|
||||
cipher: Tuple[str, str, int] = typed_attribute()
|
||||
#: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more
|
||||
#: information)
|
||||
peer_certificate: Optional[
|
||||
Dict[str, Union[str, _PCTRTTT, _PCTRTT]]
|
||||
] = typed_attribute()
|
||||
#: the peer certificate in binary form
|
||||
peer_certificate_binary: Optional[bytes] = typed_attribute()
|
||||
#: ``True`` if this is the server side of the connection
|
||||
server_side: bool = typed_attribute()
|
||||
#: ciphers shared between both ends of the TLS connection
|
||||
shared_ciphers: List[Tuple[str, str, int]] = typed_attribute()
|
||||
#: the :class:`~ssl.SSLObject` used for encryption
|
||||
ssl_object: ssl.SSLObject = typed_attribute()
|
||||
#: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being
|
||||
#: closed
|
||||
standard_compatible: bool = typed_attribute()
|
||||
#: the TLS protocol version (e.g. ``TLSv1.2``)
|
||||
tls_version: str = typed_attribute()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TLSStream(ByteStream):
|
||||
"""
|
||||
A stream wrapper that encrypts all sent data and decrypts received data.
|
||||
|
||||
This class has no public initializer; use :meth:`wrap` instead.
|
||||
All extra attributes from :class:`~TLSAttribute` are supported.
|
||||
|
||||
:var AnyByteStream transport_stream: the wrapped stream
|
||||
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteStream
|
||||
standard_compatible: bool
|
||||
_ssl_object: ssl.SSLObject
|
||||
_read_bio: ssl.MemoryBIO
|
||||
_write_bio: ssl.MemoryBIO
|
||||
|
||||
@classmethod
|
||||
async def wrap(
|
||||
cls,
|
||||
transport_stream: AnyByteStream,
|
||||
*,
|
||||
server_side: Optional[bool] = None,
|
||||
hostname: Optional[str] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
standard_compatible: bool = True,
|
||||
) -> "TLSStream":
|
||||
"""
|
||||
Wrap an existing stream with Transport Layer Security.
|
||||
|
||||
This performs a TLS handshake with the peer.
|
||||
|
||||
:param transport_stream: a bytes-transporting stream to wrap
|
||||
:param server_side: ``True`` if this is the server side of the connection, ``False`` if
|
||||
this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been
|
||||
provided, ``False`` otherwise). Used only to create a default context when an explicit
|
||||
context has not been provided.
|
||||
:param hostname: host name of the peer (if host name checking is desired)
|
||||
:param ssl_context: the SSLContext object to use (if not provided, a secure default will be
|
||||
created)
|
||||
:param standard_compatible: if ``False``, skip the closing handshake when closing the
|
||||
connection, and don't raise an exception if the peer does the same
|
||||
:raises ~ssl.SSLError: if the TLS handshake fails
|
||||
|
||||
"""
|
||||
if server_side is None:
|
||||
server_side = not hostname
|
||||
|
||||
if not ssl_context:
|
||||
purpose = (
|
||||
ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
|
||||
)
|
||||
ssl_context = ssl.create_default_context(purpose)
|
||||
|
||||
# Re-enable detection of unexpected EOFs if it was disabled by Python
|
||||
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
|
||||
ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined]
|
||||
|
||||
bio_in = ssl.MemoryBIO()
|
||||
bio_out = ssl.MemoryBIO()
|
||||
ssl_object = ssl_context.wrap_bio(
|
||||
bio_in, bio_out, server_side=server_side, server_hostname=hostname
|
||||
)
|
||||
wrapper = cls(
|
||||
transport_stream=transport_stream,
|
||||
standard_compatible=standard_compatible,
|
||||
_ssl_object=ssl_object,
|
||||
_read_bio=bio_in,
|
||||
_write_bio=bio_out,
|
||||
)
|
||||
await wrapper._call_sslobject_method(ssl_object.do_handshake)
|
||||
return wrapper
|
||||
|
||||
async def _call_sslobject_method(
|
||||
self, func: Callable[..., T_Retval], *args: object
|
||||
) -> T_Retval:
|
||||
while True:
|
||||
try:
|
||||
result = func(*args)
|
||||
except ssl.SSLWantReadError:
|
||||
try:
|
||||
# Flush any pending writes first
|
||||
if self._write_bio.pending:
|
||||
await self.transport_stream.send(self._write_bio.read())
|
||||
|
||||
data = await self.transport_stream.receive()
|
||||
except EndOfStream:
|
||||
self._read_bio.write_eof()
|
||||
except OSError as exc:
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
raise BrokenResourceError from exc
|
||||
else:
|
||||
self._read_bio.write(data)
|
||||
except ssl.SSLWantWriteError:
|
||||
await self.transport_stream.send(self._write_bio.read())
|
||||
except ssl.SSLSyscallError as exc:
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
raise BrokenResourceError from exc
|
||||
except ssl.SSLError as exc:
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
if (
|
||||
isinstance(exc, ssl.SSLEOFError)
|
||||
or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
|
||||
):
|
||||
if self.standard_compatible:
|
||||
raise BrokenResourceError from exc
|
||||
else:
|
||||
raise EndOfStream from None
|
||||
|
||||
raise
|
||||
else:
|
||||
# Flush any pending writes first
|
||||
if self._write_bio.pending:
|
||||
await self.transport_stream.send(self._write_bio.read())
|
||||
|
||||
return result
|
||||
|
||||
async def unwrap(self) -> Tuple[AnyByteStream, bytes]:
|
||||
"""
|
||||
Does the TLS closing handshake.
|
||||
|
||||
:return: a tuple of (wrapped byte stream, bytes left in the read buffer)
|
||||
|
||||
"""
|
||||
await self._call_sslobject_method(self._ssl_object.unwrap)
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
return self.transport_stream, self._read_bio.read()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.standard_compatible:
|
||||
try:
|
||||
await self.unwrap()
|
||||
except BaseException:
|
||||
await aclose_forcefully(self.transport_stream)
|
||||
raise
|
||||
|
||||
await self.transport_stream.aclose()
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
|
||||
if not data:
|
||||
raise EndOfStream
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
await self._call_sslobject_method(self._ssl_object.write, item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
tls_version = self.extra(TLSAttribute.tls_version)
|
||||
match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
|
||||
if match:
|
||||
major, minor = int(match.group(1)), int(match.group(2) or 0)
|
||||
if (major, minor) < (1, 3):
|
||||
raise NotImplementedError(
|
||||
f"send_eof() requires at least TLSv1.3; current "
|
||||
f"session uses {tls_version}"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"send_eof() has not yet been implemented for TLS streams"
|
||||
)
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self.transport_stream.extra_attributes,
|
||||
TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
|
||||
TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
|
||||
TLSAttribute.cipher: self._ssl_object.cipher,
|
||||
TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
|
||||
TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
|
||||
True
|
||||
),
|
||||
TLSAttribute.server_side: lambda: self._ssl_object.server_side,
|
||||
TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(),
|
||||
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
|
||||
TLSAttribute.ssl_object: lambda: self._ssl_object,
|
||||
TLSAttribute.tls_version: self._ssl_object.version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TLSListener(Listener[TLSStream]):
|
||||
"""
|
||||
A convenience listener that wraps another listener and auto-negotiates a TLS session on every
|
||||
accepted connection.
|
||||
|
||||
If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
|
||||
called to do whatever post-mortem processing is deemed necessary.
|
||||
|
||||
Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
|
||||
|
||||
:param Listener listener: the listener to wrap
|
||||
:param ssl_context: the SSL context object
|
||||
:param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
|
||||
:param handshake_timeout: time limit for the TLS handshake
|
||||
(passed to :func:`~anyio.fail_after`)
|
||||
"""
|
||||
|
||||
listener: Listener[Any]
|
||||
ssl_context: ssl.SSLContext
|
||||
standard_compatible: bool = True
|
||||
handshake_timeout: float = 30
|
||||
|
||||
@staticmethod
|
||||
async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
|
||||
f"""
|
||||
Handle an exception raised during the TLS handshake.
|
||||
|
||||
This method does 3 things:
|
||||
|
||||
#. Forcefully closes the original stream
|
||||
#. Logs the exception (unless it was a cancellation exception) using the ``{__name__}``
|
||||
logger
|
||||
#. Reraises the exception if it was a base exception or a cancellation exception
|
||||
|
||||
:param exc: the exception
|
||||
:param stream: the original stream
|
||||
|
||||
"""
|
||||
await aclose_forcefully(stream)
|
||||
|
||||
# Log all except cancellation exceptions
|
||||
if not isinstance(exc, get_cancelled_exc_class()):
|
||||
logging.getLogger(__name__).exception("Error during TLS handshake")
|
||||
|
||||
# Only reraise base exceptions and cancellation exceptions
|
||||
if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
|
||||
raise
|
||||
|
||||
async def serve(
|
||||
self,
|
||||
handler: Callable[[TLSStream], Any],
|
||||
task_group: Optional[TaskGroup] = None,
|
||||
) -> None:
|
||||
@wraps(handler)
|
||||
async def handler_wrapper(stream: AnyByteStream) -> None:
|
||||
from .. import fail_after
|
||||
|
||||
try:
|
||||
with fail_after(self.handshake_timeout):
|
||||
wrapped_stream = await TLSStream.wrap(
|
||||
stream,
|
||||
ssl_context=self.ssl_context,
|
||||
standard_compatible=self.standard_compatible,
|
||||
)
|
||||
except BaseException as exc:
|
||||
await self.handle_handshake_error(exc, stream)
|
||||
else:
|
||||
await handler(wrapped_stream)
|
||||
|
||||
await self.listener.serve(handler_wrapper, task_group)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.listener.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
|
||||
}
|
@ -0,0 +1,247 @@
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import deque
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from typing import Callable, Deque, List, Optional, Set, Tuple, TypeVar, cast
|
||||
|
||||
from ._core._eventloop import current_time, get_asynclib, get_cancelled_exc_class
|
||||
from ._core._exceptions import BrokenWorkerProcess
|
||||
from ._core._subprocesses import open_process
|
||||
from ._core._synchronization import CapacityLimiter
|
||||
from ._core._tasks import CancelScope, fail_after
|
||||
from .abc import ByteReceiveStream, ByteSendStream, Process
|
||||
from .lowlevel import RunVar, checkpoint_if_cancelled
|
||||
from .streams.buffered import BufferedByteReceiveStream
|
||||
|
||||
WORKER_MAX_IDLE_TIME = 300 # 5 minutes
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
_process_pool_workers: RunVar[Set[Process]] = RunVar("_process_pool_workers")
|
||||
_process_pool_idle_workers: RunVar[Deque[Tuple[Process, float]]] = RunVar(
|
||||
"_process_pool_idle_workers"
|
||||
)
|
||||
_default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter")
|
||||
|
||||
|
||||
async def run_sync(
|
||||
func: Callable[..., T_Retval],
|
||||
*args: object,
|
||||
cancellable: bool = False,
|
||||
limiter: Optional[CapacityLimiter] = None,
|
||||
) -> T_Retval:
|
||||
"""
|
||||
Call the given function with the given arguments in a worker process.
|
||||
|
||||
If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled,
|
||||
the worker process running it will be abruptly terminated using SIGKILL (or
|
||||
``terminateProcess()`` on Windows).
|
||||
|
||||
:param func: a callable
|
||||
:param args: positional arguments for the callable
|
||||
:param cancellable: ``True`` to allow cancellation of the operation while it's running
|
||||
:param limiter: capacity limiter to use to limit the total amount of processes running
|
||||
(if omitted, the default limiter is used)
|
||||
:return: an awaitable that yields the return value of the function.
|
||||
|
||||
"""
|
||||
|
||||
async def send_raw_command(pickled_cmd: bytes) -> object:
|
||||
try:
|
||||
await stdin.send(pickled_cmd)
|
||||
response = await buffered.receive_until(b"\n", 50)
|
||||
status, length = response.split(b" ")
|
||||
if status not in (b"RETURN", b"EXCEPTION"):
|
||||
raise RuntimeError(
|
||||
f"Worker process returned unexpected response: {response!r}"
|
||||
)
|
||||
|
||||
pickled_response = await buffered.receive_exactly(int(length))
|
||||
except BaseException as exc:
|
||||
workers.discard(process)
|
||||
try:
|
||||
process.kill()
|
||||
with CancelScope(shield=True):
|
||||
await process.aclose()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
if isinstance(exc, get_cancelled_exc_class()):
|
||||
raise
|
||||
else:
|
||||
raise BrokenWorkerProcess from exc
|
||||
|
||||
retval = pickle.loads(pickled_response)
|
||||
if status == b"EXCEPTION":
|
||||
assert isinstance(retval, BaseException)
|
||||
raise retval
|
||||
else:
|
||||
return retval
|
||||
|
||||
# First pickle the request before trying to reserve a worker process
|
||||
await checkpoint_if_cancelled()
|
||||
request = pickle.dumps(("run", func, args), protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# If this is the first run in this event loop thread, set up the necessary variables
|
||||
try:
|
||||
workers = _process_pool_workers.get()
|
||||
idle_workers = _process_pool_idle_workers.get()
|
||||
except LookupError:
|
||||
workers = set()
|
||||
idle_workers = deque()
|
||||
_process_pool_workers.set(workers)
|
||||
_process_pool_idle_workers.set(idle_workers)
|
||||
get_asynclib().setup_process_pool_exit_at_shutdown(workers)
|
||||
|
||||
async with (limiter or current_default_process_limiter()):
|
||||
# Pop processes from the pool (starting from the most recently used) until we find one that
|
||||
# hasn't exited yet
|
||||
process: Process
|
||||
while idle_workers:
|
||||
process, idle_since = idle_workers.pop()
|
||||
if process.returncode is None:
|
||||
stdin = cast(ByteSendStream, process.stdin)
|
||||
buffered = BufferedByteReceiveStream(
|
||||
cast(ByteReceiveStream, process.stdout)
|
||||
)
|
||||
|
||||
# Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME seconds or
|
||||
# longer
|
||||
now = current_time()
|
||||
killed_processes: List[Process] = []
|
||||
while idle_workers:
|
||||
if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME:
|
||||
break
|
||||
|
||||
process, idle_since = idle_workers.popleft()
|
||||
process.kill()
|
||||
workers.remove(process)
|
||||
killed_processes.append(process)
|
||||
|
||||
with CancelScope(shield=True):
|
||||
for process in killed_processes:
|
||||
await process.aclose()
|
||||
|
||||
break
|
||||
|
||||
workers.remove(process)
|
||||
else:
|
||||
command = [sys.executable, "-u", "-m", __name__]
|
||||
process = await open_process(
|
||||
command, stdin=subprocess.PIPE, stdout=subprocess.PIPE
|
||||
)
|
||||
try:
|
||||
stdin = cast(ByteSendStream, process.stdin)
|
||||
buffered = BufferedByteReceiveStream(
|
||||
cast(ByteReceiveStream, process.stdout)
|
||||
)
|
||||
with fail_after(20):
|
||||
message = await buffered.receive(6)
|
||||
|
||||
if message != b"READY\n":
|
||||
raise BrokenWorkerProcess(
|
||||
f"Worker process returned unexpected response: {message!r}"
|
||||
)
|
||||
|
||||
main_module_path = getattr(sys.modules["__main__"], "__file__", None)
|
||||
pickled = pickle.dumps(
|
||||
("init", sys.path, main_module_path),
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
await send_raw_command(pickled)
|
||||
except (BrokenWorkerProcess, get_cancelled_exc_class()):
|
||||
raise
|
||||
except BaseException as exc:
|
||||
process.kill()
|
||||
raise BrokenWorkerProcess(
|
||||
"Error during worker process initialization"
|
||||
) from exc
|
||||
|
||||
workers.add(process)
|
||||
|
||||
with CancelScope(shield=not cancellable):
|
||||
try:
|
||||
return cast(T_Retval, await send_raw_command(request))
|
||||
finally:
|
||||
if process in workers:
|
||||
idle_workers.append((process, current_time()))
|
||||
|
||||
|
||||
def current_default_process_limiter() -> CapacityLimiter:
|
||||
"""
|
||||
Return the capacity limiter that is used by default to limit the number of worker processes.
|
||||
|
||||
:return: a capacity limiter object
|
||||
|
||||
"""
|
||||
try:
|
||||
return _default_process_limiter.get()
|
||||
except LookupError:
|
||||
limiter = CapacityLimiter(os.cpu_count() or 2)
|
||||
_default_process_limiter.set(limiter)
|
||||
return limiter
|
||||
|
||||
|
||||
def process_worker() -> None:
|
||||
# Redirect standard streams to os.devnull so that user code won't interfere with the
|
||||
# parent-worker communication
|
||||
stdin = sys.stdin
|
||||
stdout = sys.stdout
|
||||
sys.stdin = open(os.devnull)
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
stdout.buffer.write(b"READY\n")
|
||||
while True:
|
||||
retval = exception = None
|
||||
try:
|
||||
command, *args = pickle.load(stdin.buffer)
|
||||
except EOFError:
|
||||
return
|
||||
except BaseException as exc:
|
||||
exception = exc
|
||||
else:
|
||||
if command == "run":
|
||||
func, args = args
|
||||
try:
|
||||
retval = func(*args)
|
||||
except BaseException as exc:
|
||||
exception = exc
|
||||
elif command == "init":
|
||||
main_module_path: Optional[str]
|
||||
sys.path, main_module_path = args
|
||||
del sys.modules["__main__"]
|
||||
if main_module_path:
|
||||
# Load the parent's main module but as __mp_main__ instead of __main__
|
||||
# (like multiprocessing does) to avoid infinite recursion
|
||||
try:
|
||||
spec = spec_from_file_location("__mp_main__", main_module_path)
|
||||
if spec and spec.loader:
|
||||
main = module_from_spec(spec)
|
||||
spec.loader.exec_module(main)
|
||||
sys.modules["__main__"] = main
|
||||
except BaseException as exc:
|
||||
exception = exc
|
||||
|
||||
try:
|
||||
if exception is not None:
|
||||
status = b"EXCEPTION"
|
||||
pickled = pickle.dumps(exception, pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
status = b"RETURN"
|
||||
pickled = pickle.dumps(retval, pickle.HIGHEST_PROTOCOL)
|
||||
except BaseException as exc:
|
||||
exception = exc
|
||||
status = b"EXCEPTION"
|
||||
pickled = pickle.dumps(exc, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
stdout.buffer.write(b"%s %d\n" % (status, len(pickled)))
|
||||
stdout.buffer.write(pickled)
|
||||
|
||||
# Respect SIGTERM
|
||||
if isinstance(exception, SystemExit):
|
||||
raise exception
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
process_worker()
|
@ -0,0 +1,65 @@
|
||||
from typing import Callable, Optional, TypeVar
|
||||
from warnings import warn
|
||||
|
||||
from ._core._eventloop import get_asynclib
|
||||
from .abc import CapacityLimiter
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
|
||||
|
||||
async def run_sync(
|
||||
func: Callable[..., T_Retval],
|
||||
*args: object,
|
||||
cancellable: bool = False,
|
||||
limiter: Optional[CapacityLimiter] = None
|
||||
) -> T_Retval:
|
||||
"""
|
||||
Call the given function with the given arguments in a worker thread.
|
||||
|
||||
If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled,
|
||||
the thread will still run its course but its return value (or any raised exception) will be
|
||||
ignored.
|
||||
|
||||
:param func: a callable
|
||||
:param args: positional arguments for the callable
|
||||
:param cancellable: ``True`` to allow cancellation of the operation
|
||||
:param limiter: capacity limiter to use to limit the total amount of threads running
|
||||
(if omitted, the default limiter is used)
|
||||
:return: an awaitable that yields the return value of the function.
|
||||
|
||||
"""
|
||||
return await get_asynclib().run_sync_in_worker_thread(
|
||||
func, *args, cancellable=cancellable, limiter=limiter
|
||||
)
|
||||
|
||||
|
||||
async def run_sync_in_worker_thread(
|
||||
func: Callable[..., T_Retval],
|
||||
*args: object,
|
||||
cancellable: bool = False,
|
||||
limiter: Optional[CapacityLimiter] = None
|
||||
) -> T_Retval:
|
||||
warn(
|
||||
"run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return await run_sync(func, *args, cancellable=cancellable, limiter=limiter)
|
||||
|
||||
|
||||
def current_default_thread_limiter() -> CapacityLimiter:
|
||||
"""
|
||||
Return the capacity limiter that is used by default to limit the number of concurrent threads.
|
||||
|
||||
:return: a capacity limiter object
|
||||
|
||||
"""
|
||||
return get_asynclib().current_default_thread_limiter()
|
||||
|
||||
|
||||
def current_default_worker_thread_limiter() -> CapacityLimiter:
|
||||
warn(
|
||||
"current_default_worker_thread_limiter() has been deprecated, "
|
||||
"use anyio.to_thread.current_default_thread_limiter() instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return current_default_thread_limiter()
|
@ -0,0 +1 @@
|
||||
pip
|
@ -0,0 +1,28 @@
|
||||
Copyright 2014 Pallets
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
|
||||
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
@ -0,0 +1,111 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: click
|
||||
Version: 8.1.3
|
||||
Summary: Composable command line interface toolkit
|
||||
Home-page: https://palletsprojects.com/p/click/
|
||||
Author: Armin Ronacher
|
||||
Author-email: armin.ronacher@active-4.com
|
||||
Maintainer: Pallets
|
||||
Maintainer-email: contact@palletsprojects.com
|
||||
License: BSD-3-Clause
|
||||
Project-URL: Donate, https://palletsprojects.com/donate
|
||||
Project-URL: Documentation, https://click.palletsprojects.com/
|
||||
Project-URL: Changes, https://click.palletsprojects.com/changes/
|
||||
Project-URL: Source Code, https://github.com/pallets/click/
|
||||
Project-URL: Issue Tracker, https://github.com/pallets/click/issues/
|
||||
Project-URL: Twitter, https://twitter.com/PalletsTeam
|
||||
Project-URL: Chat, https://discord.gg/pallets
|
||||
Platform: UNKNOWN
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: License :: OSI Approved :: BSD License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Programming Language :: Python
|
||||
Requires-Python: >=3.7
|
||||
Description-Content-Type: text/x-rst
|
||||
License-File: LICENSE.rst
|
||||
Requires-Dist: colorama ; platform_system == "Windows"
|
||||
Requires-Dist: importlib-metadata ; python_version < "3.8"
|
||||
|
||||
\$ click\_
|
||||
==========
|
||||
|
||||
Click is a Python package for creating beautiful command line interfaces
|
||||
in a composable way with as little code as necessary. It's the "Command
|
||||
Line Interface Creation Kit". It's highly configurable but comes with
|
||||
sensible defaults out of the box.
|
||||
|
||||
It aims to make the process of writing command line tools quick and fun
|
||||
while also preventing any frustration caused by the inability to
|
||||
implement an intended CLI API.
|
||||
|
||||
Click in three points:
|
||||
|
||||
- Arbitrary nesting of commands
|
||||
- Automatic help page generation
|
||||
- Supports lazy loading of subcommands at runtime
|
||||
|
||||
|
||||
Installing
|
||||
----------
|
||||
|
||||
Install and update using `pip`_:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ pip install -U click
|
||||
|
||||
.. _pip: https://pip.pypa.io/en/stable/getting-started/
|
||||
|
||||
|
||||
A Simple Example
|
||||
----------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import click
|
||||
|
||||
@click.command()
|
||||
@click.option("--count", default=1, help="Number of greetings.")
|
||||
@click.option("--name", prompt="Your name", help="The person to greet.")
|
||||
def hello(count, name):
|
||||
"""Simple program that greets NAME for a total of COUNT times."""
|
||||
for _ in range(count):
|
||||
click.echo(f"Hello, {name}!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
hello()
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ python hello.py --count=3
|
||||
Your name: Click
|
||||
Hello, Click!
|
||||
Hello, Click!
|
||||
Hello, Click!
|
||||
|
||||
|
||||
Donate
|
||||
------
|
||||
|
||||
The Pallets organization develops and supports Click and other popular
|
||||
packages. In order to grow the community of contributors and users, and
|
||||
allow the maintainers to devote more time to the projects, `please
|
||||
donate today`_.
|
||||
|
||||
.. _please donate today: https://palletsprojects.com/donate
|
||||
|
||||
|
||||
Links
|
||||
-----
|
||||
|
||||
- Documentation: https://click.palletsprojects.com/
|
||||
- Changes: https://click.palletsprojects.com/changes/
|
||||
- PyPI Releases: https://pypi.org/project/click/
|
||||
- Source Code: https://github.com/pallets/click
|
||||
- Issue Tracker: https://github.com/pallets/click/issues
|
||||
- Website: https://palletsprojects.com/p/click
|
||||
- Twitter: https://twitter.com/PalletsTeam
|
||||
- Chat: https://discord.gg/pallets
|
||||
|
||||
|
@ -0,0 +1,39 @@
|
||||
click-8.1.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
click-8.1.3.dist-info/LICENSE.rst,sha256=morRBqOU6FO_4h9C9OctWSgZoigF2ZG18ydQKSkrZY0,1475
|
||||
click-8.1.3.dist-info/METADATA,sha256=tFJIX5lOjx7c5LjZbdTPFVDJSgyv9F74XY0XCPp_gnc,3247
|
||||
click-8.1.3.dist-info/RECORD,,
|
||||
click-8.1.3.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
||||
click-8.1.3.dist-info/top_level.txt,sha256=J1ZQogalYS4pphY_lPECoNMfw0HzTSrZglC4Yfwo4xA,6
|
||||
click/__init__.py,sha256=rQBLutqg-z6m8nOzivIfigDn_emijB_dKv9BZ2FNi5s,3138
|
||||
click/__pycache__/__init__.cpython-37.pyc,,
|
||||
click/__pycache__/_compat.cpython-37.pyc,,
|
||||
click/__pycache__/_termui_impl.cpython-37.pyc,,
|
||||
click/__pycache__/_textwrap.cpython-37.pyc,,
|
||||
click/__pycache__/_winconsole.cpython-37.pyc,,
|
||||
click/__pycache__/core.cpython-37.pyc,,
|
||||
click/__pycache__/decorators.cpython-37.pyc,,
|
||||
click/__pycache__/exceptions.cpython-37.pyc,,
|
||||
click/__pycache__/formatting.cpython-37.pyc,,
|
||||
click/__pycache__/globals.cpython-37.pyc,,
|
||||
click/__pycache__/parser.cpython-37.pyc,,
|
||||
click/__pycache__/shell_completion.cpython-37.pyc,,
|
||||
click/__pycache__/termui.cpython-37.pyc,,
|
||||
click/__pycache__/testing.cpython-37.pyc,,
|
||||
click/__pycache__/types.cpython-37.pyc,,
|
||||
click/__pycache__/utils.cpython-37.pyc,,
|
||||
click/_compat.py,sha256=JIHLYs7Jzz4KT9t-ds4o4jBzLjnwCiJQKqur-5iwCKI,18810
|
||||
click/_termui_impl.py,sha256=qK6Cfy4mRFxvxE8dya8RBhLpSC8HjF-lvBc6aNrPdwg,23451
|
||||
click/_textwrap.py,sha256=10fQ64OcBUMuK7mFvh8363_uoOxPlRItZBmKzRJDgoY,1353
|
||||
click/_winconsole.py,sha256=5ju3jQkcZD0W27WEMGqmEP4y_crUVzPCqsX_FYb7BO0,7860
|
||||
click/core.py,sha256=mz87bYEKzIoNYEa56BFAiOJnvt1Y0L-i7wD4_ZecieE,112782
|
||||
click/decorators.py,sha256=yo3zvzgUm5q7h5CXjyV6q3h_PJAiUaem178zXwdWUFI,16350
|
||||
click/exceptions.py,sha256=7gDaLGuFZBeCNwY9ERMsF2-Z3R9Fvq09Zc6IZSKjseo,9167
|
||||
click/formatting.py,sha256=Frf0-5W33-loyY_i9qrwXR8-STnW3m5gvyxLVUdyxyk,9706
|
||||
click/globals.py,sha256=TP-qM88STzc7f127h35TD_v920FgfOD2EwzqA0oE8XU,1961
|
||||
click/parser.py,sha256=cAEt1uQR8gq3-S9ysqbVU-fdAZNvilxw4ReJ_T1OQMk,19044
|
||||
click/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
click/shell_completion.py,sha256=qOp_BeC9esEOSZKyu5G7RIxEUaLsXUX-mTb7hB1r4QY,18018
|
||||
click/termui.py,sha256=ACBQVOvFCTSqtD5VREeCAdRtlHd-Imla-Lte4wSfMjA,28355
|
||||
click/testing.py,sha256=ptpMYgRY7dVfE3UDgkgwayu9ePw98sQI3D7zZXiCpj4,16063
|
||||
click/types.py,sha256=rEb1aZSQKq3ciCMmjpG2Uva9vk498XRL7ThrcK2GRss,35805
|
||||
click/utils.py,sha256=33D6E7poH_nrKB-xr-UyDEXnxOcCiQqxuRLtrqeVv6o,18682
|
@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.37.1)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
@ -0,0 +1 @@
|
||||
click
|
@ -0,0 +1,73 @@
|
||||
"""
|
||||
Click is a simple Python module inspired by the stdlib optparse to make
|
||||
writing command line scripts fun. Unlike other modules, it's based
|
||||
around a simple API that does not come with too much magic and is
|
||||
composable.
|
||||
"""
|
||||
from .core import Argument as Argument
|
||||
from .core import BaseCommand as BaseCommand
|
||||
from .core import Command as Command
|
||||
from .core import CommandCollection as CommandCollection
|
||||
from .core import Context as Context
|
||||
from .core import Group as Group
|
||||
from .core import MultiCommand as MultiCommand
|
||||
from .core import Option as Option
|
||||
from .core import Parameter as Parameter
|
||||
from .decorators import argument as argument
|
||||
from .decorators import command as command
|
||||
from .decorators import confirmation_option as confirmation_option
|
||||
from .decorators import group as group
|
||||
from .decorators import help_option as help_option
|
||||
from .decorators import make_pass_decorator as make_pass_decorator
|
||||
from .decorators import option as option
|
||||
from .decorators import pass_context as pass_context
|
||||
from .decorators import pass_obj as pass_obj
|
||||
from .decorators import password_option as password_option
|
||||
from .decorators import version_option as version_option
|
||||
from .exceptions import Abort as Abort
|
||||
from .exceptions import BadArgumentUsage as BadArgumentUsage
|
||||
from .exceptions import BadOptionUsage as BadOptionUsage
|
||||
from .exceptions import BadParameter as BadParameter
|
||||
from .exceptions import ClickException as ClickException
|
||||
from .exceptions import FileError as FileError
|
||||
from .exceptions import MissingParameter as MissingParameter
|
||||
from .exceptions import NoSuchOption as NoSuchOption
|
||||
from .exceptions import UsageError as UsageError
|
||||
from .formatting import HelpFormatter as HelpFormatter
|
||||
from .formatting import wrap_text as wrap_text
|
||||
from .globals import get_current_context as get_current_context
|
||||
from .parser import OptionParser as OptionParser
|
||||
from .termui import clear as clear
|
||||
from .termui import confirm as confirm
|
||||
from .termui import echo_via_pager as echo_via_pager
|
||||
from .termui import edit as edit
|
||||
from .termui import getchar as getchar
|
||||
from .termui import launch as launch
|
||||
from .termui import pause as pause
|
||||
from .termui import progressbar as progressbar
|
||||
from .termui import prompt as prompt
|
||||
from .termui import secho as secho
|
||||
from .termui import style as style
|
||||
from .termui import unstyle as unstyle
|
||||
from .types import BOOL as BOOL
|
||||
from .types import Choice as Choice
|
||||
from .types import DateTime as DateTime
|
||||
from .types import File as File
|
||||
from .types import FLOAT as FLOAT
|
||||
from .types import FloatRange as FloatRange
|
||||
from .types import INT as INT
|
||||
from .types import IntRange as IntRange
|
||||
from .types import ParamType as ParamType
|
||||
from .types import Path as Path
|
||||
from .types import STRING as STRING
|
||||
from .types import Tuple as Tuple
|
||||
from .types import UNPROCESSED as UNPROCESSED
|
||||
from .types import UUID as UUID
|
||||
from .utils import echo as echo
|
||||
from .utils import format_filename as format_filename
|
||||
from .utils import get_app_dir as get_app_dir
|
||||
from .utils import get_binary_stream as get_binary_stream
|
||||
from .utils import get_text_stream as get_text_stream
|
||||
from .utils import open_file as open_file
|
||||
|
||||
__version__ = "8.1.3"
|
@ -0,0 +1,626 @@
|
||||
import codecs
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
CYGWIN = sys.platform.startswith("cygwin")
|
||||
MSYS2 = sys.platform.startswith("win") and ("GCC" in sys.version)
|
||||
# Determine local App Engine environment, per Google's own suggestion
|
||||
APP_ENGINE = "APPENGINE_RUNTIME" in os.environ and "Development/" in os.environ.get(
|
||||
"SERVER_SOFTWARE", ""
|
||||
)
|
||||
WIN = sys.platform.startswith("win") and not APP_ENGINE and not MSYS2
|
||||
auto_wrap_for_ansi: t.Optional[t.Callable[[t.TextIO], t.TextIO]] = None
|
||||
_ansi_re = re.compile(r"\033\[[;?0-9]*[a-zA-Z]")
|
||||
|
||||
|
||||
def get_filesystem_encoding() -> str:
|
||||
return sys.getfilesystemencoding() or sys.getdefaultencoding()
|
||||
|
||||
|
||||
def _make_text_stream(
|
||||
stream: t.BinaryIO,
|
||||
encoding: t.Optional[str],
|
||||
errors: t.Optional[str],
|
||||
force_readable: bool = False,
|
||||
force_writable: bool = False,
|
||||
) -> t.TextIO:
|
||||
if encoding is None:
|
||||
encoding = get_best_encoding(stream)
|
||||
if errors is None:
|
||||
errors = "replace"
|
||||
return _NonClosingTextIOWrapper(
|
||||
stream,
|
||||
encoding,
|
||||
errors,
|
||||
line_buffering=True,
|
||||
force_readable=force_readable,
|
||||
force_writable=force_writable,
|
||||
)
|
||||
|
||||
|
||||
def is_ascii_encoding(encoding: str) -> bool:
|
||||
"""Checks if a given encoding is ascii."""
|
||||
try:
|
||||
return codecs.lookup(encoding).name == "ascii"
|
||||
except LookupError:
|
||||
return False
|
||||
|
||||
|
||||
def get_best_encoding(stream: t.IO) -> str:
|
||||
"""Returns the default stream encoding if not found."""
|
||||
rv = getattr(stream, "encoding", None) or sys.getdefaultencoding()
|
||||
if is_ascii_encoding(rv):
|
||||
return "utf-8"
|
||||
return rv
|
||||
|
||||
|
||||
class _NonClosingTextIOWrapper(io.TextIOWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
stream: t.BinaryIO,
|
||||
encoding: t.Optional[str],
|
||||
errors: t.Optional[str],
|
||||
force_readable: bool = False,
|
||||
force_writable: bool = False,
|
||||
**extra: t.Any,
|
||||
) -> None:
|
||||
self._stream = stream = t.cast(
|
||||
t.BinaryIO, _FixupStream(stream, force_readable, force_writable)
|
||||
)
|
||||
super().__init__(stream, encoding, errors, **extra)
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.detach()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def isatty(self) -> bool:
|
||||
# https://bitbucket.org/pypy/pypy/issue/1803
|
||||
return self._stream.isatty()
|
||||
|
||||
|
||||
class _FixupStream:
|
||||
"""The new io interface needs more from streams than streams
|
||||
traditionally implement. As such, this fix-up code is necessary in
|
||||
some circumstances.
|
||||
|
||||
The forcing of readable and writable flags are there because some tools
|
||||
put badly patched objects on sys (one such offender are certain version
|
||||
of jupyter notebook).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: t.BinaryIO,
|
||||
force_readable: bool = False,
|
||||
force_writable: bool = False,
|
||||
):
|
||||
self._stream = stream
|
||||
self._force_readable = force_readable
|
||||
self._force_writable = force_writable
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
return getattr(self._stream, name)
|
||||
|
||||
def read1(self, size: int) -> bytes:
|
||||
f = getattr(self._stream, "read1", None)
|
||||
|
||||
if f is not None:
|
||||
return t.cast(bytes, f(size))
|
||||
|
||||
return self._stream.read(size)
|
||||
|
||||
def readable(self) -> bool:
|
||||
if self._force_readable:
|
||||
return True
|
||||
x = getattr(self._stream, "readable", None)
|
||||
if x is not None:
|
||||
return t.cast(bool, x())
|
||||
try:
|
||||
self._stream.read(0)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def writable(self) -> bool:
|
||||
if self._force_writable:
|
||||
return True
|
||||
x = getattr(self._stream, "writable", None)
|
||||
if x is not None:
|
||||
return t.cast(bool, x())
|
||||
try:
|
||||
self._stream.write("") # type: ignore
|
||||
except Exception:
|
||||
try:
|
||||
self._stream.write(b"")
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def seekable(self) -> bool:
|
||||
x = getattr(self._stream, "seekable", None)
|
||||
if x is not None:
|
||||
return t.cast(bool, x())
|
||||
try:
|
||||
self._stream.seek(self._stream.tell())
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_binary_reader(stream: t.IO, default: bool = False) -> bool:
|
||||
try:
|
||||
return isinstance(stream.read(0), bytes)
|
||||
except Exception:
|
||||
return default
|
||||
# This happens in some cases where the stream was already
|
||||
# closed. In this case, we assume the default.
|
||||
|
||||
|
||||
def _is_binary_writer(stream: t.IO, default: bool = False) -> bool:
|
||||
try:
|
||||
stream.write(b"")
|
||||
except Exception:
|
||||
try:
|
||||
stream.write("")
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return default
|
||||
return True
|
||||
|
||||
|
||||
def _find_binary_reader(stream: t.IO) -> t.Optional[t.BinaryIO]:
|
||||
# We need to figure out if the given stream is already binary.
|
||||
# This can happen because the official docs recommend detaching
|
||||
# the streams to get binary streams. Some code might do this, so
|
||||
# we need to deal with this case explicitly.
|
||||
if _is_binary_reader(stream, False):
|
||||
return t.cast(t.BinaryIO, stream)
|
||||
|
||||
buf = getattr(stream, "buffer", None)
|
||||
|
||||
# Same situation here; this time we assume that the buffer is
|
||||
# actually binary in case it's closed.
|
||||
if buf is not None and _is_binary_reader(buf, True):
|
||||
return t.cast(t.BinaryIO, buf)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_binary_writer(stream: t.IO) -> t.Optional[t.BinaryIO]:
|
||||
# We need to figure out if the given stream is already binary.
|
||||
# This can happen because the official docs recommend detaching
|
||||
# the streams to get binary streams. Some code might do this, so
|
||||
# we need to deal with this case explicitly.
|
||||
if _is_binary_writer(stream, False):
|
||||
return t.cast(t.BinaryIO, stream)
|
||||
|
||||
buf = getattr(stream, "buffer", None)
|
||||
|
||||
# Same situation here; this time we assume that the buffer is
|
||||
# actually binary in case it's closed.
|
||||
if buf is not None and _is_binary_writer(buf, True):
|
||||
return t.cast(t.BinaryIO, buf)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _stream_is_misconfigured(stream: t.TextIO) -> bool:
|
||||
"""A stream is misconfigured if its encoding is ASCII."""
|
||||
# If the stream does not have an encoding set, we assume it's set
|
||||
# to ASCII. This appears to happen in certain unittest
|
||||
# environments. It's not quite clear what the correct behavior is
|
||||
# but this at least will force Click to recover somehow.
|
||||
return is_ascii_encoding(getattr(stream, "encoding", None) or "ascii")
|
||||
|
||||
|
||||
def _is_compat_stream_attr(stream: t.TextIO, attr: str, value: t.Optional[str]) -> bool:
|
||||
"""A stream attribute is compatible if it is equal to the
|
||||
desired value or the desired value is unset and the attribute
|
||||
has a value.
|
||||
"""
|
||||
stream_value = getattr(stream, attr, None)
|
||||
return stream_value == value or (value is None and stream_value is not None)
|
||||
|
||||
|
||||
def _is_compatible_text_stream(
|
||||
stream: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
|
||||
) -> bool:
|
||||
"""Check if a stream's encoding and errors attributes are
|
||||
compatible with the desired values.
|
||||
"""
|
||||
return _is_compat_stream_attr(
|
||||
stream, "encoding", encoding
|
||||
) and _is_compat_stream_attr(stream, "errors", errors)
|
||||
|
||||
|
||||
def _force_correct_text_stream(
|
||||
text_stream: t.IO,
|
||||
encoding: t.Optional[str],
|
||||
errors: t.Optional[str],
|
||||
is_binary: t.Callable[[t.IO, bool], bool],
|
||||
find_binary: t.Callable[[t.IO], t.Optional[t.BinaryIO]],
|
||||
force_readable: bool = False,
|
||||
force_writable: bool = False,
|
||||
) -> t.TextIO:
|
||||
if is_binary(text_stream, False):
|
||||
binary_reader = t.cast(t.BinaryIO, text_stream)
|
||||
else:
|
||||
text_stream = t.cast(t.TextIO, text_stream)
|
||||
# If the stream looks compatible, and won't default to a
|
||||
# misconfigured ascii encoding, return it as-is.
|
||||
if _is_compatible_text_stream(text_stream, encoding, errors) and not (
|
||||
encoding is None and _stream_is_misconfigured(text_stream)
|
||||
):
|
||||
return text_stream
|
||||
|
||||
# Otherwise, get the underlying binary reader.
|
||||
possible_binary_reader = find_binary(text_stream)
|
||||
|
||||
# If that's not possible, silently use the original reader
|
||||
# and get mojibake instead of exceptions.
|
||||
if possible_binary_reader is None:
|
||||
return text_stream
|
||||
|
||||
binary_reader = possible_binary_reader
|
||||
|
||||
# Default errors to replace instead of strict in order to get
|
||||
# something that works.
|
||||
if errors is None:
|
||||
errors = "replace"
|
||||
|
||||
# Wrap the binary stream in a text stream with the correct
|
||||
# encoding parameters.
|
||||
return _make_text_stream(
|
||||
binary_reader,
|
||||
encoding,
|
||||
errors,
|
||||
force_readable=force_readable,
|
||||
force_writable=force_writable,
|
||||
)
|
||||
|
||||
|
||||
def _force_correct_text_reader(
|
||||
text_reader: t.IO,
|
||||
encoding: t.Optional[str],
|
||||
errors: t.Optional[str],
|
||||
force_readable: bool = False,
|
||||
) -> t.TextIO:
|
||||
return _force_correct_text_stream(
|
||||
text_reader,
|
||||
encoding,
|
||||
errors,
|
||||
_is_binary_reader,
|
||||
_find_binary_reader,
|
||||
force_readable=force_readable,
|
||||
)
|
||||
|
||||
|
||||
def _force_correct_text_writer(
|
||||
text_writer: t.IO,
|
||||
encoding: t.Optional[str],
|
||||
errors: t.Optional[str],
|
||||
force_writable: bool = False,
|
||||
) -> t.TextIO:
|
||||
return _force_correct_text_stream(
|
||||
text_writer,
|
||||
encoding,
|
||||
errors,
|
||||
_is_binary_writer,
|
||||
_find_binary_writer,
|
||||
force_writable=force_writable,
|
||||
)
|
||||
|
||||
|
||||
def get_binary_stdin() -> t.BinaryIO:
|
||||
reader = _find_binary_reader(sys.stdin)
|
||||
if reader is None:
|
||||
raise RuntimeError("Was not able to determine binary stream for sys.stdin.")
|
||||
return reader
|
||||
|
||||
|
||||
def get_binary_stdout() -> t.BinaryIO:
|
||||
writer = _find_binary_writer(sys.stdout)
|
||||
if writer is None:
|
||||
raise RuntimeError("Was not able to determine binary stream for sys.stdout.")
|
||||
return writer
|
||||
|
||||
|
||||
def get_binary_stderr() -> t.BinaryIO:
|
||||
writer = _find_binary_writer(sys.stderr)
|
||||
if writer is None:
|
||||
raise RuntimeError("Was not able to determine binary stream for sys.stderr.")
|
||||
return writer
|
||||
|
||||
|
||||
def get_text_stdin(
|
||||
encoding: t.Optional[str] = None, errors: t.Optional[str] = None
|
||||
) -> t.TextIO:
|
||||
rv = _get_windows_console_stream(sys.stdin, encoding, errors)
|
||||
if rv is not None:
|
||||
return rv
|
||||
return _force_correct_text_reader(sys.stdin, encoding, errors, force_readable=True)
|
||||
|
||||
|
||||
def get_text_stdout(
|
||||
encoding: t.Optional[str] = None, errors: t.Optional[str] = None
|
||||
) -> t.TextIO:
|
||||
rv = _get_windows_console_stream(sys.stdout, encoding, errors)
|
||||
if rv is not None:
|
||||
return rv
|
||||
return _force_correct_text_writer(sys.stdout, encoding, errors, force_writable=True)
|
||||
|
||||
|
||||
def get_text_stderr(
|
||||
encoding: t.Optional[str] = None, errors: t.Optional[str] = None
|
||||
) -> t.TextIO:
|
||||
rv = _get_windows_console_stream(sys.stderr, encoding, errors)
|
||||
if rv is not None:
|
||||
return rv
|
||||
return _force_correct_text_writer(sys.stderr, encoding, errors, force_writable=True)
|
||||
|
||||
|
||||
def _wrap_io_open(
|
||||
file: t.Union[str, os.PathLike, int],
|
||||
mode: str,
|
||||
encoding: t.Optional[str],
|
||||
errors: t.Optional[str],
|
||||
) -> t.IO:
|
||||
"""Handles not passing ``encoding`` and ``errors`` in binary mode."""
|
||||
if "b" in mode:
|
||||
return open(file, mode)
|
||||
|
||||
return open(file, mode, encoding=encoding, errors=errors)
|
||||
|
||||
|
||||
def open_stream(
|
||||
filename: str,
|
||||
mode: str = "r",
|
||||
encoding: t.Optional[str] = None,
|
||||
errors: t.Optional[str] = "strict",
|
||||
atomic: bool = False,
|
||||
) -> t.Tuple[t.IO, bool]:
|
||||
binary = "b" in mode
|
||||
|
||||
# Standard streams first. These are simple because they ignore the
|
||||
# atomic flag. Use fsdecode to handle Path("-").
|
||||
if os.fsdecode(filename) == "-":
|
||||
if any(m in mode for m in ["w", "a", "x"]):
|
||||
if binary:
|
||||
return get_binary_stdout(), False
|
||||
return get_text_stdout(encoding=encoding, errors=errors), False
|
||||
if binary:
|
||||
return get_binary_stdin(), False
|
||||
return get_text_stdin(encoding=encoding, errors=errors), False
|
||||
|
||||
# Non-atomic writes directly go out through the regular open functions.
|
||||
if not atomic:
|
||||
return _wrap_io_open(filename, mode, encoding, errors), True
|
||||
|
||||
# Some usability stuff for atomic writes
|
||||
if "a" in mode:
|
||||
raise ValueError(
|
||||
"Appending to an existing file is not supported, because that"
|
||||
" would involve an expensive `copy`-operation to a temporary"
|
||||
" file. Open the file in normal `w`-mode and copy explicitly"
|
||||
" if that's what you're after."
|
||||
)
|
||||
if "x" in mode:
|
||||
raise ValueError("Use the `overwrite`-parameter instead.")
|
||||
if "w" not in mode:
|
||||
raise ValueError("Atomic writes only make sense with `w`-mode.")
|
||||
|
||||
# Atomic writes are more complicated. They work by opening a file
|
||||
# as a proxy in the same folder and then using the fdopen
|
||||
# functionality to wrap it in a Python file. Then we wrap it in an
|
||||
# atomic file that moves the file over on close.
|
||||
import errno
|
||||
import random
|
||||
|
||||
try:
|
||||
perm: t.Optional[int] = os.stat(filename).st_mode
|
||||
except OSError:
|
||||
perm = None
|
||||
|
||||
flags = os.O_RDWR | os.O_CREAT | os.O_EXCL
|
||||
|
||||
if binary:
|
||||
flags |= getattr(os, "O_BINARY", 0)
|
||||
|
||||
while True:
|
||||
tmp_filename = os.path.join(
|
||||
os.path.dirname(filename),
|
||||
f".__atomic-write{random.randrange(1 << 32):08x}",
|
||||
)
|
||||
try:
|
||||
fd = os.open(tmp_filename, flags, 0o666 if perm is None else perm)
|
||||
break
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST or (
|
||||
os.name == "nt"
|
||||
and e.errno == errno.EACCES
|
||||
and os.path.isdir(e.filename)
|
||||
and os.access(e.filename, os.W_OK)
|
||||
):
|
||||
continue
|
||||
raise
|
||||
|
||||
if perm is not None:
|
||||
os.chmod(tmp_filename, perm) # in case perm includes bits in umask
|
||||
|
||||
f = _wrap_io_open(fd, mode, encoding, errors)
|
||||
af = _AtomicFile(f, tmp_filename, os.path.realpath(filename))
|
||||
return t.cast(t.IO, af), True
|
||||
|
||||
|
||||
class _AtomicFile:
|
||||
def __init__(self, f: t.IO, tmp_filename: str, real_filename: str) -> None:
|
||||
self._f = f
|
||||
self._tmp_filename = tmp_filename
|
||||
self._real_filename = real_filename
|
||||
self.closed = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._real_filename
|
||||
|
||||
def close(self, delete: bool = False) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
self._f.close()
|
||||
os.replace(self._tmp_filename, self._real_filename)
|
||||
self.closed = True
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
return getattr(self._f, name)
|
||||
|
||||
def __enter__(self) -> "_AtomicFile":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb): # type: ignore
|
||||
self.close(delete=exc_type is not None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._f)
|
||||
|
||||
|
||||
def strip_ansi(value: str) -> str:
|
||||
return _ansi_re.sub("", value)
|
||||
|
||||
|
||||
def _is_jupyter_kernel_output(stream: t.IO) -> bool:
|
||||
while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)):
|
||||
stream = stream._stream
|
||||
|
||||
return stream.__class__.__module__.startswith("ipykernel.")
|
||||
|
||||
|
||||
def should_strip_ansi(
|
||||
stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
|
||||
) -> bool:
|
||||
if color is None:
|
||||
if stream is None:
|
||||
stream = sys.stdin
|
||||
return not isatty(stream) and not _is_jupyter_kernel_output(stream)
|
||||
return not color
|
||||
|
||||
|
||||
# On Windows, wrap the output streams with colorama to support ANSI
|
||||
# color codes.
|
||||
# NOTE: double check is needed so mypy does not analyze this on Linux
|
||||
if sys.platform.startswith("win") and WIN:
|
||||
from ._winconsole import _get_windows_console_stream
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
|
||||
return locale.getpreferredencoding()
|
||||
|
||||
_ansi_stream_wrappers: t.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary()
|
||||
|
||||
def auto_wrap_for_ansi(
|
||||
stream: t.TextIO, color: t.Optional[bool] = None
|
||||
) -> t.TextIO:
|
||||
"""Support ANSI color and style codes on Windows by wrapping a
|
||||
stream with colorama.
|
||||
"""
|
||||
try:
|
||||
cached = _ansi_stream_wrappers.get(stream)
|
||||
except Exception:
|
||||
cached = None
|
||||
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
import colorama
|
||||
|
||||
strip = should_strip_ansi(stream, color)
|
||||
ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip)
|
||||
rv = t.cast(t.TextIO, ansi_wrapper.stream)
|
||||
_write = rv.write
|
||||
|
||||
def _safe_write(s):
|
||||
try:
|
||||
return _write(s)
|
||||
except BaseException:
|
||||
ansi_wrapper.reset_all()
|
||||
raise
|
||||
|
||||
rv.write = _safe_write
|
||||
|
||||
try:
|
||||
_ansi_stream_wrappers[stream] = rv
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return rv
|
||||
|
||||
else:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
return getattr(sys.stdin, "encoding", None) or get_filesystem_encoding()
|
||||
|
||||
def _get_windows_console_stream(
|
||||
f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
|
||||
) -> t.Optional[t.TextIO]:
|
||||
return None
|
||||
|
||||
|
||||
def term_len(x: str) -> int:
|
||||
return len(strip_ansi(x))
|
||||
|
||||
|
||||
def isatty(stream: t.IO) -> bool:
|
||||
try:
|
||||
return stream.isatty()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _make_cached_stream_func(
|
||||
src_func: t.Callable[[], t.TextIO], wrapper_func: t.Callable[[], t.TextIO]
|
||||
) -> t.Callable[[], t.TextIO]:
|
||||
cache: t.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary()
|
||||
|
||||
def func() -> t.TextIO:
|
||||
stream = src_func()
|
||||
try:
|
||||
rv = cache.get(stream)
|
||||
except Exception:
|
||||
rv = None
|
||||
if rv is not None:
|
||||
return rv
|
||||
rv = wrapper_func()
|
||||
try:
|
||||
cache[stream] = rv
|
||||
except Exception:
|
||||
pass
|
||||
return rv
|
||||
|
||||
return func
|
||||
|
||||
|
||||
_default_text_stdin = _make_cached_stream_func(lambda: sys.stdin, get_text_stdin)
|
||||
_default_text_stdout = _make_cached_stream_func(lambda: sys.stdout, get_text_stdout)
|
||||
_default_text_stderr = _make_cached_stream_func(lambda: sys.stderr, get_text_stderr)
|
||||
|
||||
|
||||
binary_streams: t.Mapping[str, t.Callable[[], t.BinaryIO]] = {
|
||||
"stdin": get_binary_stdin,
|
||||
"stdout": get_binary_stdout,
|
||||
"stderr": get_binary_stderr,
|
||||
}
|
||||
|
||||
text_streams: t.Mapping[
|
||||
str, t.Callable[[t.Optional[str], t.Optional[str]], t.TextIO]
|
||||
] = {
|
||||
"stdin": get_text_stdin,
|
||||
"stdout": get_text_stdout,
|
||||
"stderr": get_text_stderr,
|
||||
}
|
@ -0,0 +1,717 @@
|
||||
"""
|
||||
This module contains implementations for the termui module. To keep the
|
||||
import time of Click down, some infrequently used functionality is
|
||||
placed in this module and only imported as needed.
|
||||
"""
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import typing as t
|
||||
from gettext import gettext as _
|
||||
|
||||
from ._compat import _default_text_stdout
|
||||
from ._compat import CYGWIN
|
||||
from ._compat import get_best_encoding
|
||||
from ._compat import isatty
|
||||
from ._compat import open_stream
|
||||
from ._compat import strip_ansi
|
||||
from ._compat import term_len
|
||||
from ._compat import WIN
|
||||
from .exceptions import ClickException
|
||||
from .utils import echo
|
||||
|
||||
V = t.TypeVar("V")
|
||||
|
||||
if os.name == "nt":
|
||||
BEFORE_BAR = "\r"
|
||||
AFTER_BAR = "\n"
|
||||
else:
|
||||
BEFORE_BAR = "\r\033[?25l"
|
||||
AFTER_BAR = "\033[?25h\n"
|
||||
|
||||
|
||||
class ProgressBar(t.Generic[V]):
|
||||
def __init__(
|
||||
self,
|
||||
iterable: t.Optional[t.Iterable[V]],
|
||||
length: t.Optional[int] = None,
|
||||
fill_char: str = "#",
|
||||
empty_char: str = " ",
|
||||
bar_template: str = "%(bar)s",
|
||||
info_sep: str = " ",
|
||||
show_eta: bool = True,
|
||||
show_percent: t.Optional[bool] = None,
|
||||
show_pos: bool = False,
|
||||
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
|
||||
label: t.Optional[str] = None,
|
||||
file: t.Optional[t.TextIO] = None,
|
||||
color: t.Optional[bool] = None,
|
||||
update_min_steps: int = 1,
|
||||
width: int = 30,
|
||||
) -> None:
|
||||
self.fill_char = fill_char
|
||||
self.empty_char = empty_char
|
||||
self.bar_template = bar_template
|
||||
self.info_sep = info_sep
|
||||
self.show_eta = show_eta
|
||||
self.show_percent = show_percent
|
||||
self.show_pos = show_pos
|
||||
self.item_show_func = item_show_func
|
||||
self.label = label or ""
|
||||
if file is None:
|
||||
file = _default_text_stdout()
|
||||
self.file = file
|
||||
self.color = color
|
||||
self.update_min_steps = update_min_steps
|
||||
self._completed_intervals = 0
|
||||
self.width = width
|
||||
self.autowidth = width == 0
|
||||
|
||||
if length is None:
|
||||
from operator import length_hint
|
||||
|
||||
length = length_hint(iterable, -1)
|
||||
|
||||
if length == -1:
|
||||
length = None
|
||||
if iterable is None:
|
||||
if length is None:
|
||||
raise TypeError("iterable or length is required")
|
||||
iterable = t.cast(t.Iterable[V], range(length))
|
||||
self.iter = iter(iterable)
|
||||
self.length = length
|
||||
self.pos = 0
|
||||
self.avg: t.List[float] = []
|
||||
self.start = self.last_eta = time.time()
|
||||
self.eta_known = False
|
||||
self.finished = False
|
||||
self.max_width: t.Optional[int] = None
|
||||
self.entered = False
|
||||
self.current_item: t.Optional[V] = None
|
||||
self.is_hidden = not isatty(self.file)
|
||||
self._last_line: t.Optional[str] = None
|
||||
|
||||
def __enter__(self) -> "ProgressBar":
|
||||
self.entered = True
|
||||
self.render_progress()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb): # type: ignore
|
||||
self.render_finish()
|
||||
|
||||
def __iter__(self) -> t.Iterator[V]:
|
||||
if not self.entered:
|
||||
raise RuntimeError("You need to use progress bars in a with block.")
|
||||
self.render_progress()
|
||||
return self.generator()
|
||||
|
||||
def __next__(self) -> V:
|
||||
# Iteration is defined in terms of a generator function,
|
||||
# returned by iter(self); use that to define next(). This works
|
||||
# because `self.iter` is an iterable consumed by that generator,
|
||||
# so it is re-entry safe. Calling `next(self.generator())`
|
||||
# twice works and does "what you want".
|
||||
return next(iter(self))
|
||||
|
||||
def render_finish(self) -> None:
|
||||
if self.is_hidden:
|
||||
return
|
||||
self.file.write(AFTER_BAR)
|
||||
self.file.flush()
|
||||
|
||||
@property
|
||||
def pct(self) -> float:
|
||||
if self.finished:
|
||||
return 1.0
|
||||
return min(self.pos / (float(self.length or 1) or 1), 1.0)
|
||||
|
||||
@property
|
||||
def time_per_iteration(self) -> float:
|
||||
if not self.avg:
|
||||
return 0.0
|
||||
return sum(self.avg) / float(len(self.avg))
|
||||
|
||||
@property
|
||||
def eta(self) -> float:
|
||||
if self.length is not None and not self.finished:
|
||||
return self.time_per_iteration * (self.length - self.pos)
|
||||
return 0.0
|
||||
|
||||
def format_eta(self) -> str:
|
||||
if self.eta_known:
|
||||
t = int(self.eta)
|
||||
seconds = t % 60
|
||||
t //= 60
|
||||
minutes = t % 60
|
||||
t //= 60
|
||||
hours = t % 24
|
||||
t //= 24
|
||||
if t > 0:
|
||||
return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
|
||||
else:
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
||||
return ""
|
||||
|
||||
def format_pos(self) -> str:
|
||||
pos = str(self.pos)
|
||||
if self.length is not None:
|
||||
pos += f"/{self.length}"
|
||||
return pos
|
||||
|
||||
def format_pct(self) -> str:
|
||||
return f"{int(self.pct * 100): 4}%"[1:]
|
||||
|
||||
def format_bar(self) -> str:
|
||||
if self.length is not None:
|
||||
bar_length = int(self.pct * self.width)
|
||||
bar = self.fill_char * bar_length
|
||||
bar += self.empty_char * (self.width - bar_length)
|
||||
elif self.finished:
|
||||
bar = self.fill_char * self.width
|
||||
else:
|
||||
chars = list(self.empty_char * (self.width or 1))
|
||||
if self.time_per_iteration != 0:
|
||||
chars[
|
||||
int(
|
||||
(math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
|
||||
* self.width
|
||||
)
|
||||
] = self.fill_char
|
||||
bar = "".join(chars)
|
||||
return bar
|
||||
|
||||
def format_progress_line(self) -> str:
|
||||
show_percent = self.show_percent
|
||||
|
||||
info_bits = []
|
||||
if self.length is not None and show_percent is None:
|
||||
show_percent = not self.show_pos
|
||||
|
||||
if self.show_pos:
|
||||
info_bits.append(self.format_pos())
|
||||
if show_percent:
|
||||
info_bits.append(self.format_pct())
|
||||
if self.show_eta and self.eta_known and not self.finished:
|
||||
info_bits.append(self.format_eta())
|
||||
if self.item_show_func is not None:
|
||||
item_info = self.item_show_func(self.current_item)
|
||||
if item_info is not None:
|
||||
info_bits.append(item_info)
|
||||
|
||||
return (
|
||||
self.bar_template
|
||||
% {
|
||||
"label": self.label,
|
||||
"bar": self.format_bar(),
|
||||
"info": self.info_sep.join(info_bits),
|
||||
}
|
||||
).rstrip()
|
||||
|
||||
def render_progress(self) -> None:
|
||||
import shutil
|
||||
|
||||
if self.is_hidden:
|
||||
# Only output the label as it changes if the output is not a
|
||||
# TTY. Use file=stderr if you expect to be piping stdout.
|
||||
if self._last_line != self.label:
|
||||
self._last_line = self.label
|
||||
echo(self.label, file=self.file, color=self.color)
|
||||
|
||||
return
|
||||
|
||||
buf = []
|
||||
# Update width in case the terminal has been resized
|
||||
if self.autowidth:
|
||||
old_width = self.width
|
||||
self.width = 0
|
||||
clutter_length = term_len(self.format_progress_line())
|
||||
new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
|
||||
if new_width < old_width:
|
||||
buf.append(BEFORE_BAR)
|
||||
buf.append(" " * self.max_width) # type: ignore
|
||||
self.max_width = new_width
|
||||
self.width = new_width
|
||||
|
||||
clear_width = self.width
|
||||
if self.max_width is not None:
|
||||
clear_width = self.max_width
|
||||
|
||||
buf.append(BEFORE_BAR)
|
||||
line = self.format_progress_line()
|
||||
line_len = term_len(line)
|
||||
if self.max_width is None or self.max_width < line_len:
|
||||
self.max_width = line_len
|
||||
|
||||
buf.append(line)
|
||||
buf.append(" " * (clear_width - line_len))
|
||||
line = "".join(buf)
|
||||
# Render the line only if it changed.
|
||||
|
||||
if line != self._last_line:
|
||||
self._last_line = line
|
||||
echo(line, file=self.file, color=self.color, nl=False)
|
||||
self.file.flush()
|
||||
|
||||
def make_step(self, n_steps: int) -> None:
|
||||
self.pos += n_steps
|
||||
if self.length is not None and self.pos >= self.length:
|
||||
self.finished = True
|
||||
|
||||
if (time.time() - self.last_eta) < 1.0:
|
||||
return
|
||||
|
||||
self.last_eta = time.time()
|
||||
|
||||
# self.avg is a rolling list of length <= 7 of steps where steps are
|
||||
# defined as time elapsed divided by the total progress through
|
||||
# self.length.
|
||||
if self.pos:
|
||||
step = (time.time() - self.start) / self.pos
|
||||
else:
|
||||
step = time.time() - self.start
|
||||
|
||||
self.avg = self.avg[-6:] + [step]
|
||||
|
||||
self.eta_known = self.length is not None
|
||||
|
||||
def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None:
|
||||
"""Update the progress bar by advancing a specified number of
|
||||
steps, and optionally set the ``current_item`` for this new
|
||||
position.
|
||||
|
||||
:param n_steps: Number of steps to advance.
|
||||
:param current_item: Optional item to set as ``current_item``
|
||||
for the updated position.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Added the ``current_item`` optional parameter.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Only render when the number of steps meets the
|
||||
``update_min_steps`` threshold.
|
||||
"""
|
||||
if current_item is not None:
|
||||
self.current_item = current_item
|
||||
|
||||
self._completed_intervals += n_steps
|
||||
|
||||
if self._completed_intervals >= self.update_min_steps:
|
||||
self.make_step(self._completed_intervals)
|
||||
self.render_progress()
|
||||
self._completed_intervals = 0
|
||||
|
||||
def finish(self) -> None:
|
||||
self.eta_known = False
|
||||
self.current_item = None
|
||||
self.finished = True
|
||||
|
||||
def generator(self) -> t.Iterator[V]:
|
||||
"""Return a generator which yields the items added to the bar
|
||||
during construction, and updates the progress bar *after* the
|
||||
yielded block returns.
|
||||
"""
|
||||
# WARNING: the iterator interface for `ProgressBar` relies on
|
||||
# this and only works because this is a simple generator which
|
||||
# doesn't create or manage additional state. If this function
|
||||
# changes, the impact should be evaluated both against
|
||||
# `iter(bar)` and `next(bar)`. `next()` in particular may call
|
||||
# `self.generator()` repeatedly, and this must remain safe in
|
||||
# order for that interface to work.
|
||||
if not self.entered:
|
||||
raise RuntimeError("You need to use progress bars in a with block.")
|
||||
|
||||
if self.is_hidden:
|
||||
yield from self.iter
|
||||
else:
|
||||
for rv in self.iter:
|
||||
self.current_item = rv
|
||||
|
||||
# This allows show_item_func to be updated before the
|
||||
# item is processed. Only trigger at the beginning of
|
||||
# the update interval.
|
||||
if self._completed_intervals == 0:
|
||||
self.render_progress()
|
||||
|
||||
yield rv
|
||||
self.update(1)
|
||||
|
||||
self.finish()
|
||||
self.render_progress()
|
||||
|
||||
|
||||
def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None:
|
||||
"""Decide what method to use for paging through text."""
|
||||
stdout = _default_text_stdout()
|
||||
if not isatty(sys.stdin) or not isatty(stdout):
|
||||
return _nullpager(stdout, generator, color)
|
||||
pager_cmd = (os.environ.get("PAGER", None) or "").strip()
|
||||
if pager_cmd:
|
||||
if WIN:
|
||||
return _tempfilepager(generator, pager_cmd, color)
|
||||
return _pipepager(generator, pager_cmd, color)
|
||||
if os.environ.get("TERM") in ("dumb", "emacs"):
|
||||
return _nullpager(stdout, generator, color)
|
||||
if WIN or sys.platform.startswith("os2"):
|
||||
return _tempfilepager(generator, "more <", color)
|
||||
if hasattr(os, "system") and os.system("(less) 2>/dev/null") == 0:
|
||||
return _pipepager(generator, "less", color)
|
||||
|
||||
import tempfile
|
||||
|
||||
fd, filename = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
if hasattr(os, "system") and os.system(f'more "{filename}"') == 0:
|
||||
return _pipepager(generator, "more", color)
|
||||
return _nullpager(stdout, generator, color)
|
||||
finally:
|
||||
os.unlink(filename)
|
||||
|
||||
|
||||
def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> None:
|
||||
"""Page through text by feeding it to another program. Invoking a
|
||||
pager through this might support colors.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
env = dict(os.environ)
|
||||
|
||||
# If we're piping to less we might support colors under the
|
||||
# condition that
|
||||
cmd_detail = cmd.rsplit("/", 1)[-1].split()
|
||||
if color is None and cmd_detail[0] == "less":
|
||||
less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}"
|
||||
if not less_flags:
|
||||
env["LESS"] = "-R"
|
||||
color = True
|
||||
elif "r" in less_flags or "R" in less_flags:
|
||||
color = True
|
||||
|
||||
c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, env=env)
|
||||
stdin = t.cast(t.BinaryIO, c.stdin)
|
||||
encoding = get_best_encoding(stdin)
|
||||
try:
|
||||
for text in generator:
|
||||
if not color:
|
||||
text = strip_ansi(text)
|
||||
|
||||
stdin.write(text.encode(encoding, "replace"))
|
||||
except (OSError, KeyboardInterrupt):
|
||||
pass
|
||||
else:
|
||||
stdin.close()
|
||||
|
||||
# Less doesn't respect ^C, but catches it for its own UI purposes (aborting
|
||||
# search or other commands inside less).
|
||||
#
|
||||
# That means when the user hits ^C, the parent process (click) terminates,
|
||||
# but less is still alive, paging the output and messing up the terminal.
|
||||
#
|
||||
# If the user wants to make the pager exit on ^C, they should set
|
||||
# `LESS='-K'`. It's not our decision to make.
|
||||
while True:
|
||||
try:
|
||||
c.wait()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def _tempfilepager(
|
||||
generator: t.Iterable[str], cmd: str, color: t.Optional[bool]
|
||||
) -> None:
|
||||
"""Page through text by invoking a program on a temporary file."""
|
||||
import tempfile
|
||||
|
||||
fd, filename = tempfile.mkstemp()
|
||||
# TODO: This never terminates if the passed generator never terminates.
|
||||
text = "".join(generator)
|
||||
if not color:
|
||||
text = strip_ansi(text)
|
||||
encoding = get_best_encoding(sys.stdout)
|
||||
with open_stream(filename, "wb")[0] as f:
|
||||
f.write(text.encode(encoding))
|
||||
try:
|
||||
os.system(f'{cmd} "{filename}"')
|
||||
finally:
|
||||
os.close(fd)
|
||||
os.unlink(filename)
|
||||
|
||||
|
||||
def _nullpager(
|
||||
stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool]
|
||||
) -> None:
|
||||
"""Simply print unformatted text. This is the ultimate fallback."""
|
||||
for text in generator:
|
||||
if not color:
|
||||
text = strip_ansi(text)
|
||||
stream.write(text)
|
||||
|
||||
|
||||
class Editor:
|
||||
def __init__(
|
||||
self,
|
||||
editor: t.Optional[str] = None,
|
||||
env: t.Optional[t.Mapping[str, str]] = None,
|
||||
require_save: bool = True,
|
||||
extension: str = ".txt",
|
||||
) -> None:
|
||||
self.editor = editor
|
||||
self.env = env
|
||||
self.require_save = require_save
|
||||
self.extension = extension
|
||||
|
||||
def get_editor(self) -> str:
|
||||
if self.editor is not None:
|
||||
return self.editor
|
||||
for key in "VISUAL", "EDITOR":
|
||||
rv = os.environ.get(key)
|
||||
if rv:
|
||||
return rv
|
||||
if WIN:
|
||||
return "notepad"
|
||||
for editor in "sensible-editor", "vim", "nano":
|
||||
if os.system(f"which {editor} >/dev/null 2>&1") == 0:
|
||||
return editor
|
||||
return "vi"
|
||||
|
||||
def edit_file(self, filename: str) -> None:
|
||||
import subprocess
|
||||
|
||||
editor = self.get_editor()
|
||||
environ: t.Optional[t.Dict[str, str]] = None
|
||||
|
||||
if self.env:
|
||||
environ = os.environ.copy()
|
||||
environ.update(self.env)
|
||||
|
||||
try:
|
||||
c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True)
|
||||
exit_code = c.wait()
|
||||
if exit_code != 0:
|
||||
raise ClickException(
|
||||
_("{editor}: Editing failed").format(editor=editor)
|
||||
)
|
||||
except OSError as e:
|
||||
raise ClickException(
|
||||
_("{editor}: Editing failed: {e}").format(editor=editor, e=e)
|
||||
) from e
|
||||
|
||||
def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]:
|
||||
import tempfile
|
||||
|
||||
if not text:
|
||||
data = b""
|
||||
elif isinstance(text, (bytes, bytearray)):
|
||||
data = text
|
||||
else:
|
||||
if text and not text.endswith("\n"):
|
||||
text += "\n"
|
||||
|
||||
if WIN:
|
||||
data = text.replace("\n", "\r\n").encode("utf-8-sig")
|
||||
else:
|
||||
data = text.encode("utf-8")
|
||||
|
||||
fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension)
|
||||
f: t.BinaryIO
|
||||
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
# If the filesystem resolution is 1 second, like Mac OS
|
||||
# 10.12 Extended, or 2 seconds, like FAT32, and the editor
|
||||
# closes very fast, require_save can fail. Set the modified
|
||||
# time to be 2 seconds in the past to work around this.
|
||||
os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2))
|
||||
# Depending on the resolution, the exact value might not be
|
||||
# recorded, so get the new recorded value.
|
||||
timestamp = os.path.getmtime(name)
|
||||
|
||||
self.edit_file(name)
|
||||
|
||||
if self.require_save and os.path.getmtime(name) == timestamp:
|
||||
return None
|
||||
|
||||
with open(name, "rb") as f:
|
||||
rv = f.read()
|
||||
|
||||
if isinstance(text, (bytes, bytearray)):
|
||||
return rv
|
||||
|
||||
return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore
|
||||
finally:
|
||||
os.unlink(name)
|
||||
|
||||
|
||||
def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
|
||||
import subprocess
|
||||
|
||||
def _unquote_file(url: str) -> str:
|
||||
from urllib.parse import unquote
|
||||
|
||||
if url.startswith("file://"):
|
||||
url = unquote(url[7:])
|
||||
|
||||
return url
|
||||
|
||||
if sys.platform == "darwin":
|
||||
args = ["open"]
|
||||
if wait:
|
||||
args.append("-W")
|
||||
if locate:
|
||||
args.append("-R")
|
||||
args.append(_unquote_file(url))
|
||||
null = open("/dev/null", "w")
|
||||
try:
|
||||
return subprocess.Popen(args, stderr=null).wait()
|
||||
finally:
|
||||
null.close()
|
||||
elif WIN:
|
||||
if locate:
|
||||
url = _unquote_file(url.replace('"', ""))
|
||||
args = f'explorer /select,"{url}"'
|
||||
else:
|
||||
url = url.replace('"', "")
|
||||
wait_str = "/WAIT" if wait else ""
|
||||
args = f'start {wait_str} "" "{url}"'
|
||||
return os.system(args)
|
||||
elif CYGWIN:
|
||||
if locate:
|
||||
url = os.path.dirname(_unquote_file(url).replace('"', ""))
|
||||
args = f'cygstart "{url}"'
|
||||
else:
|
||||
url = url.replace('"', "")
|
||||
wait_str = "-w" if wait else ""
|
||||
args = f'cygstart {wait_str} "{url}"'
|
||||
return os.system(args)
|
||||
|
||||
try:
|
||||
if locate:
|
||||
url = os.path.dirname(_unquote_file(url)) or "."
|
||||
else:
|
||||
url = _unquote_file(url)
|
||||
c = subprocess.Popen(["xdg-open", url])
|
||||
if wait:
|
||||
return c.wait()
|
||||
return 0
|
||||
except OSError:
|
||||
if url.startswith(("http://", "https://")) and not locate and not wait:
|
||||
import webbrowser
|
||||
|
||||
webbrowser.open(url)
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]:
|
||||
if ch == "\x03":
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
|
||||
raise EOFError()
|
||||
|
||||
if ch == "\x1a" and WIN: # Windows, Ctrl+Z
|
||||
raise EOFError()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if WIN:
|
||||
import msvcrt
|
||||
|
||||
@contextlib.contextmanager
|
||||
def raw_terminal() -> t.Iterator[int]:
|
||||
yield -1
|
||||
|
||||
def getchar(echo: bool) -> str:
|
||||
# The function `getch` will return a bytes object corresponding to
|
||||
# the pressed character. Since Windows 10 build 1803, it will also
|
||||
# return \x00 when called a second time after pressing a regular key.
|
||||
#
|
||||
# `getwch` does not share this probably-bugged behavior. Moreover, it
|
||||
# returns a Unicode object by default, which is what we want.
|
||||
#
|
||||
# Either of these functions will return \x00 or \xe0 to indicate
|
||||
# a special key, and you need to call the same function again to get
|
||||
# the "rest" of the code. The fun part is that \u00e0 is
|
||||
# "latin small letter a with grave", so if you type that on a French
|
||||
# keyboard, you _also_ get a \xe0.
|
||||
# E.g., consider the Up arrow. This returns \xe0 and then \x48. The
|
||||
# resulting Unicode string reads as "a with grave" + "capital H".
|
||||
# This is indistinguishable from when the user actually types
|
||||
# "a with grave" and then "capital H".
|
||||
#
|
||||
# When \xe0 is returned, we assume it's part of a special-key sequence
|
||||
# and call `getwch` again, but that means that when the user types
|
||||
# the \u00e0 character, `getchar` doesn't return until a second
|
||||
# character is typed.
|
||||
# The alternative is returning immediately, but that would mess up
|
||||
# cross-platform handling of arrow keys and others that start with
|
||||
# \xe0. Another option is using `getch`, but then we can't reliably
|
||||
# read non-ASCII characters, because return values of `getch` are
|
||||
# limited to the current 8-bit codepage.
|
||||
#
|
||||
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
|
||||
# is doing the right thing in more situations than with `getch`.
|
||||
func: t.Callable[[], str]
|
||||
|
||||
if echo:
|
||||
func = msvcrt.getwche # type: ignore
|
||||
else:
|
||||
func = msvcrt.getwch # type: ignore
|
||||
|
||||
rv = func()
|
||||
|
||||
if rv in ("\x00", "\xe0"):
|
||||
# \x00 and \xe0 are control characters that indicate special key,
|
||||
# see above.
|
||||
rv += func()
|
||||
|
||||
_translate_ch_to_exc(rv)
|
||||
return rv
|
||||
|
||||
else:
|
||||
import tty
|
||||
import termios
|
||||
|
||||
@contextlib.contextmanager
|
||||
def raw_terminal() -> t.Iterator[int]:
|
||||
f: t.Optional[t.TextIO]
|
||||
fd: int
|
||||
|
||||
if not isatty(sys.stdin):
|
||||
f = open("/dev/tty")
|
||||
fd = f.fileno()
|
||||
else:
|
||||
fd = sys.stdin.fileno()
|
||||
f = None
|
||||
|
||||
try:
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
|
||||
try:
|
||||
tty.setraw(fd)
|
||||
yield fd
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
sys.stdout.flush()
|
||||
|
||||
if f is not None:
|
||||
f.close()
|
||||
except termios.error:
|
||||
pass
|
||||
|
||||
def getchar(echo: bool) -> str:
|
||||
with raw_terminal() as fd:
|
||||
ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
|
||||
|
||||
if echo and isatty(sys.stdout):
|
||||
sys.stdout.write(ch)
|
||||
|
||||
_translate_ch_to_exc(ch)
|
||||
return ch
|
@ -0,0 +1,49 @@
|
||||
import textwrap
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class TextWrapper(textwrap.TextWrapper):
|
||||
def _handle_long_word(
|
||||
self,
|
||||
reversed_chunks: t.List[str],
|
||||
cur_line: t.List[str],
|
||||
cur_len: int,
|
||||
width: int,
|
||||
) -> None:
|
||||
space_left = max(width - cur_len, 1)
|
||||
|
||||
if self.break_long_words:
|
||||
last = reversed_chunks[-1]
|
||||
cut = last[:space_left]
|
||||
res = last[space_left:]
|
||||
cur_line.append(cut)
|
||||
reversed_chunks[-1] = res
|
||||
elif not cur_line:
|
||||
cur_line.append(reversed_chunks.pop())
|
||||
|
||||
@contextmanager
|
||||
def extra_indent(self, indent: str) -> t.Iterator[None]:
|
||||
old_initial_indent = self.initial_indent
|
||||
old_subsequent_indent = self.subsequent_indent
|
||||
self.initial_indent += indent
|
||||
self.subsequent_indent += indent
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.initial_indent = old_initial_indent
|
||||
self.subsequent_indent = old_subsequent_indent
|
||||
|
||||
def indent_only(self, text: str) -> str:
|
||||
rv = []
|
||||
|
||||
for idx, line in enumerate(text.splitlines()):
|
||||
indent = self.initial_indent
|
||||
|
||||
if idx > 0:
|
||||
indent = self.subsequent_indent
|
||||
|
||||
rv.append(f"{indent}{line}")
|
||||
|
||||
return "\n".join(rv)
|
@ -0,0 +1,279 @@
|
||||
# This module is based on the excellent work by Adam Bartoš who
|
||||
# provided a lot of what went into the implementation here in
|
||||
# the discussion to issue1602 in the Python bug tracker.
|
||||
#
|
||||
# There are some general differences in regards to how this works
|
||||
# compared to the original patches as we do not need to patch
|
||||
# the entire interpreter but just work in our little world of
|
||||
# echo and prompt.
|
||||
import io
|
||||
import sys
|
||||
import time
|
||||
import typing as t
|
||||
from ctypes import byref
|
||||
from ctypes import c_char
|
||||
from ctypes import c_char_p
|
||||
from ctypes import c_int
|
||||
from ctypes import c_ssize_t
|
||||
from ctypes import c_ulong
|
||||
from ctypes import c_void_p
|
||||
from ctypes import POINTER
|
||||
from ctypes import py_object
|
||||
from ctypes import Structure
|
||||
from ctypes.wintypes import DWORD
|
||||
from ctypes.wintypes import HANDLE
|
||||
from ctypes.wintypes import LPCWSTR
|
||||
from ctypes.wintypes import LPWSTR
|
||||
|
||||
from ._compat import _NonClosingTextIOWrapper
|
||||
|
||||
assert sys.platform == "win32"
|
||||
import msvcrt # noqa: E402
|
||||
from ctypes import windll # noqa: E402
|
||||
from ctypes import WINFUNCTYPE # noqa: E402
|
||||
|
||||
c_ssize_p = POINTER(c_ssize_t)
|
||||
|
||||
kernel32 = windll.kernel32
|
||||
GetStdHandle = kernel32.GetStdHandle
|
||||
ReadConsoleW = kernel32.ReadConsoleW
|
||||
WriteConsoleW = kernel32.WriteConsoleW
|
||||
GetConsoleMode = kernel32.GetConsoleMode
|
||||
GetLastError = kernel32.GetLastError
|
||||
GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
|
||||
CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
|
||||
("CommandLineToArgvW", windll.shell32)
|
||||
)
|
||||
LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
|
||||
|
||||
STDIN_HANDLE = GetStdHandle(-10)
|
||||
STDOUT_HANDLE = GetStdHandle(-11)
|
||||
STDERR_HANDLE = GetStdHandle(-12)
|
||||
|
||||
PyBUF_SIMPLE = 0
|
||||
PyBUF_WRITABLE = 1
|
||||
|
||||
ERROR_SUCCESS = 0
|
||||
ERROR_NOT_ENOUGH_MEMORY = 8
|
||||
ERROR_OPERATION_ABORTED = 995
|
||||
|
||||
STDIN_FILENO = 0
|
||||
STDOUT_FILENO = 1
|
||||
STDERR_FILENO = 2
|
||||
|
||||
EOF = b"\x1a"
|
||||
MAX_BYTES_WRITTEN = 32767
|
||||
|
||||
try:
|
||||
from ctypes import pythonapi
|
||||
except ImportError:
|
||||
# On PyPy we cannot get buffers so our ability to operate here is
|
||||
# severely limited.
|
||||
get_buffer = None
|
||||
else:
|
||||
|
||||
class Py_buffer(Structure):
|
||||
_fields_ = [
|
||||
("buf", c_void_p),
|
||||
("obj", py_object),
|
||||
("len", c_ssize_t),
|
||||
("itemsize", c_ssize_t),
|
||||
("readonly", c_int),
|
||||
("ndim", c_int),
|
||||
("format", c_char_p),
|
||||
("shape", c_ssize_p),
|
||||
("strides", c_ssize_p),
|
||||
("suboffsets", c_ssize_p),
|
||||
("internal", c_void_p),
|
||||
]
|
||||
|
||||
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
|
||||
PyBuffer_Release = pythonapi.PyBuffer_Release
|
||||
|
||||
def get_buffer(obj, writable=False):
|
||||
buf = Py_buffer()
|
||||
flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
|
||||
PyObject_GetBuffer(py_object(obj), byref(buf), flags)
|
||||
|
||||
try:
|
||||
buffer_type = c_char * buf.len
|
||||
return buffer_type.from_address(buf.buf)
|
||||
finally:
|
||||
PyBuffer_Release(byref(buf))
|
||||
|
||||
|
||||
class _WindowsConsoleRawIOBase(io.RawIOBase):
|
||||
def __init__(self, handle):
|
||||
self.handle = handle
|
||||
|
||||
def isatty(self):
|
||||
super().isatty()
|
||||
return True
|
||||
|
||||
|
||||
class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
|
||||
def readable(self):
|
||||
return True
|
||||
|
||||
def readinto(self, b):
|
||||
bytes_to_be_read = len(b)
|
||||
if not bytes_to_be_read:
|
||||
return 0
|
||||
elif bytes_to_be_read % 2:
|
||||
raise ValueError(
|
||||
"cannot read odd number of bytes from UTF-16-LE encoded console"
|
||||
)
|
||||
|
||||
buffer = get_buffer(b, writable=True)
|
||||
code_units_to_be_read = bytes_to_be_read // 2
|
||||
code_units_read = c_ulong()
|
||||
|
||||
rv = ReadConsoleW(
|
||||
HANDLE(self.handle),
|
||||
buffer,
|
||||
code_units_to_be_read,
|
||||
byref(code_units_read),
|
||||
None,
|
||||
)
|
||||
if GetLastError() == ERROR_OPERATION_ABORTED:
|
||||
# wait for KeyboardInterrupt
|
||||
time.sleep(0.1)
|
||||
if not rv:
|
||||
raise OSError(f"Windows error: {GetLastError()}")
|
||||
|
||||
if buffer[0] == EOF:
|
||||
return 0
|
||||
return 2 * code_units_read.value
|
||||
|
||||
|
||||
class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
|
||||
def writable(self):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _get_error_message(errno):
|
||||
if errno == ERROR_SUCCESS:
|
||||
return "ERROR_SUCCESS"
|
||||
elif errno == ERROR_NOT_ENOUGH_MEMORY:
|
||||
return "ERROR_NOT_ENOUGH_MEMORY"
|
||||
return f"Windows error {errno}"
|
||||
|
||||
def write(self, b):
|
||||
bytes_to_be_written = len(b)
|
||||
buf = get_buffer(b)
|
||||
code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
|
||||
code_units_written = c_ulong()
|
||||
|
||||
WriteConsoleW(
|
||||
HANDLE(self.handle),
|
||||
buf,
|
||||
code_units_to_be_written,
|
||||
byref(code_units_written),
|
||||
None,
|
||||
)
|
||||
bytes_written = 2 * code_units_written.value
|
||||
|
||||
if bytes_written == 0 and bytes_to_be_written > 0:
|
||||
raise OSError(self._get_error_message(GetLastError()))
|
||||
return bytes_written
|
||||
|
||||
|
||||
class ConsoleStream:
|
||||
def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
|
||||
self._text_stream = text_stream
|
||||
self.buffer = byte_stream
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.buffer.name
|
||||
|
||||
def write(self, x: t.AnyStr) -> int:
|
||||
if isinstance(x, str):
|
||||
return self._text_stream.write(x)
|
||||
try:
|
||||
self.flush()
|
||||
except Exception:
|
||||
pass
|
||||
return self.buffer.write(x)
|
||||
|
||||
def writelines(self, lines: t.Iterable[t.AnyStr]) -> None:
|
||||
for line in lines:
|
||||
self.write(line)
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
return getattr(self._text_stream, name)
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return self.buffer.isatty()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
|
||||
|
||||
|
||||
def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
|
||||
text_stream = _NonClosingTextIOWrapper(
|
||||
io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
|
||||
"utf-16-le",
|
||||
"strict",
|
||||
line_buffering=True,
|
||||
)
|
||||
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
|
||||
|
||||
|
||||
def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
|
||||
text_stream = _NonClosingTextIOWrapper(
|
||||
io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
|
||||
"utf-16-le",
|
||||
"strict",
|
||||
line_buffering=True,
|
||||
)
|
||||
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
|
||||
|
||||
|
||||
def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
|
||||
text_stream = _NonClosingTextIOWrapper(
|
||||
io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
|
||||
"utf-16-le",
|
||||
"strict",
|
||||
line_buffering=True,
|
||||
)
|
||||
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
|
||||
|
||||
|
||||
_stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
|
||||
0: _get_text_stdin,
|
||||
1: _get_text_stdout,
|
||||
2: _get_text_stderr,
|
||||
}
|
||||
|
||||
|
||||
def _is_console(f: t.TextIO) -> bool:
|
||||
if not hasattr(f, "fileno"):
|
||||
return False
|
||||
|
||||
try:
|
||||
fileno = f.fileno()
|
||||
except (OSError, io.UnsupportedOperation):
|
||||
return False
|
||||
|
||||
handle = msvcrt.get_osfhandle(fileno)
|
||||
return bool(GetConsoleMode(handle, byref(DWORD())))
|
||||
|
||||
|
||||
def _get_windows_console_stream(
|
||||
f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
|
||||
) -> t.Optional[t.TextIO]:
|
||||
if (
|
||||
get_buffer is not None
|
||||
and encoding in {"utf-16-le", None}
|
||||
and errors in {"strict", None}
|
||||
and _is_console(f)
|
||||
):
|
||||
func = _stream_factories.get(f.fileno())
|
||||
if func is not None:
|
||||
b = getattr(f, "buffer", None)
|
||||
|
||||
if b is None:
|
||||
return None
|
||||
|
||||
return func(b)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,497 @@
|
||||
import inspect
|
||||
import types
|
||||
import typing as t
|
||||
from functools import update_wrapper
|
||||
from gettext import gettext as _
|
||||
|
||||
from .core import Argument
|
||||
from .core import Command
|
||||
from .core import Context
|
||||
from .core import Group
|
||||
from .core import Option
|
||||
from .core import Parameter
|
||||
from .globals import get_current_context
|
||||
from .utils import echo
|
||||
|
||||
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
|
||||
FC = t.TypeVar("FC", bound=t.Union[t.Callable[..., t.Any], Command])
|
||||
|
||||
|
||||
def pass_context(f: F) -> F:
|
||||
"""Marks a callback as wanting to receive the current context
|
||||
object as first argument.
|
||||
"""
|
||||
|
||||
def new_func(*args, **kwargs): # type: ignore
|
||||
return f(get_current_context(), *args, **kwargs)
|
||||
|
||||
return update_wrapper(t.cast(F, new_func), f)
|
||||
|
||||
|
||||
def pass_obj(f: F) -> F:
|
||||
"""Similar to :func:`pass_context`, but only pass the object on the
|
||||
context onwards (:attr:`Context.obj`). This is useful if that object
|
||||
represents the state of a nested system.
|
||||
"""
|
||||
|
||||
def new_func(*args, **kwargs): # type: ignore
|
||||
return f(get_current_context().obj, *args, **kwargs)
|
||||
|
||||
return update_wrapper(t.cast(F, new_func), f)
|
||||
|
||||
|
||||
def make_pass_decorator(
|
||||
object_type: t.Type, ensure: bool = False
|
||||
) -> "t.Callable[[F], F]":
|
||||
"""Given an object type this creates a decorator that will work
|
||||
similar to :func:`pass_obj` but instead of passing the object of the
|
||||
current context, it will find the innermost context of type
|
||||
:func:`object_type`.
|
||||
|
||||
This generates a decorator that works roughly like this::
|
||||
|
||||
from functools import update_wrapper
|
||||
|
||||
def decorator(f):
|
||||
@pass_context
|
||||
def new_func(ctx, *args, **kwargs):
|
||||
obj = ctx.find_object(object_type)
|
||||
return ctx.invoke(f, obj, *args, **kwargs)
|
||||
return update_wrapper(new_func, f)
|
||||
return decorator
|
||||
|
||||
:param object_type: the type of the object to pass.
|
||||
:param ensure: if set to `True`, a new object will be created and
|
||||
remembered on the context if it's not there yet.
|
||||
"""
|
||||
|
||||
def decorator(f: F) -> F:
|
||||
def new_func(*args, **kwargs): # type: ignore
|
||||
ctx = get_current_context()
|
||||
|
||||
if ensure:
|
||||
obj = ctx.ensure_object(object_type)
|
||||
else:
|
||||
obj = ctx.find_object(object_type)
|
||||
|
||||
if obj is None:
|
||||
raise RuntimeError(
|
||||
"Managed to invoke callback without a context"
|
||||
f" object of type {object_type.__name__!r}"
|
||||
" existing."
|
||||
)
|
||||
|
||||
return ctx.invoke(f, obj, *args, **kwargs)
|
||||
|
||||
return update_wrapper(t.cast(F, new_func), f)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def pass_meta_key(
|
||||
key: str, *, doc_description: t.Optional[str] = None
|
||||
) -> "t.Callable[[F], F]":
|
||||
"""Create a decorator that passes a key from
|
||||
:attr:`click.Context.meta` as the first argument to the decorated
|
||||
function.
|
||||
|
||||
:param key: Key in ``Context.meta`` to pass.
|
||||
:param doc_description: Description of the object being passed,
|
||||
inserted into the decorator's docstring. Defaults to "the 'key'
|
||||
key from Context.meta".
|
||||
|
||||
.. versionadded:: 8.0
|
||||
"""
|
||||
|
||||
def decorator(f: F) -> F:
|
||||
def new_func(*args, **kwargs): # type: ignore
|
||||
ctx = get_current_context()
|
||||
obj = ctx.meta[key]
|
||||
return ctx.invoke(f, obj, *args, **kwargs)
|
||||
|
||||
return update_wrapper(t.cast(F, new_func), f)
|
||||
|
||||
if doc_description is None:
|
||||
doc_description = f"the {key!r} key from :attr:`click.Context.meta`"
|
||||
|
||||
decorator.__doc__ = (
|
||||
f"Decorator that passes {doc_description} as the first argument"
|
||||
" to the decorated function."
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
CmdType = t.TypeVar("CmdType", bound=Command)
|
||||
|
||||
|
||||
@t.overload
|
||||
def command(
|
||||
__func: t.Callable[..., t.Any],
|
||||
) -> Command:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def command(
|
||||
name: t.Optional[str] = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.Callable[..., Command]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def command(
|
||||
name: t.Optional[str] = None,
|
||||
cls: t.Type[CmdType] = ...,
|
||||
**attrs: t.Any,
|
||||
) -> t.Callable[..., CmdType]:
|
||||
...
|
||||
|
||||
|
||||
def command(
|
||||
name: t.Union[str, t.Callable[..., t.Any], None] = None,
|
||||
cls: t.Optional[t.Type[Command]] = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.Union[Command, t.Callable[..., Command]]:
|
||||
r"""Creates a new :class:`Command` and uses the decorated function as
|
||||
callback. This will also automatically attach all decorated
|
||||
:func:`option`\s and :func:`argument`\s as parameters to the command.
|
||||
|
||||
The name of the command defaults to the name of the function with
|
||||
underscores replaced by dashes. If you want to change that, you can
|
||||
pass the intended name as the first argument.
|
||||
|
||||
All keyword arguments are forwarded to the underlying command class.
|
||||
For the ``params`` argument, any decorated params are appended to
|
||||
the end of the list.
|
||||
|
||||
Once decorated the function turns into a :class:`Command` instance
|
||||
that can be invoked as a command line utility or be attached to a
|
||||
command :class:`Group`.
|
||||
|
||||
:param name: the name of the command. This defaults to the function
|
||||
name with underscores replaced by dashes.
|
||||
:param cls: the command class to instantiate. This defaults to
|
||||
:class:`Command`.
|
||||
|
||||
.. versionchanged:: 8.1
|
||||
This decorator can be applied without parentheses.
|
||||
|
||||
.. versionchanged:: 8.1
|
||||
The ``params`` argument can be used. Decorated params are
|
||||
appended to the end of the list.
|
||||
"""
|
||||
|
||||
func: t.Optional[t.Callable[..., t.Any]] = None
|
||||
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class."
|
||||
assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments."
|
||||
|
||||
if cls is None:
|
||||
cls = Command
|
||||
|
||||
def decorator(f: t.Callable[..., t.Any]) -> Command:
|
||||
if isinstance(f, Command):
|
||||
raise TypeError("Attempted to convert a callback into a command twice.")
|
||||
|
||||
attr_params = attrs.pop("params", None)
|
||||
params = attr_params if attr_params is not None else []
|
||||
|
||||
try:
|
||||
decorator_params = f.__click_params__ # type: ignore
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
del f.__click_params__ # type: ignore
|
||||
params.extend(reversed(decorator_params))
|
||||
|
||||
if attrs.get("help") is None:
|
||||
attrs["help"] = f.__doc__
|
||||
|
||||
cmd = cls( # type: ignore[misc]
|
||||
name=name or f.__name__.lower().replace("_", "-"), # type: ignore[arg-type]
|
||||
callback=f,
|
||||
params=params,
|
||||
**attrs,
|
||||
)
|
||||
cmd.__doc__ = f.__doc__
|
||||
return cmd
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@t.overload
|
||||
def group(
|
||||
__func: t.Callable[..., t.Any],
|
||||
) -> Group:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def group(
|
||||
name: t.Optional[str] = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.Callable[[F], Group]:
|
||||
...
|
||||
|
||||
|
||||
def group(
|
||||
name: t.Union[str, t.Callable[..., t.Any], None] = None, **attrs: t.Any
|
||||
) -> t.Union[Group, t.Callable[[F], Group]]:
|
||||
"""Creates a new :class:`Group` with a function as callback. This
|
||||
works otherwise the same as :func:`command` just that the `cls`
|
||||
parameter is set to :class:`Group`.
|
||||
|
||||
.. versionchanged:: 8.1
|
||||
This decorator can be applied without parentheses.
|
||||
"""
|
||||
if attrs.get("cls") is None:
|
||||
attrs["cls"] = Group
|
||||
|
||||
if callable(name):
|
||||
grp: t.Callable[[F], Group] = t.cast(Group, command(**attrs))
|
||||
return grp(name)
|
||||
|
||||
return t.cast(Group, command(name, **attrs))
|
||||
|
||||
|
||||
def _param_memo(f: FC, param: Parameter) -> None:
|
||||
if isinstance(f, Command):
|
||||
f.params.append(param)
|
||||
else:
|
||||
if not hasattr(f, "__click_params__"):
|
||||
f.__click_params__ = [] # type: ignore
|
||||
|
||||
f.__click_params__.append(param) # type: ignore
|
||||
|
||||
|
||||
def argument(*param_decls: str, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
"""Attaches an argument to the command. All positional arguments are
|
||||
passed as parameter declarations to :class:`Argument`; all keyword
|
||||
arguments are forwarded unchanged (except ``cls``).
|
||||
This is equivalent to creating an :class:`Argument` instance manually
|
||||
and attaching it to the :attr:`Command.params` list.
|
||||
|
||||
:param cls: the argument class to instantiate. This defaults to
|
||||
:class:`Argument`.
|
||||
"""
|
||||
|
||||
def decorator(f: FC) -> FC:
|
||||
ArgumentClass = attrs.pop("cls", None) or Argument
|
||||
_param_memo(f, ArgumentClass(param_decls, **attrs))
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def option(*param_decls: str, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
"""Attaches an option to the command. All positional arguments are
|
||||
passed as parameter declarations to :class:`Option`; all keyword
|
||||
arguments are forwarded unchanged (except ``cls``).
|
||||
This is equivalent to creating an :class:`Option` instance manually
|
||||
and attaching it to the :attr:`Command.params` list.
|
||||
|
||||
:param cls: the option class to instantiate. This defaults to
|
||||
:class:`Option`.
|
||||
"""
|
||||
|
||||
def decorator(f: FC) -> FC:
|
||||
# Issue 926, copy attrs, so pre-defined options can re-use the same cls=
|
||||
option_attrs = attrs.copy()
|
||||
OptionClass = option_attrs.pop("cls", None) or Option
|
||||
_param_memo(f, OptionClass(param_decls, **option_attrs))
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
|
||||
"""Add a ``--yes`` option which shows a prompt before continuing if
|
||||
not passed. If the prompt is declined, the program will exit.
|
||||
|
||||
:param param_decls: One or more option names. Defaults to the single
|
||||
value ``"--yes"``.
|
||||
:param kwargs: Extra arguments are passed to :func:`option`.
|
||||
"""
|
||||
|
||||
def callback(ctx: Context, param: Parameter, value: bool) -> None:
|
||||
if not value:
|
||||
ctx.abort()
|
||||
|
||||
if not param_decls:
|
||||
param_decls = ("--yes",)
|
||||
|
||||
kwargs.setdefault("is_flag", True)
|
||||
kwargs.setdefault("callback", callback)
|
||||
kwargs.setdefault("expose_value", False)
|
||||
kwargs.setdefault("prompt", "Do you want to continue?")
|
||||
kwargs.setdefault("help", "Confirm the action without prompting.")
|
||||
return option(*param_decls, **kwargs)
|
||||
|
||||
|
||||
def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
|
||||
"""Add a ``--password`` option which prompts for a password, hiding
|
||||
input and asking to enter the value again for confirmation.
|
||||
|
||||
:param param_decls: One or more option names. Defaults to the single
|
||||
value ``"--password"``.
|
||||
:param kwargs: Extra arguments are passed to :func:`option`.
|
||||
"""
|
||||
if not param_decls:
|
||||
param_decls = ("--password",)
|
||||
|
||||
kwargs.setdefault("prompt", True)
|
||||
kwargs.setdefault("confirmation_prompt", True)
|
||||
kwargs.setdefault("hide_input", True)
|
||||
return option(*param_decls, **kwargs)
|
||||
|
||||
|
||||
def version_option(
|
||||
version: t.Optional[str] = None,
|
||||
*param_decls: str,
|
||||
package_name: t.Optional[str] = None,
|
||||
prog_name: t.Optional[str] = None,
|
||||
message: t.Optional[str] = None,
|
||||
**kwargs: t.Any,
|
||||
) -> t.Callable[[FC], FC]:
|
||||
"""Add a ``--version`` option which immediately prints the version
|
||||
number and exits the program.
|
||||
|
||||
If ``version`` is not provided, Click will try to detect it using
|
||||
:func:`importlib.metadata.version` to get the version for the
|
||||
``package_name``. On Python < 3.8, the ``importlib_metadata``
|
||||
backport must be installed.
|
||||
|
||||
If ``package_name`` is not provided, Click will try to detect it by
|
||||
inspecting the stack frames. This will be used to detect the
|
||||
version, so it must match the name of the installed package.
|
||||
|
||||
:param version: The version number to show. If not provided, Click
|
||||
will try to detect it.
|
||||
:param param_decls: One or more option names. Defaults to the single
|
||||
value ``"--version"``.
|
||||
:param package_name: The package name to detect the version from. If
|
||||
not provided, Click will try to detect it.
|
||||
:param prog_name: The name of the CLI to show in the message. If not
|
||||
provided, it will be detected from the command.
|
||||
:param message: The message to show. The values ``%(prog)s``,
|
||||
``%(package)s``, and ``%(version)s`` are available. Defaults to
|
||||
``"%(prog)s, version %(version)s"``.
|
||||
:param kwargs: Extra arguments are passed to :func:`option`.
|
||||
:raise RuntimeError: ``version`` could not be detected.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Add the ``package_name`` parameter, and the ``%(package)s``
|
||||
value for messages.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Use :mod:`importlib.metadata` instead of ``pkg_resources``. The
|
||||
version is detected based on the package name, not the entry
|
||||
point name. The Python package name must match the installed
|
||||
package name, or be passed with ``package_name=``.
|
||||
"""
|
||||
if message is None:
|
||||
message = _("%(prog)s, version %(version)s")
|
||||
|
||||
if version is None and package_name is None:
|
||||
frame = inspect.currentframe()
|
||||
f_back = frame.f_back if frame is not None else None
|
||||
f_globals = f_back.f_globals if f_back is not None else None
|
||||
# break reference cycle
|
||||
# https://docs.python.org/3/library/inspect.html#the-interpreter-stack
|
||||
del frame
|
||||
|
||||
if f_globals is not None:
|
||||
package_name = f_globals.get("__name__")
|
||||
|
||||
if package_name == "__main__":
|
||||
package_name = f_globals.get("__package__")
|
||||
|
||||
if package_name:
|
||||
package_name = package_name.partition(".")[0]
|
||||
|
||||
def callback(ctx: Context, param: Parameter, value: bool) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
|
||||
nonlocal prog_name
|
||||
nonlocal version
|
||||
|
||||
if prog_name is None:
|
||||
prog_name = ctx.find_root().info_name
|
||||
|
||||
if version is None and package_name is not None:
|
||||
metadata: t.Optional[types.ModuleType]
|
||||
|
||||
try:
|
||||
from importlib import metadata # type: ignore
|
||||
except ImportError:
|
||||
# Python < 3.8
|
||||
import importlib_metadata as metadata # type: ignore
|
||||
|
||||
try:
|
||||
version = metadata.version(package_name) # type: ignore
|
||||
except metadata.PackageNotFoundError: # type: ignore
|
||||
raise RuntimeError(
|
||||
f"{package_name!r} is not installed. Try passing"
|
||||
" 'package_name' instead."
|
||||
) from None
|
||||
|
||||
if version is None:
|
||||
raise RuntimeError(
|
||||
f"Could not determine the version for {package_name!r} automatically."
|
||||
)
|
||||
|
||||
echo(
|
||||
t.cast(str, message)
|
||||
% {"prog": prog_name, "package": package_name, "version": version},
|
||||
color=ctx.color,
|
||||
)
|
||||
ctx.exit()
|
||||
|
||||
if not param_decls:
|
||||
param_decls = ("--version",)
|
||||
|
||||
kwargs.setdefault("is_flag", True)
|
||||
kwargs.setdefault("expose_value", False)
|
||||
kwargs.setdefault("is_eager", True)
|
||||
kwargs.setdefault("help", _("Show the version and exit."))
|
||||
kwargs["callback"] = callback
|
||||
return option(*param_decls, **kwargs)
|
||||
|
||||
|
||||
def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
|
||||
"""Add a ``--help`` option which immediately prints the help page
|
||||
and exits the program.
|
||||
|
||||
This is usually unnecessary, as the ``--help`` option is added to
|
||||
each command automatically unless ``add_help_option=False`` is
|
||||
passed.
|
||||
|
||||
:param param_decls: One or more option names. Defaults to the single
|
||||
value ``"--help"``.
|
||||
:param kwargs: Extra arguments are passed to :func:`option`.
|
||||
"""
|
||||
|
||||
def callback(ctx: Context, param: Parameter, value: bool) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
|
||||
echo(ctx.get_help(), color=ctx.color)
|
||||
ctx.exit()
|
||||
|
||||
if not param_decls:
|
||||
param_decls = ("--help",)
|
||||
|
||||
kwargs.setdefault("is_flag", True)
|
||||
kwargs.setdefault("expose_value", False)
|
||||
kwargs.setdefault("is_eager", True)
|
||||
kwargs.setdefault("help", _("Show this message and exit."))
|
||||
kwargs["callback"] = callback
|
||||
return option(*param_decls, **kwargs)
|
@ -0,0 +1,287 @@
|
||||
import os
|
||||
import typing as t
|
||||
from gettext import gettext as _
|
||||
from gettext import ngettext
|
||||
|
||||
from ._compat import get_text_stderr
|
||||
from .utils import echo
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .core import Context
|
||||
from .core import Parameter
|
||||
|
||||
|
||||
def _join_param_hints(
|
||||
param_hint: t.Optional[t.Union[t.Sequence[str], str]]
|
||||
) -> t.Optional[str]:
|
||||
if param_hint is not None and not isinstance(param_hint, str):
|
||||
return " / ".join(repr(x) for x in param_hint)
|
||||
|
||||
return param_hint
|
||||
|
||||
|
||||
class ClickException(Exception):
|
||||
"""An exception that Click can handle and show to the user."""
|
||||
|
||||
#: The exit code for this exception.
|
||||
exit_code = 1
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def format_message(self) -> str:
|
||||
return self.message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.message
|
||||
|
||||
def show(self, file: t.Optional[t.IO] = None) -> None:
|
||||
if file is None:
|
||||
file = get_text_stderr()
|
||||
|
||||
echo(_("Error: {message}").format(message=self.format_message()), file=file)
|
||||
|
||||
|
||||
class UsageError(ClickException):
|
||||
"""An internal exception that signals a usage error. This typically
|
||||
aborts any further handling.
|
||||
|
||||
:param message: the error message to display.
|
||||
:param ctx: optionally the context that caused this error. Click will
|
||||
fill in the context automatically in some situations.
|
||||
"""
|
||||
|
||||
exit_code = 2
|
||||
|
||||
def __init__(self, message: str, ctx: t.Optional["Context"] = None) -> None:
|
||||
super().__init__(message)
|
||||
self.ctx = ctx
|
||||
self.cmd = self.ctx.command if self.ctx else None
|
||||
|
||||
def show(self, file: t.Optional[t.IO] = None) -> None:
|
||||
if file is None:
|
||||
file = get_text_stderr()
|
||||
color = None
|
||||
hint = ""
|
||||
if (
|
||||
self.ctx is not None
|
||||
and self.ctx.command.get_help_option(self.ctx) is not None
|
||||
):
|
||||
hint = _("Try '{command} {option}' for help.").format(
|
||||
command=self.ctx.command_path, option=self.ctx.help_option_names[0]
|
||||
)
|
||||
hint = f"{hint}\n"
|
||||
if self.ctx is not None:
|
||||
color = self.ctx.color
|
||||
echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color)
|
||||
echo(
|
||||
_("Error: {message}").format(message=self.format_message()),
|
||||
file=file,
|
||||
color=color,
|
||||
)
|
||||
|
||||
|
||||
class BadParameter(UsageError):
|
||||
"""An exception that formats out a standardized error message for a
|
||||
bad parameter. This is useful when thrown from a callback or type as
|
||||
Click will attach contextual information to it (for instance, which
|
||||
parameter it is).
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
:param param: the parameter object that caused this error. This can
|
||||
be left out, and Click will attach this info itself
|
||||
if possible.
|
||||
:param param_hint: a string that shows up as parameter name. This
|
||||
can be used as alternative to `param` in cases
|
||||
where custom validation should happen. If it is
|
||||
a string it's used as such, if it's a list then
|
||||
each item is quoted and separated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
ctx: t.Optional["Context"] = None,
|
||||
param: t.Optional["Parameter"] = None,
|
||||
param_hint: t.Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(message, ctx)
|
||||
self.param = param
|
||||
self.param_hint = param_hint
|
||||
|
||||
def format_message(self) -> str:
|
||||
if self.param_hint is not None:
|
||||
param_hint = self.param_hint
|
||||
elif self.param is not None:
|
||||
param_hint = self.param.get_error_hint(self.ctx) # type: ignore
|
||||
else:
|
||||
return _("Invalid value: {message}").format(message=self.message)
|
||||
|
||||
return _("Invalid value for {param_hint}: {message}").format(
|
||||
param_hint=_join_param_hints(param_hint), message=self.message
|
||||
)
|
||||
|
||||
|
||||
class MissingParameter(BadParameter):
|
||||
"""Raised if click required an option or argument but it was not
|
||||
provided when invoking the script.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
:param param_type: a string that indicates the type of the parameter.
|
||||
The default is to inherit the parameter type from
|
||||
the given `param`. Valid values are ``'parameter'``,
|
||||
``'option'`` or ``'argument'``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: t.Optional[str] = None,
|
||||
ctx: t.Optional["Context"] = None,
|
||||
param: t.Optional["Parameter"] = None,
|
||||
param_hint: t.Optional[str] = None,
|
||||
param_type: t.Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(message or "", ctx, param, param_hint)
|
||||
self.param_type = param_type
|
||||
|
||||
def format_message(self) -> str:
|
||||
if self.param_hint is not None:
|
||||
param_hint: t.Optional[str] = self.param_hint
|
||||
elif self.param is not None:
|
||||
param_hint = self.param.get_error_hint(self.ctx) # type: ignore
|
||||
else:
|
||||
param_hint = None
|
||||
|
||||
param_hint = _join_param_hints(param_hint)
|
||||
param_hint = f" {param_hint}" if param_hint else ""
|
||||
|
||||
param_type = self.param_type
|
||||
if param_type is None and self.param is not None:
|
||||
param_type = self.param.param_type_name
|
||||
|
||||
msg = self.message
|
||||
if self.param is not None:
|
||||
msg_extra = self.param.type.get_missing_message(self.param)
|
||||
if msg_extra:
|
||||
if msg:
|
||||
msg += f". {msg_extra}"
|
||||
else:
|
||||
msg = msg_extra
|
||||
|
||||
msg = f" {msg}" if msg else ""
|
||||
|
||||
# Translate param_type for known types.
|
||||
if param_type == "argument":
|
||||
missing = _("Missing argument")
|
||||
elif param_type == "option":
|
||||
missing = _("Missing option")
|
||||
elif param_type == "parameter":
|
||||
missing = _("Missing parameter")
|
||||
else:
|
||||
missing = _("Missing {param_type}").format(param_type=param_type)
|
||||
|
||||
return f"{missing}{param_hint}.{msg}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not self.message:
|
||||
param_name = self.param.name if self.param else None
|
||||
return _("Missing parameter: {param_name}").format(param_name=param_name)
|
||||
else:
|
||||
return self.message
|
||||
|
||||
|
||||
class NoSuchOption(UsageError):
|
||||
"""Raised if click attempted to handle an option that does not
|
||||
exist.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
option_name: str,
|
||||
message: t.Optional[str] = None,
|
||||
possibilities: t.Optional[t.Sequence[str]] = None,
|
||||
ctx: t.Optional["Context"] = None,
|
||||
) -> None:
|
||||
if message is None:
|
||||
message = _("No such option: {name}").format(name=option_name)
|
||||
|
||||
super().__init__(message, ctx)
|
||||
self.option_name = option_name
|
||||
self.possibilities = possibilities
|
||||
|
||||
def format_message(self) -> str:
|
||||
if not self.possibilities:
|
||||
return self.message
|
||||
|
||||
possibility_str = ", ".join(sorted(self.possibilities))
|
||||
suggest = ngettext(
|
||||
"Did you mean {possibility}?",
|
||||
"(Possible options: {possibilities})",
|
||||
len(self.possibilities),
|
||||
).format(possibility=possibility_str, possibilities=possibility_str)
|
||||
return f"{self.message} {suggest}"
|
||||
|
||||
|
||||
class BadOptionUsage(UsageError):
|
||||
"""Raised if an option is generally supplied but the use of the option
|
||||
was incorrect. This is for instance raised if the number of arguments
|
||||
for an option is not correct.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
:param option_name: the name of the option being used incorrectly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, option_name: str, message: str, ctx: t.Optional["Context"] = None
|
||||
) -> None:
|
||||
super().__init__(message, ctx)
|
||||
self.option_name = option_name
|
||||
|
||||
|
||||
class BadArgumentUsage(UsageError):
|
||||
"""Raised if an argument is generally supplied but the use of the argument
|
||||
was incorrect. This is for instance raised if the number of values
|
||||
for an argument is not correct.
|
||||
|
||||
.. versionadded:: 6.0
|
||||
"""
|
||||
|
||||
|
||||
class FileError(ClickException):
|
||||
"""Raised if a file cannot be opened."""
|
||||
|
||||
def __init__(self, filename: str, hint: t.Optional[str] = None) -> None:
|
||||
if hint is None:
|
||||
hint = _("unknown error")
|
||||
|
||||
super().__init__(hint)
|
||||
self.ui_filename = os.fsdecode(filename)
|
||||
self.filename = filename
|
||||
|
||||
def format_message(self) -> str:
|
||||
return _("Could not open file {filename!r}: {message}").format(
|
||||
filename=self.ui_filename, message=self.message
|
||||
)
|
||||
|
||||
|
||||
class Abort(RuntimeError):
|
||||
"""An internal signalling exception that signals Click to abort."""
|
||||
|
||||
|
||||
class Exit(RuntimeError):
|
||||
"""An exception that indicates that the application should exit with some
|
||||
status code.
|
||||
|
||||
:param code: the status code to exit with.
|
||||
"""
|
||||
|
||||
__slots__ = ("exit_code",)
|
||||
|
||||
def __init__(self, code: int = 0) -> None:
|
||||
self.exit_code = code
|
@ -0,0 +1,301 @@
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from gettext import gettext as _
|
||||
|
||||
from ._compat import term_len
|
||||
from .parser import split_opt
|
||||
|
||||
# Can force a width. This is used by the test system
|
||||
FORCED_WIDTH: t.Optional[int] = None
|
||||
|
||||
|
||||
def measure_table(rows: t.Iterable[t.Tuple[str, str]]) -> t.Tuple[int, ...]:
|
||||
widths: t.Dict[int, int] = {}
|
||||
|
||||
for row in rows:
|
||||
for idx, col in enumerate(row):
|
||||
widths[idx] = max(widths.get(idx, 0), term_len(col))
|
||||
|
||||
return tuple(y for x, y in sorted(widths.items()))
|
||||
|
||||
|
||||
def iter_rows(
|
||||
rows: t.Iterable[t.Tuple[str, str]], col_count: int
|
||||
) -> t.Iterator[t.Tuple[str, ...]]:
|
||||
for row in rows:
|
||||
yield row + ("",) * (col_count - len(row))
|
||||
|
||||
|
||||
def wrap_text(
|
||||
text: str,
|
||||
width: int = 78,
|
||||
initial_indent: str = "",
|
||||
subsequent_indent: str = "",
|
||||
preserve_paragraphs: bool = False,
|
||||
) -> str:
|
||||
"""A helper function that intelligently wraps text. By default, it
|
||||
assumes that it operates on a single paragraph of text but if the
|
||||
`preserve_paragraphs` parameter is provided it will intelligently
|
||||
handle paragraphs (defined by two empty lines).
|
||||
|
||||
If paragraphs are handled, a paragraph can be prefixed with an empty
|
||||
line containing the ``\\b`` character (``\\x08``) to indicate that
|
||||
no rewrapping should happen in that block.
|
||||
|
||||
:param text: the text that should be rewrapped.
|
||||
:param width: the maximum width for the text.
|
||||
:param initial_indent: the initial indent that should be placed on the
|
||||
first line as a string.
|
||||
:param subsequent_indent: the indent string that should be placed on
|
||||
each consecutive line.
|
||||
:param preserve_paragraphs: if this flag is set then the wrapping will
|
||||
intelligently handle paragraphs.
|
||||
"""
|
||||
from ._textwrap import TextWrapper
|
||||
|
||||
text = text.expandtabs()
|
||||
wrapper = TextWrapper(
|
||||
width,
|
||||
initial_indent=initial_indent,
|
||||
subsequent_indent=subsequent_indent,
|
||||
replace_whitespace=False,
|
||||
)
|
||||
if not preserve_paragraphs:
|
||||
return wrapper.fill(text)
|
||||
|
||||
p: t.List[t.Tuple[int, bool, str]] = []
|
||||
buf: t.List[str] = []
|
||||
indent = None
|
||||
|
||||
def _flush_par() -> None:
|
||||
if not buf:
|
||||
return
|
||||
if buf[0].strip() == "\b":
|
||||
p.append((indent or 0, True, "\n".join(buf[1:])))
|
||||
else:
|
||||
p.append((indent or 0, False, " ".join(buf)))
|
||||
del buf[:]
|
||||
|
||||
for line in text.splitlines():
|
||||
if not line:
|
||||
_flush_par()
|
||||
indent = None
|
||||
else:
|
||||
if indent is None:
|
||||
orig_len = term_len(line)
|
||||
line = line.lstrip()
|
||||
indent = orig_len - term_len(line)
|
||||
buf.append(line)
|
||||
_flush_par()
|
||||
|
||||
rv = []
|
||||
for indent, raw, text in p:
|
||||
with wrapper.extra_indent(" " * indent):
|
||||
if raw:
|
||||
rv.append(wrapper.indent_only(text))
|
||||
else:
|
||||
rv.append(wrapper.fill(text))
|
||||
|
||||
return "\n\n".join(rv)
|
||||
|
||||
|
||||
class HelpFormatter:
|
||||
"""This class helps with formatting text-based help pages. It's
|
||||
usually just needed for very special internal cases, but it's also
|
||||
exposed so that developers can write their own fancy outputs.
|
||||
|
||||
At present, it always writes into memory.
|
||||
|
||||
:param indent_increment: the additional increment for each level.
|
||||
:param width: the width for the text. This defaults to the terminal
|
||||
width clamped to a maximum of 78.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indent_increment: int = 2,
|
||||
width: t.Optional[int] = None,
|
||||
max_width: t.Optional[int] = None,
|
||||
) -> None:
|
||||
import shutil
|
||||
|
||||
self.indent_increment = indent_increment
|
||||
if max_width is None:
|
||||
max_width = 80
|
||||
if width is None:
|
||||
width = FORCED_WIDTH
|
||||
if width is None:
|
||||
width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50)
|
||||
self.width = width
|
||||
self.current_indent = 0
|
||||
self.buffer: t.List[str] = []
|
||||
|
||||
def write(self, string: str) -> None:
|
||||
"""Writes a unicode string into the internal buffer."""
|
||||
self.buffer.append(string)
|
||||
|
||||
def indent(self) -> None:
|
||||
"""Increases the indentation."""
|
||||
self.current_indent += self.indent_increment
|
||||
|
||||
def dedent(self) -> None:
|
||||
"""Decreases the indentation."""
|
||||
self.current_indent -= self.indent_increment
|
||||
|
||||
def write_usage(
|
||||
self, prog: str, args: str = "", prefix: t.Optional[str] = None
|
||||
) -> None:
|
||||
"""Writes a usage line into the buffer.
|
||||
|
||||
:param prog: the program name.
|
||||
:param args: whitespace separated list of arguments.
|
||||
:param prefix: The prefix for the first line. Defaults to
|
||||
``"Usage: "``.
|
||||
"""
|
||||
if prefix is None:
|
||||
prefix = f"{_('Usage:')} "
|
||||
|
||||
usage_prefix = f"{prefix:>{self.current_indent}}{prog} "
|
||||
text_width = self.width - self.current_indent
|
||||
|
||||
if text_width >= (term_len(usage_prefix) + 20):
|
||||
# The arguments will fit to the right of the prefix.
|
||||
indent = " " * term_len(usage_prefix)
|
||||
self.write(
|
||||
wrap_text(
|
||||
args,
|
||||
text_width,
|
||||
initial_indent=usage_prefix,
|
||||
subsequent_indent=indent,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# The prefix is too long, put the arguments on the next line.
|
||||
self.write(usage_prefix)
|
||||
self.write("\n")
|
||||
indent = " " * (max(self.current_indent, term_len(prefix)) + 4)
|
||||
self.write(
|
||||
wrap_text(
|
||||
args, text_width, initial_indent=indent, subsequent_indent=indent
|
||||
)
|
||||
)
|
||||
|
||||
self.write("\n")
|
||||
|
||||
def write_heading(self, heading: str) -> None:
|
||||
"""Writes a heading into the buffer."""
|
||||
self.write(f"{'':>{self.current_indent}}{heading}:\n")
|
||||
|
||||
def write_paragraph(self) -> None:
|
||||
"""Writes a paragraph into the buffer."""
|
||||
if self.buffer:
|
||||
self.write("\n")
|
||||
|
||||
def write_text(self, text: str) -> None:
|
||||
"""Writes re-indented text into the buffer. This rewraps and
|
||||
preserves paragraphs.
|
||||
"""
|
||||
indent = " " * self.current_indent
|
||||
self.write(
|
||||
wrap_text(
|
||||
text,
|
||||
self.width,
|
||||
initial_indent=indent,
|
||||
subsequent_indent=indent,
|
||||
preserve_paragraphs=True,
|
||||
)
|
||||
)
|
||||
self.write("\n")
|
||||
|
||||
def write_dl(
|
||||
self,
|
||||
rows: t.Sequence[t.Tuple[str, str]],
|
||||
col_max: int = 30,
|
||||
col_spacing: int = 2,
|
||||
) -> None:
|
||||
"""Writes a definition list into the buffer. This is how options
|
||||
and commands are usually formatted.
|
||||
|
||||
:param rows: a list of two item tuples for the terms and values.
|
||||
:param col_max: the maximum width of the first column.
|
||||
:param col_spacing: the number of spaces between the first and
|
||||
second column.
|
||||
"""
|
||||
rows = list(rows)
|
||||
widths = measure_table(rows)
|
||||
if len(widths) != 2:
|
||||
raise TypeError("Expected two columns for definition list")
|
||||
|
||||
first_col = min(widths[0], col_max) + col_spacing
|
||||
|
||||
for first, second in iter_rows(rows, len(widths)):
|
||||
self.write(f"{'':>{self.current_indent}}{first}")
|
||||
if not second:
|
||||
self.write("\n")
|
||||
continue
|
||||
if term_len(first) <= first_col - col_spacing:
|
||||
self.write(" " * (first_col - term_len(first)))
|
||||
else:
|
||||
self.write("\n")
|
||||
self.write(" " * (first_col + self.current_indent))
|
||||
|
||||
text_width = max(self.width - first_col - 2, 10)
|
||||
wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True)
|
||||
lines = wrapped_text.splitlines()
|
||||
|
||||
if lines:
|
||||
self.write(f"{lines[0]}\n")
|
||||
|
||||
for line in lines[1:]:
|
||||
self.write(f"{'':>{first_col + self.current_indent}}{line}\n")
|
||||
else:
|
||||
self.write("\n")
|
||||
|
||||
@contextmanager
|
||||
def section(self, name: str) -> t.Iterator[None]:
|
||||
"""Helpful context manager that writes a paragraph, a heading,
|
||||
and the indents.
|
||||
|
||||
:param name: the section name that is written as heading.
|
||||
"""
|
||||
self.write_paragraph()
|
||||
self.write_heading(name)
|
||||
self.indent()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.dedent()
|
||||
|
||||
@contextmanager
|
||||
def indentation(self) -> t.Iterator[None]:
|
||||
"""A context manager that increases the indentation."""
|
||||
self.indent()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.dedent()
|
||||
|
||||
def getvalue(self) -> str:
|
||||
"""Returns the buffer contents."""
|
||||
return "".join(self.buffer)
|
||||
|
||||
|
||||
def join_options(options: t.Sequence[str]) -> t.Tuple[str, bool]:
|
||||
"""Given a list of option strings this joins them in the most appropriate
|
||||
way and returns them in the form ``(formatted_string,
|
||||
any_prefix_is_slash)`` where the second item in the tuple is a flag that
|
||||
indicates if any of the option prefixes was a slash.
|
||||
"""
|
||||
rv = []
|
||||
any_prefix_is_slash = False
|
||||
|
||||
for opt in options:
|
||||
prefix = split_opt(opt)[0]
|
||||
|
||||
if prefix == "/":
|
||||
any_prefix_is_slash = True
|
||||
|
||||
rv.append((len(prefix), opt))
|
||||
|
||||
rv.sort(key=lambda x: x[0])
|
||||
return ", ".join(x[1] for x in rv), any_prefix_is_slash
|
@ -0,0 +1,68 @@
|
||||
import typing as t
|
||||
from threading import local
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import typing_extensions as te
|
||||
from .core import Context
|
||||
|
||||
_local = local()
|
||||
|
||||
|
||||
@t.overload
|
||||
def get_current_context(silent: "te.Literal[False]" = False) -> "Context":
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def get_current_context(silent: bool = ...) -> t.Optional["Context"]:
|
||||
...
|
||||
|
||||
|
||||
def get_current_context(silent: bool = False) -> t.Optional["Context"]:
|
||||
"""Returns the current click context. This can be used as a way to
|
||||
access the current context object from anywhere. This is a more implicit
|
||||
alternative to the :func:`pass_context` decorator. This function is
|
||||
primarily useful for helpers such as :func:`echo` which might be
|
||||
interested in changing its behavior based on the current context.
|
||||
|
||||
To push the current context, :meth:`Context.scope` can be used.
|
||||
|
||||
.. versionadded:: 5.0
|
||||
|
||||
:param silent: if set to `True` the return value is `None` if no context
|
||||
is available. The default behavior is to raise a
|
||||
:exc:`RuntimeError`.
|
||||
"""
|
||||
try:
|
||||
return t.cast("Context", _local.stack[-1])
|
||||
except (AttributeError, IndexError) as e:
|
||||
if not silent:
|
||||
raise RuntimeError("There is no active click context.") from e
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def push_context(ctx: "Context") -> None:
|
||||
"""Pushes a new context to the current stack."""
|
||||
_local.__dict__.setdefault("stack", []).append(ctx)
|
||||
|
||||
|
||||
def pop_context() -> None:
|
||||
"""Removes the top level from the stack."""
|
||||
_local.stack.pop()
|
||||
|
||||
|
||||
def resolve_color_default(color: t.Optional[bool] = None) -> t.Optional[bool]:
|
||||
"""Internal helper to get the default value of the color flag. If a
|
||||
value is passed it's returned unchanged, otherwise it's looked up from
|
||||
the current context.
|
||||
"""
|
||||
if color is not None:
|
||||
return color
|
||||
|
||||
ctx = get_current_context(silent=True)
|
||||
|
||||
if ctx is not None:
|
||||
return ctx.color
|
||||
|
||||
return None
|
@ -0,0 +1,529 @@
|
||||
"""
|
||||
This module started out as largely a copy paste from the stdlib's
|
||||
optparse module with the features removed that we do not need from
|
||||
optparse because we implement them in Click on a higher level (for
|
||||
instance type handling, help formatting and a lot more).
|
||||
|
||||
The plan is to remove more and more from here over time.
|
||||
|
||||
The reason this is a different module and not optparse from the stdlib
|
||||
is that there are differences in 2.x and 3.x about the error messages
|
||||
generated and optparse in the stdlib uses gettext for no good reason
|
||||
and might cause us issues.
|
||||
|
||||
Click uses parts of optparse written by Gregory P. Ward and maintained
|
||||
by the Python Software Foundation. This is limited to code in parser.py.
|
||||
|
||||
Copyright 2001-2006 Gregory P. Ward. All rights reserved.
|
||||
Copyright 2002-2006 Python Software Foundation. All rights reserved.
|
||||
"""
|
||||
# This code uses parts of optparse written by Gregory P. Ward and
|
||||
# maintained by the Python Software Foundation.
|
||||
# Copyright 2001-2006 Gregory P. Ward
|
||||
# Copyright 2002-2006 Python Software Foundation
|
||||
import typing as t
|
||||
from collections import deque
|
||||
from gettext import gettext as _
|
||||
from gettext import ngettext
|
||||
|
||||
from .exceptions import BadArgumentUsage
|
||||
from .exceptions import BadOptionUsage
|
||||
from .exceptions import NoSuchOption
|
||||
from .exceptions import UsageError
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import typing_extensions as te
|
||||
from .core import Argument as CoreArgument
|
||||
from .core import Context
|
||||
from .core import Option as CoreOption
|
||||
from .core import Parameter as CoreParameter
|
||||
|
||||
V = t.TypeVar("V")
|
||||
|
||||
# Sentinel value that indicates an option was passed as a flag without a
|
||||
# value but is not a flag option. Option.consume_value uses this to
|
||||
# prompt or use the flag_value.
|
||||
_flag_needs_value = object()
|
||||
|
||||
|
||||
def _unpack_args(
|
||||
args: t.Sequence[str], nargs_spec: t.Sequence[int]
|
||||
) -> t.Tuple[t.Sequence[t.Union[str, t.Sequence[t.Optional[str]], None]], t.List[str]]:
|
||||
"""Given an iterable of arguments and an iterable of nargs specifications,
|
||||
it returns a tuple with all the unpacked arguments at the first index
|
||||
and all remaining arguments as the second.
|
||||
|
||||
The nargs specification is the number of arguments that should be consumed
|
||||
or `-1` to indicate that this position should eat up all the remainders.
|
||||
|
||||
Missing items are filled with `None`.
|
||||
"""
|
||||
args = deque(args)
|
||||
nargs_spec = deque(nargs_spec)
|
||||
rv: t.List[t.Union[str, t.Tuple[t.Optional[str], ...], None]] = []
|
||||
spos: t.Optional[int] = None
|
||||
|
||||
def _fetch(c: "te.Deque[V]") -> t.Optional[V]:
|
||||
try:
|
||||
if spos is None:
|
||||
return c.popleft()
|
||||
else:
|
||||
return c.pop()
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
while nargs_spec:
|
||||
nargs = _fetch(nargs_spec)
|
||||
|
||||
if nargs is None:
|
||||
continue
|
||||
|
||||
if nargs == 1:
|
||||
rv.append(_fetch(args))
|
||||
elif nargs > 1:
|
||||
x = [_fetch(args) for _ in range(nargs)]
|
||||
|
||||
# If we're reversed, we're pulling in the arguments in reverse,
|
||||
# so we need to turn them around.
|
||||
if spos is not None:
|
||||
x.reverse()
|
||||
|
||||
rv.append(tuple(x))
|
||||
elif nargs < 0:
|
||||
if spos is not None:
|
||||
raise TypeError("Cannot have two nargs < 0")
|
||||
|
||||
spos = len(rv)
|
||||
rv.append(None)
|
||||
|
||||
# spos is the position of the wildcard (star). If it's not `None`,
|
||||
# we fill it with the remainder.
|
||||
if spos is not None:
|
||||
rv[spos] = tuple(args)
|
||||
args = []
|
||||
rv[spos + 1 :] = reversed(rv[spos + 1 :])
|
||||
|
||||
return tuple(rv), list(args)
|
||||
|
||||
|
||||
def split_opt(opt: str) -> t.Tuple[str, str]:
|
||||
first = opt[:1]
|
||||
if first.isalnum():
|
||||
return "", opt
|
||||
if opt[1:2] == first:
|
||||
return opt[:2], opt[2:]
|
||||
return first, opt[1:]
|
||||
|
||||
|
||||
def normalize_opt(opt: str, ctx: t.Optional["Context"]) -> str:
|
||||
if ctx is None or ctx.token_normalize_func is None:
|
||||
return opt
|
||||
prefix, opt = split_opt(opt)
|
||||
return f"{prefix}{ctx.token_normalize_func(opt)}"
|
||||
|
||||
|
||||
def split_arg_string(string: str) -> t.List[str]:
|
||||
"""Split an argument string as with :func:`shlex.split`, but don't
|
||||
fail if the string is incomplete. Ignores a missing closing quote or
|
||||
incomplete escape sequence and uses the partial token as-is.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
split_arg_string("example 'my file")
|
||||
["example", "my file"]
|
||||
|
||||
split_arg_string("example my\\")
|
||||
["example", "my"]
|
||||
|
||||
:param string: String to split.
|
||||
"""
|
||||
import shlex
|
||||
|
||||
lex = shlex.shlex(string, posix=True)
|
||||
lex.whitespace_split = True
|
||||
lex.commenters = ""
|
||||
out = []
|
||||
|
||||
try:
|
||||
for token in lex:
|
||||
out.append(token)
|
||||
except ValueError:
|
||||
# Raised when end-of-string is reached in an invalid state. Use
|
||||
# the partial token as-is. The quote or escape character is in
|
||||
# lex.state, not lex.token.
|
||||
out.append(lex.token)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Option:
|
||||
def __init__(
|
||||
self,
|
||||
obj: "CoreOption",
|
||||
opts: t.Sequence[str],
|
||||
dest: t.Optional[str],
|
||||
action: t.Optional[str] = None,
|
||||
nargs: int = 1,
|
||||
const: t.Optional[t.Any] = None,
|
||||
):
|
||||
self._short_opts = []
|
||||
self._long_opts = []
|
||||
self.prefixes = set()
|
||||
|
||||
for opt in opts:
|
||||
prefix, value = split_opt(opt)
|
||||
if not prefix:
|
||||
raise ValueError(f"Invalid start character for option ({opt})")
|
||||
self.prefixes.add(prefix[0])
|
||||
if len(prefix) == 1 and len(value) == 1:
|
||||
self._short_opts.append(opt)
|
||||
else:
|
||||
self._long_opts.append(opt)
|
||||
self.prefixes.add(prefix)
|
||||
|
||||
if action is None:
|
||||
action = "store"
|
||||
|
||||
self.dest = dest
|
||||
self.action = action
|
||||
self.nargs = nargs
|
||||
self.const = const
|
||||
self.obj = obj
|
||||
|
||||
@property
|
||||
def takes_value(self) -> bool:
|
||||
return self.action in ("store", "append")
|
||||
|
||||
def process(self, value: str, state: "ParsingState") -> None:
|
||||
if self.action == "store":
|
||||
state.opts[self.dest] = value # type: ignore
|
||||
elif self.action == "store_const":
|
||||
state.opts[self.dest] = self.const # type: ignore
|
||||
elif self.action == "append":
|
||||
state.opts.setdefault(self.dest, []).append(value) # type: ignore
|
||||
elif self.action == "append_const":
|
||||
state.opts.setdefault(self.dest, []).append(self.const) # type: ignore
|
||||
elif self.action == "count":
|
||||
state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore
|
||||
else:
|
||||
raise ValueError(f"unknown action '{self.action}'")
|
||||
state.order.append(self.obj)
|
||||
|
||||
|
||||
class Argument:
|
||||
def __init__(self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1):
|
||||
self.dest = dest
|
||||
self.nargs = nargs
|
||||
self.obj = obj
|
||||
|
||||
def process(
|
||||
self,
|
||||
value: t.Union[t.Optional[str], t.Sequence[t.Optional[str]]],
|
||||
state: "ParsingState",
|
||||
) -> None:
|
||||
if self.nargs > 1:
|
||||
assert value is not None
|
||||
holes = sum(1 for x in value if x is None)
|
||||
if holes == len(value):
|
||||
value = None
|
||||
elif holes != 0:
|
||||
raise BadArgumentUsage(
|
||||
_("Argument {name!r} takes {nargs} values.").format(
|
||||
name=self.dest, nargs=self.nargs
|
||||
)
|
||||
)
|
||||
|
||||
if self.nargs == -1 and self.obj.envvar is not None and value == ():
|
||||
# Replace empty tuple with None so that a value from the
|
||||
# environment may be tried.
|
||||
value = None
|
||||
|
||||
state.opts[self.dest] = value # type: ignore
|
||||
state.order.append(self.obj)
|
||||
|
||||
|
||||
class ParsingState:
|
||||
def __init__(self, rargs: t.List[str]) -> None:
|
||||
self.opts: t.Dict[str, t.Any] = {}
|
||||
self.largs: t.List[str] = []
|
||||
self.rargs = rargs
|
||||
self.order: t.List["CoreParameter"] = []
|
||||
|
||||
|
||||
class OptionParser:
|
||||
"""The option parser is an internal class that is ultimately used to
|
||||
parse options and arguments. It's modelled after optparse and brings
|
||||
a similar but vastly simplified API. It should generally not be used
|
||||
directly as the high level Click classes wrap it for you.
|
||||
|
||||
It's not nearly as extensible as optparse or argparse as it does not
|
||||
implement features that are implemented on a higher level (such as
|
||||
types or defaults).
|
||||
|
||||
:param ctx: optionally the :class:`~click.Context` where this parser
|
||||
should go with.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: t.Optional["Context"] = None) -> None:
|
||||
#: The :class:`~click.Context` for this parser. This might be
|
||||
#: `None` for some advanced use cases.
|
||||
self.ctx = ctx
|
||||
#: This controls how the parser deals with interspersed arguments.
|
||||
#: If this is set to `False`, the parser will stop on the first
|
||||
#: non-option. Click uses this to implement nested subcommands
|
||||
#: safely.
|
||||
self.allow_interspersed_args = True
|
||||
#: This tells the parser how to deal with unknown options. By
|
||||
#: default it will error out (which is sensible), but there is a
|
||||
#: second mode where it will ignore it and continue processing
|
||||
#: after shifting all the unknown options into the resulting args.
|
||||
self.ignore_unknown_options = False
|
||||
|
||||
if ctx is not None:
|
||||
self.allow_interspersed_args = ctx.allow_interspersed_args
|
||||
self.ignore_unknown_options = ctx.ignore_unknown_options
|
||||
|
||||
self._short_opt: t.Dict[str, Option] = {}
|
||||
self._long_opt: t.Dict[str, Option] = {}
|
||||
self._opt_prefixes = {"-", "--"}
|
||||
self._args: t.List[Argument] = []
|
||||
|
||||
def add_option(
|
||||
self,
|
||||
obj: "CoreOption",
|
||||
opts: t.Sequence[str],
|
||||
dest: t.Optional[str],
|
||||
action: t.Optional[str] = None,
|
||||
nargs: int = 1,
|
||||
const: t.Optional[t.Any] = None,
|
||||
) -> None:
|
||||
"""Adds a new option named `dest` to the parser. The destination
|
||||
is not inferred (unlike with optparse) and needs to be explicitly
|
||||
provided. Action can be any of ``store``, ``store_const``,
|
||||
``append``, ``append_const`` or ``count``.
|
||||
|
||||
The `obj` can be used to identify the option in the order list
|
||||
that is returned from the parser.
|
||||
"""
|
||||
opts = [normalize_opt(opt, self.ctx) for opt in opts]
|
||||
option = Option(obj, opts, dest, action=action, nargs=nargs, const=const)
|
||||
self._opt_prefixes.update(option.prefixes)
|
||||
for opt in option._short_opts:
|
||||
self._short_opt[opt] = option
|
||||
for opt in option._long_opts:
|
||||
self._long_opt[opt] = option
|
||||
|
||||
def add_argument(
|
||||
self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1
|
||||
) -> None:
|
||||
"""Adds a positional argument named `dest` to the parser.
|
||||
|
||||
The `obj` can be used to identify the option in the order list
|
||||
that is returned from the parser.
|
||||
"""
|
||||
self._args.append(Argument(obj, dest=dest, nargs=nargs))
|
||||
|
||||
def parse_args(
|
||||
self, args: t.List[str]
|
||||
) -> t.Tuple[t.Dict[str, t.Any], t.List[str], t.List["CoreParameter"]]:
|
||||
"""Parses positional arguments and returns ``(values, args, order)``
|
||||
for the parsed options and arguments as well as the leftover
|
||||
arguments if there are any. The order is a list of objects as they
|
||||
appear on the command line. If arguments appear multiple times they
|
||||
will be memorized multiple times as well.
|
||||
"""
|
||||
state = ParsingState(args)
|
||||
try:
|
||||
self._process_args_for_options(state)
|
||||
self._process_args_for_args(state)
|
||||
except UsageError:
|
||||
if self.ctx is None or not self.ctx.resilient_parsing:
|
||||
raise
|
||||
return state.opts, state.largs, state.order
|
||||
|
||||
def _process_args_for_args(self, state: ParsingState) -> None:
|
||||
pargs, args = _unpack_args(
|
||||
state.largs + state.rargs, [x.nargs for x in self._args]
|
||||
)
|
||||
|
||||
for idx, arg in enumerate(self._args):
|
||||
arg.process(pargs[idx], state)
|
||||
|
||||
state.largs = args
|
||||
state.rargs = []
|
||||
|
||||
def _process_args_for_options(self, state: ParsingState) -> None:
|
||||
while state.rargs:
|
||||
arg = state.rargs.pop(0)
|
||||
arglen = len(arg)
|
||||
# Double dashes always handled explicitly regardless of what
|
||||
# prefixes are valid.
|
||||
if arg == "--":
|
||||
return
|
||||
elif arg[:1] in self._opt_prefixes and arglen > 1:
|
||||
self._process_opts(arg, state)
|
||||
elif self.allow_interspersed_args:
|
||||
state.largs.append(arg)
|
||||
else:
|
||||
state.rargs.insert(0, arg)
|
||||
return
|
||||
|
||||
# Say this is the original argument list:
|
||||
# [arg0, arg1, ..., arg(i-1), arg(i), arg(i+1), ..., arg(N-1)]
|
||||
# ^
|
||||
# (we are about to process arg(i)).
|
||||
#
|
||||
# Then rargs is [arg(i), ..., arg(N-1)] and largs is a *subset* of
|
||||
# [arg0, ..., arg(i-1)] (any options and their arguments will have
|
||||
# been removed from largs).
|
||||
#
|
||||
# The while loop will usually consume 1 or more arguments per pass.
|
||||
# If it consumes 1 (eg. arg is an option that takes no arguments),
|
||||
# then after _process_arg() is done the situation is:
|
||||
#
|
||||
# largs = subset of [arg0, ..., arg(i)]
|
||||
# rargs = [arg(i+1), ..., arg(N-1)]
|
||||
#
|
||||
# If allow_interspersed_args is false, largs will always be
|
||||
# *empty* -- still a subset of [arg0, ..., arg(i-1)], but
|
||||
# not a very interesting subset!
|
||||
|
||||
def _match_long_opt(
|
||||
self, opt: str, explicit_value: t.Optional[str], state: ParsingState
|
||||
) -> None:
|
||||
if opt not in self._long_opt:
|
||||
from difflib import get_close_matches
|
||||
|
||||
possibilities = get_close_matches(opt, self._long_opt)
|
||||
raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx)
|
||||
|
||||
option = self._long_opt[opt]
|
||||
if option.takes_value:
|
||||
# At this point it's safe to modify rargs by injecting the
|
||||
# explicit value, because no exception is raised in this
|
||||
# branch. This means that the inserted value will be fully
|
||||
# consumed.
|
||||
if explicit_value is not None:
|
||||
state.rargs.insert(0, explicit_value)
|
||||
|
||||
value = self._get_value_from_state(opt, option, state)
|
||||
|
||||
elif explicit_value is not None:
|
||||
raise BadOptionUsage(
|
||||
opt, _("Option {name!r} does not take a value.").format(name=opt)
|
||||
)
|
||||
|
||||
else:
|
||||
value = None
|
||||
|
||||
option.process(value, state)
|
||||
|
||||
def _match_short_opt(self, arg: str, state: ParsingState) -> None:
|
||||
stop = False
|
||||
i = 1
|
||||
prefix = arg[0]
|
||||
unknown_options = []
|
||||
|
||||
for ch in arg[1:]:
|
||||
opt = normalize_opt(f"{prefix}{ch}", self.ctx)
|
||||
option = self._short_opt.get(opt)
|
||||
i += 1
|
||||
|
||||
if not option:
|
||||
if self.ignore_unknown_options:
|
||||
unknown_options.append(ch)
|
||||
continue
|
||||
raise NoSuchOption(opt, ctx=self.ctx)
|
||||
if option.takes_value:
|
||||
# Any characters left in arg? Pretend they're the
|
||||
# next arg, and stop consuming characters of arg.
|
||||
if i < len(arg):
|
||||
state.rargs.insert(0, arg[i:])
|
||||
stop = True
|
||||
|
||||
value = self._get_value_from_state(opt, option, state)
|
||||
|
||||
else:
|
||||
value = None
|
||||
|
||||
option.process(value, state)
|
||||
|
||||
if stop:
|
||||
break
|
||||
|
||||
# If we got any unknown options we re-combinate the string of the
|
||||
# remaining options and re-attach the prefix, then report that
|
||||
# to the state as new larg. This way there is basic combinatorics
|
||||
# that can be achieved while still ignoring unknown arguments.
|
||||
if self.ignore_unknown_options and unknown_options:
|
||||
state.largs.append(f"{prefix}{''.join(unknown_options)}")
|
||||
|
||||
def _get_value_from_state(
|
||||
self, option_name: str, option: Option, state: ParsingState
|
||||
) -> t.Any:
|
||||
nargs = option.nargs
|
||||
|
||||
if len(state.rargs) < nargs:
|
||||
if option.obj._flag_needs_value:
|
||||
# Option allows omitting the value.
|
||||
value = _flag_needs_value
|
||||
else:
|
||||
raise BadOptionUsage(
|
||||
option_name,
|
||||
ngettext(
|
||||
"Option {name!r} requires an argument.",
|
||||
"Option {name!r} requires {nargs} arguments.",
|
||||
nargs,
|
||||
).format(name=option_name, nargs=nargs),
|
||||
)
|
||||
elif nargs == 1:
|
||||
next_rarg = state.rargs[0]
|
||||
|
||||
if (
|
||||
option.obj._flag_needs_value
|
||||
and isinstance(next_rarg, str)
|
||||
and next_rarg[:1] in self._opt_prefixes
|
||||
and len(next_rarg) > 1
|
||||
):
|
||||
# The next arg looks like the start of an option, don't
|
||||
# use it as the value if omitting the value is allowed.
|
||||
value = _flag_needs_value
|
||||
else:
|
||||
value = state.rargs.pop(0)
|
||||
else:
|
||||
value = tuple(state.rargs[:nargs])
|
||||
del state.rargs[:nargs]
|
||||
|
||||
return value
|
||||
|
||||
def _process_opts(self, arg: str, state: ParsingState) -> None:
|
||||
explicit_value = None
|
||||
# Long option handling happens in two parts. The first part is
|
||||
# supporting explicitly attached values. In any case, we will try
|
||||
# to long match the option first.
|
||||
if "=" in arg:
|
||||
long_opt, explicit_value = arg.split("=", 1)
|
||||
else:
|
||||
long_opt = arg
|
||||
norm_long_opt = normalize_opt(long_opt, self.ctx)
|
||||
|
||||
# At this point we will match the (assumed) long option through
|
||||
# the long option matching code. Note that this allows options
|
||||
# like "-foo" to be matched as long options.
|
||||
try:
|
||||
self._match_long_opt(norm_long_opt, explicit_value, state)
|
||||
except NoSuchOption:
|
||||
# At this point the long option matching failed, and we need
|
||||
# to try with short options. However there is a special rule
|
||||
# which says, that if we have a two character options prefix
|
||||
# (applies to "--foo" for instance), we do not dispatch to the
|
||||
# short option code and will instead raise the no option
|
||||
# error.
|
||||
if arg[:2] not in self._opt_prefixes:
|
||||
self._match_short_opt(arg, state)
|
||||
return
|
||||
|
||||
if not self.ignore_unknown_options:
|
||||
raise
|
||||
|
||||
state.largs.append(arg)
|
@ -0,0 +1,580 @@
|
||||
import os
|
||||
import re
|
||||
import typing as t
|
||||
from gettext import gettext as _
|
||||
|
||||
from .core import Argument
|
||||
from .core import BaseCommand
|
||||
from .core import Context
|
||||
from .core import MultiCommand
|
||||
from .core import Option
|
||||
from .core import Parameter
|
||||
from .core import ParameterSource
|
||||
from .parser import split_arg_string
|
||||
from .utils import echo
|
||||
|
||||
|
||||
def shell_complete(
|
||||
cli: BaseCommand,
|
||||
ctx_args: t.Dict[str, t.Any],
|
||||
prog_name: str,
|
||||
complete_var: str,
|
||||
instruction: str,
|
||||
) -> int:
|
||||
"""Perform shell completion for the given CLI program.
|
||||
|
||||
:param cli: Command being called.
|
||||
:param ctx_args: Extra arguments to pass to
|
||||
``cli.make_context``.
|
||||
:param prog_name: Name of the executable in the shell.
|
||||
:param complete_var: Name of the environment variable that holds
|
||||
the completion instruction.
|
||||
:param instruction: Value of ``complete_var`` with the completion
|
||||
instruction and shell, in the form ``instruction_shell``.
|
||||
:return: Status code to exit with.
|
||||
"""
|
||||
shell, _, instruction = instruction.partition("_")
|
||||
comp_cls = get_completion_class(shell)
|
||||
|
||||
if comp_cls is None:
|
||||
return 1
|
||||
|
||||
comp = comp_cls(cli, ctx_args, prog_name, complete_var)
|
||||
|
||||
if instruction == "source":
|
||||
echo(comp.source())
|
||||
return 0
|
||||
|
||||
if instruction == "complete":
|
||||
echo(comp.complete())
|
||||
return 0
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
class CompletionItem:
|
||||
"""Represents a completion value and metadata about the value. The
|
||||
default metadata is ``type`` to indicate special shell handling,
|
||||
and ``help`` if a shell supports showing a help string next to the
|
||||
value.
|
||||
|
||||
Arbitrary parameters can be passed when creating the object, and
|
||||
accessed using ``item.attr``. If an attribute wasn't passed,
|
||||
accessing it returns ``None``.
|
||||
|
||||
:param value: The completion suggestion.
|
||||
:param type: Tells the shell script to provide special completion
|
||||
support for the type. Click uses ``"dir"`` and ``"file"``.
|
||||
:param help: String shown next to the value if supported.
|
||||
:param kwargs: Arbitrary metadata. The built-in implementations
|
||||
don't use this, but custom type completions paired with custom
|
||||
shell support could use it.
|
||||
"""
|
||||
|
||||
__slots__ = ("value", "type", "help", "_info")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: t.Any,
|
||||
type: str = "plain",
|
||||
help: t.Optional[str] = None,
|
||||
**kwargs: t.Any,
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.type = type
|
||||
self.help = help
|
||||
self._info = kwargs
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
return self._info.get(name)
|
||||
|
||||
|
||||
# Only Bash >= 4.4 has the nosort option.
|
||||
_SOURCE_BASH = """\
|
||||
%(complete_func)s() {
|
||||
local IFS=$'\\n'
|
||||
local response
|
||||
|
||||
response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \
|
||||
%(complete_var)s=bash_complete $1)
|
||||
|
||||
for completion in $response; do
|
||||
IFS=',' read type value <<< "$completion"
|
||||
|
||||
if [[ $type == 'dir' ]]; then
|
||||
COMPREPLY=()
|
||||
compopt -o dirnames
|
||||
elif [[ $type == 'file' ]]; then
|
||||
COMPREPLY=()
|
||||
compopt -o default
|
||||
elif [[ $type == 'plain' ]]; then
|
||||
COMPREPLY+=($value)
|
||||
fi
|
||||
done
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
%(complete_func)s_setup() {
|
||||
complete -o nosort -F %(complete_func)s %(prog_name)s
|
||||
}
|
||||
|
||||
%(complete_func)s_setup;
|
||||
"""
|
||||
|
||||
_SOURCE_ZSH = """\
|
||||
#compdef %(prog_name)s
|
||||
|
||||
%(complete_func)s() {
|
||||
local -a completions
|
||||
local -a completions_with_descriptions
|
||||
local -a response
|
||||
(( ! $+commands[%(prog_name)s] )) && return 1
|
||||
|
||||
response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \
|
||||
%(complete_var)s=zsh_complete %(prog_name)s)}")
|
||||
|
||||
for type key descr in ${response}; do
|
||||
if [[ "$type" == "plain" ]]; then
|
||||
if [[ "$descr" == "_" ]]; then
|
||||
completions+=("$key")
|
||||
else
|
||||
completions_with_descriptions+=("$key":"$descr")
|
||||
fi
|
||||
elif [[ "$type" == "dir" ]]; then
|
||||
_path_files -/
|
||||
elif [[ "$type" == "file" ]]; then
|
||||
_path_files -f
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -n "$completions_with_descriptions" ]; then
|
||||
_describe -V unsorted completions_with_descriptions -U
|
||||
fi
|
||||
|
||||
if [ -n "$completions" ]; then
|
||||
compadd -U -V unsorted -a completions
|
||||
fi
|
||||
}
|
||||
|
||||
compdef %(complete_func)s %(prog_name)s;
|
||||
"""
|
||||
|
||||
_SOURCE_FISH = """\
|
||||
function %(complete_func)s;
|
||||
set -l response;
|
||||
|
||||
for value in (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \
|
||||
COMP_CWORD=(commandline -t) %(prog_name)s);
|
||||
set response $response $value;
|
||||
end;
|
||||
|
||||
for completion in $response;
|
||||
set -l metadata (string split "," $completion);
|
||||
|
||||
if test $metadata[1] = "dir";
|
||||
__fish_complete_directories $metadata[2];
|
||||
else if test $metadata[1] = "file";
|
||||
__fish_complete_path $metadata[2];
|
||||
else if test $metadata[1] = "plain";
|
||||
echo $metadata[2];
|
||||
end;
|
||||
end;
|
||||
end;
|
||||
|
||||
complete --no-files --command %(prog_name)s --arguments \
|
||||
"(%(complete_func)s)";
|
||||
"""
|
||||
|
||||
|
||||
class ShellComplete:
|
||||
"""Base class for providing shell completion support. A subclass for
|
||||
a given shell will override attributes and methods to implement the
|
||||
completion instructions (``source`` and ``complete``).
|
||||
|
||||
:param cli: Command being called.
|
||||
:param prog_name: Name of the executable in the shell.
|
||||
:param complete_var: Name of the environment variable that holds
|
||||
the completion instruction.
|
||||
|
||||
.. versionadded:: 8.0
|
||||
"""
|
||||
|
||||
name: t.ClassVar[str]
|
||||
"""Name to register the shell as with :func:`add_completion_class`.
|
||||
This is used in completion instructions (``{name}_source`` and
|
||||
``{name}_complete``).
|
||||
"""
|
||||
|
||||
source_template: t.ClassVar[str]
|
||||
"""Completion script template formatted by :meth:`source`. This must
|
||||
be provided by subclasses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cli: BaseCommand,
|
||||
ctx_args: t.Dict[str, t.Any],
|
||||
prog_name: str,
|
||||
complete_var: str,
|
||||
) -> None:
|
||||
self.cli = cli
|
||||
self.ctx_args = ctx_args
|
||||
self.prog_name = prog_name
|
||||
self.complete_var = complete_var
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
"""The name of the shell function defined by the completion
|
||||
script.
|
||||
"""
|
||||
safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), re.ASCII)
|
||||
return f"_{safe_name}_completion"
|
||||
|
||||
def source_vars(self) -> t.Dict[str, t.Any]:
|
||||
"""Vars for formatting :attr:`source_template`.
|
||||
|
||||
By default this provides ``complete_func``, ``complete_var``,
|
||||
and ``prog_name``.
|
||||
"""
|
||||
return {
|
||||
"complete_func": self.func_name,
|
||||
"complete_var": self.complete_var,
|
||||
"prog_name": self.prog_name,
|
||||
}
|
||||
|
||||
def source(self) -> str:
|
||||
"""Produce the shell script that defines the completion
|
||||
function. By default this ``%``-style formats
|
||||
:attr:`source_template` with the dict returned by
|
||||
:meth:`source_vars`.
|
||||
"""
|
||||
return self.source_template % self.source_vars()
|
||||
|
||||
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
||||
"""Use the env vars defined by the shell script to return a
|
||||
tuple of ``args, incomplete``. This must be implemented by
|
||||
subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_completions(
|
||||
self, args: t.List[str], incomplete: str
|
||||
) -> t.List[CompletionItem]:
|
||||
"""Determine the context and last complete command or parameter
|
||||
from the complete args. Call that object's ``shell_complete``
|
||||
method to get the completions for the incomplete value.
|
||||
|
||||
:param args: List of complete args before the incomplete value.
|
||||
:param incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
|
||||
obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
|
||||
return obj.shell_complete(ctx, incomplete)
|
||||
|
||||
def format_completion(self, item: CompletionItem) -> str:
|
||||
"""Format a completion item into the form recognized by the
|
||||
shell script. This must be implemented by subclasses.
|
||||
|
||||
:param item: Completion item to format.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def complete(self) -> str:
|
||||
"""Produce the completion data to send back to the shell.
|
||||
|
||||
By default this calls :meth:`get_completion_args`, gets the
|
||||
completions, then calls :meth:`format_completion` for each
|
||||
completion.
|
||||
"""
|
||||
args, incomplete = self.get_completion_args()
|
||||
completions = self.get_completions(args, incomplete)
|
||||
out = [self.format_completion(item) for item in completions]
|
||||
return "\n".join(out)
|
||||
|
||||
|
||||
class BashComplete(ShellComplete):
|
||||
"""Shell completion for Bash."""
|
||||
|
||||
name = "bash"
|
||||
source_template = _SOURCE_BASH
|
||||
|
||||
def _check_version(self) -> None:
|
||||
import subprocess
|
||||
|
||||
output = subprocess.run(
|
||||
["bash", "-c", "echo ${BASH_VERSION}"], stdout=subprocess.PIPE
|
||||
)
|
||||
match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
|
||||
|
||||
if match is not None:
|
||||
major, minor = match.groups()
|
||||
|
||||
if major < "4" or major == "4" and minor < "4":
|
||||
raise RuntimeError(
|
||||
_(
|
||||
"Shell completion is not supported for Bash"
|
||||
" versions older than 4.4."
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
_("Couldn't detect Bash version, shell completion is not supported.")
|
||||
)
|
||||
|
||||
def source(self) -> str:
|
||||
self._check_version()
|
||||
return super().source()
|
||||
|
||||
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
||||
cwords = split_arg_string(os.environ["COMP_WORDS"])
|
||||
cword = int(os.environ["COMP_CWORD"])
|
||||
args = cwords[1:cword]
|
||||
|
||||
try:
|
||||
incomplete = cwords[cword]
|
||||
except IndexError:
|
||||
incomplete = ""
|
||||
|
||||
return args, incomplete
|
||||
|
||||
def format_completion(self, item: CompletionItem) -> str:
|
||||
return f"{item.type},{item.value}"
|
||||
|
||||
|
||||
class ZshComplete(ShellComplete):
|
||||
"""Shell completion for Zsh."""
|
||||
|
||||
name = "zsh"
|
||||
source_template = _SOURCE_ZSH
|
||||
|
||||
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
||||
cwords = split_arg_string(os.environ["COMP_WORDS"])
|
||||
cword = int(os.environ["COMP_CWORD"])
|
||||
args = cwords[1:cword]
|
||||
|
||||
try:
|
||||
incomplete = cwords[cword]
|
||||
except IndexError:
|
||||
incomplete = ""
|
||||
|
||||
return args, incomplete
|
||||
|
||||
def format_completion(self, item: CompletionItem) -> str:
|
||||
return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}"
|
||||
|
||||
|
||||
class FishComplete(ShellComplete):
|
||||
"""Shell completion for Fish."""
|
||||
|
||||
name = "fish"
|
||||
source_template = _SOURCE_FISH
|
||||
|
||||
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
||||
cwords = split_arg_string(os.environ["COMP_WORDS"])
|
||||
incomplete = os.environ["COMP_CWORD"]
|
||||
args = cwords[1:]
|
||||
|
||||
# Fish stores the partial word in both COMP_WORDS and
|
||||
# COMP_CWORD, remove it from complete args.
|
||||
if incomplete and args and args[-1] == incomplete:
|
||||
args.pop()
|
||||
|
||||
return args, incomplete
|
||||
|
||||
def format_completion(self, item: CompletionItem) -> str:
|
||||
if item.help:
|
||||
return f"{item.type},{item.value}\t{item.help}"
|
||||
|
||||
return f"{item.type},{item.value}"
|
||||
|
||||
|
||||
_available_shells: t.Dict[str, t.Type[ShellComplete]] = {
|
||||
"bash": BashComplete,
|
||||
"fish": FishComplete,
|
||||
"zsh": ZshComplete,
|
||||
}
|
||||
|
||||
|
||||
def add_completion_class(
|
||||
cls: t.Type[ShellComplete], name: t.Optional[str] = None
|
||||
) -> None:
|
||||
"""Register a :class:`ShellComplete` subclass under the given name.
|
||||
The name will be provided by the completion instruction environment
|
||||
variable during completion.
|
||||
|
||||
:param cls: The completion class that will handle completion for the
|
||||
shell.
|
||||
:param name: Name to register the class under. Defaults to the
|
||||
class's ``name`` attribute.
|
||||
"""
|
||||
if name is None:
|
||||
name = cls.name
|
||||
|
||||
_available_shells[name] = cls
|
||||
|
||||
|
||||
def get_completion_class(shell: str) -> t.Optional[t.Type[ShellComplete]]:
|
||||
"""Look up a registered :class:`ShellComplete` subclass by the name
|
||||
provided by the completion instruction environment variable. If the
|
||||
name isn't registered, returns ``None``.
|
||||
|
||||
:param shell: Name the class is registered under.
|
||||
"""
|
||||
return _available_shells.get(shell)
|
||||
|
||||
|
||||
def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
|
||||
"""Determine if the given parameter is an argument that can still
|
||||
accept values.
|
||||
|
||||
:param ctx: Invocation context for the command represented by the
|
||||
parsed complete args.
|
||||
:param param: Argument object being checked.
|
||||
"""
|
||||
if not isinstance(param, Argument):
|
||||
return False
|
||||
|
||||
assert param.name is not None
|
||||
value = ctx.params[param.name]
|
||||
return (
|
||||
param.nargs == -1
|
||||
or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
|
||||
or (
|
||||
param.nargs > 1
|
||||
and isinstance(value, (tuple, list))
|
||||
and len(value) < param.nargs
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _start_of_option(ctx: Context, value: str) -> bool:
|
||||
"""Check if the value looks like the start of an option."""
|
||||
if not value:
|
||||
return False
|
||||
|
||||
c = value[0]
|
||||
return c in ctx._opt_prefixes
|
||||
|
||||
|
||||
def _is_incomplete_option(ctx: Context, args: t.List[str], param: Parameter) -> bool:
|
||||
"""Determine if the given parameter is an option that needs a value.
|
||||
|
||||
:param args: List of complete args before the incomplete value.
|
||||
:param param: Option object being checked.
|
||||
"""
|
||||
if not isinstance(param, Option):
|
||||
return False
|
||||
|
||||
if param.is_flag or param.count:
|
||||
return False
|
||||
|
||||
last_option = None
|
||||
|
||||
for index, arg in enumerate(reversed(args)):
|
||||
if index + 1 > param.nargs:
|
||||
break
|
||||
|
||||
if _start_of_option(ctx, arg):
|
||||
last_option = arg
|
||||
|
||||
return last_option is not None and last_option in param.opts
|
||||
|
||||
|
||||
def _resolve_context(
|
||||
cli: BaseCommand, ctx_args: t.Dict[str, t.Any], prog_name: str, args: t.List[str]
|
||||
) -> Context:
|
||||
"""Produce the context hierarchy starting with the command and
|
||||
traversing the complete arguments. This only follows the commands,
|
||||
it doesn't trigger input prompts or callbacks.
|
||||
|
||||
:param cli: Command being called.
|
||||
:param prog_name: Name of the executable in the shell.
|
||||
:param args: List of complete args before the incomplete value.
|
||||
"""
|
||||
ctx_args["resilient_parsing"] = True
|
||||
ctx = cli.make_context(prog_name, args.copy(), **ctx_args)
|
||||
args = ctx.protected_args + ctx.args
|
||||
|
||||
while args:
|
||||
command = ctx.command
|
||||
|
||||
if isinstance(command, MultiCommand):
|
||||
if not command.chain:
|
||||
name, cmd, args = command.resolve_command(ctx, args)
|
||||
|
||||
if cmd is None:
|
||||
return ctx
|
||||
|
||||
ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True)
|
||||
args = ctx.protected_args + ctx.args
|
||||
else:
|
||||
while args:
|
||||
name, cmd, args = command.resolve_command(ctx, args)
|
||||
|
||||
if cmd is None:
|
||||
return ctx
|
||||
|
||||
sub_ctx = cmd.make_context(
|
||||
name,
|
||||
args,
|
||||
parent=ctx,
|
||||
allow_extra_args=True,
|
||||
allow_interspersed_args=False,
|
||||
resilient_parsing=True,
|
||||
)
|
||||
args = sub_ctx.args
|
||||
|
||||
ctx = sub_ctx
|
||||
args = [*sub_ctx.protected_args, *sub_ctx.args]
|
||||
else:
|
||||
break
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
def _resolve_incomplete(
|
||||
ctx: Context, args: t.List[str], incomplete: str
|
||||
) -> t.Tuple[t.Union[BaseCommand, Parameter], str]:
|
||||
"""Find the Click object that will handle the completion of the
|
||||
incomplete value. Return the object and the incomplete value.
|
||||
|
||||
:param ctx: Invocation context for the command represented by
|
||||
the parsed complete args.
|
||||
:param args: List of complete args before the incomplete value.
|
||||
:param incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
# Different shells treat an "=" between a long option name and
|
||||
# value differently. Might keep the value joined, return the "="
|
||||
# as a separate item, or return the split name and value. Always
|
||||
# split and discard the "=" to make completion easier.
|
||||
if incomplete == "=":
|
||||
incomplete = ""
|
||||
elif "=" in incomplete and _start_of_option(ctx, incomplete):
|
||||
name, _, incomplete = incomplete.partition("=")
|
||||
args.append(name)
|
||||
|
||||
# The "--" marker tells Click to stop treating values as options
|
||||
# even if they start with the option character. If it hasn't been
|
||||
# given and the incomplete arg looks like an option, the current
|
||||
# command will provide option name completions.
|
||||
if "--" not in args and _start_of_option(ctx, incomplete):
|
||||
return ctx.command, incomplete
|
||||
|
||||
params = ctx.command.get_params(ctx)
|
||||
|
||||
# If the last complete arg is an option name with an incomplete
|
||||
# value, the option will provide value completions.
|
||||
for param in params:
|
||||
if _is_incomplete_option(ctx, args, param):
|
||||
return param, incomplete
|
||||
|
||||
# It's not an option name or value. The first argument without a
|
||||
# parsed value will provide value completions.
|
||||
for param in params:
|
||||
if _is_incomplete_argument(ctx, param):
|
||||
return param, incomplete
|
||||
|
||||
# There were no unparsed arguments, the command may be a group that
|
||||
# will provide command name completions.
|
||||
return ctx.command, incomplete
|
@ -0,0 +1,787 @@
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from gettext import gettext as _
|
||||
|
||||
from ._compat import isatty
|
||||
from ._compat import strip_ansi
|
||||
from ._compat import WIN
|
||||
from .exceptions import Abort
|
||||
from .exceptions import UsageError
|
||||
from .globals import resolve_color_default
|
||||
from .types import Choice
|
||||
from .types import convert_type
|
||||
from .types import ParamType
|
||||
from .utils import echo
|
||||
from .utils import LazyFile
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._termui_impl import ProgressBar
|
||||
|
||||
V = t.TypeVar("V")
|
||||
|
||||
# The prompt functions to use. The doc tools currently override these
|
||||
# functions to customize how they work.
|
||||
visible_prompt_func: t.Callable[[str], str] = input
|
||||
|
||||
_ansi_colors = {
|
||||
"black": 30,
|
||||
"red": 31,
|
||||
"green": 32,
|
||||
"yellow": 33,
|
||||
"blue": 34,
|
||||
"magenta": 35,
|
||||
"cyan": 36,
|
||||
"white": 37,
|
||||
"reset": 39,
|
||||
"bright_black": 90,
|
||||
"bright_red": 91,
|
||||
"bright_green": 92,
|
||||
"bright_yellow": 93,
|
||||
"bright_blue": 94,
|
||||
"bright_magenta": 95,
|
||||
"bright_cyan": 96,
|
||||
"bright_white": 97,
|
||||
}
|
||||
_ansi_reset_all = "\033[0m"
|
||||
|
||||
|
||||
def hidden_prompt_func(prompt: str) -> str:
|
||||
import getpass
|
||||
|
||||
return getpass.getpass(prompt)
|
||||
|
||||
|
||||
def _build_prompt(
|
||||
text: str,
|
||||
suffix: str,
|
||||
show_default: bool = False,
|
||||
default: t.Optional[t.Any] = None,
|
||||
show_choices: bool = True,
|
||||
type: t.Optional[ParamType] = None,
|
||||
) -> str:
|
||||
prompt = text
|
||||
if type is not None and show_choices and isinstance(type, Choice):
|
||||
prompt += f" ({', '.join(map(str, type.choices))})"
|
||||
if default is not None and show_default:
|
||||
prompt = f"{prompt} [{_format_default(default)}]"
|
||||
return f"{prompt}{suffix}"
|
||||
|
||||
|
||||
def _format_default(default: t.Any) -> t.Any:
|
||||
if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
|
||||
return default.name # type: ignore
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def prompt(
|
||||
text: str,
|
||||
default: t.Optional[t.Any] = None,
|
||||
hide_input: bool = False,
|
||||
confirmation_prompt: t.Union[bool, str] = False,
|
||||
type: t.Optional[t.Union[ParamType, t.Any]] = None,
|
||||
value_proc: t.Optional[t.Callable[[str], t.Any]] = None,
|
||||
prompt_suffix: str = ": ",
|
||||
show_default: bool = True,
|
||||
err: bool = False,
|
||||
show_choices: bool = True,
|
||||
) -> t.Any:
|
||||
"""Prompts a user for input. This is a convenience function that can
|
||||
be used to prompt a user for input later.
|
||||
|
||||
If the user aborts the input by sending an interrupt signal, this
|
||||
function will catch it and raise a :exc:`Abort` exception.
|
||||
|
||||
:param text: the text to show for the prompt.
|
||||
:param default: the default value to use if no input happens. If this
|
||||
is not given it will prompt until it's aborted.
|
||||
:param hide_input: if this is set to true then the input value will
|
||||
be hidden.
|
||||
:param confirmation_prompt: Prompt a second time to confirm the
|
||||
value. Can be set to a string instead of ``True`` to customize
|
||||
the message.
|
||||
:param type: the type to use to check the value against.
|
||||
:param value_proc: if this parameter is provided it's a function that
|
||||
is invoked instead of the type conversion to
|
||||
convert a value.
|
||||
:param prompt_suffix: a suffix that should be added to the prompt.
|
||||
:param show_default: shows or hides the default value in the prompt.
|
||||
:param err: if set to true the file defaults to ``stderr`` instead of
|
||||
``stdout``, the same as with echo.
|
||||
:param show_choices: Show or hide choices if the passed type is a Choice.
|
||||
For example if type is a Choice of either day or week,
|
||||
show_choices is true and text is "Group by" then the
|
||||
prompt will be "Group by (day, week): ".
|
||||
|
||||
.. versionadded:: 8.0
|
||||
``confirmation_prompt`` can be a custom string.
|
||||
|
||||
.. versionadded:: 7.0
|
||||
Added the ``show_choices`` parameter.
|
||||
|
||||
.. versionadded:: 6.0
|
||||
Added unicode support for cmd.exe on Windows.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
Added the `err` parameter.
|
||||
|
||||
"""
|
||||
|
||||
def prompt_func(text: str) -> str:
|
||||
f = hidden_prompt_func if hide_input else visible_prompt_func
|
||||
try:
|
||||
# Write the prompt separately so that we get nice
|
||||
# coloring through colorama on Windows
|
||||
echo(text.rstrip(" "), nl=False, err=err)
|
||||
# Echo a space to stdout to work around an issue where
|
||||
# readline causes backspace to clear the whole line.
|
||||
return f(" ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# getpass doesn't print a newline if the user aborts input with ^C.
|
||||
# Allegedly this behavior is inherited from getpass(3).
|
||||
# A doc bug has been filed at https://bugs.python.org/issue24711
|
||||
if hide_input:
|
||||
echo(None, err=err)
|
||||
raise Abort() from None
|
||||
|
||||
if value_proc is None:
|
||||
value_proc = convert_type(type, default)
|
||||
|
||||
prompt = _build_prompt(
|
||||
text, prompt_suffix, show_default, default, show_choices, type
|
||||
)
|
||||
|
||||
if confirmation_prompt:
|
||||
if confirmation_prompt is True:
|
||||
confirmation_prompt = _("Repeat for confirmation")
|
||||
|
||||
confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
|
||||
|
||||
while True:
|
||||
while True:
|
||||
value = prompt_func(prompt)
|
||||
if value:
|
||||
break
|
||||
elif default is not None:
|
||||
value = default
|
||||
break
|
||||
try:
|
||||
result = value_proc(value)
|
||||
except UsageError as e:
|
||||
if hide_input:
|
||||
echo(_("Error: The value you entered was invalid."), err=err)
|
||||
else:
|
||||
echo(_("Error: {e.message}").format(e=e), err=err) # noqa: B306
|
||||
continue
|
||||
if not confirmation_prompt:
|
||||
return result
|
||||
while True:
|
||||
value2 = prompt_func(confirmation_prompt)
|
||||
is_empty = not value and not value2
|
||||
if value2 or is_empty:
|
||||
break
|
||||
if value == value2:
|
||||
return result
|
||||
echo(_("Error: The two entered values do not match."), err=err)
|
||||
|
||||
|
||||
def confirm(
|
||||
text: str,
|
||||
default: t.Optional[bool] = False,
|
||||
abort: bool = False,
|
||||
prompt_suffix: str = ": ",
|
||||
show_default: bool = True,
|
||||
err: bool = False,
|
||||
) -> bool:
|
||||
"""Prompts for confirmation (yes/no question).
|
||||
|
||||
If the user aborts the input by sending a interrupt signal this
|
||||
function will catch it and raise a :exc:`Abort` exception.
|
||||
|
||||
:param text: the question to ask.
|
||||
:param default: The default value to use when no input is given. If
|
||||
``None``, repeat until input is given.
|
||||
:param abort: if this is set to `True` a negative answer aborts the
|
||||
exception by raising :exc:`Abort`.
|
||||
:param prompt_suffix: a suffix that should be added to the prompt.
|
||||
:param show_default: shows or hides the default value in the prompt.
|
||||
:param err: if set to true the file defaults to ``stderr`` instead of
|
||||
``stdout``, the same as with echo.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Repeat until input is given if ``default`` is ``None``.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
Added the ``err`` parameter.
|
||||
"""
|
||||
prompt = _build_prompt(
|
||||
text,
|
||||
prompt_suffix,
|
||||
show_default,
|
||||
"y/n" if default is None else ("Y/n" if default else "y/N"),
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Write the prompt separately so that we get nice
|
||||
# coloring through colorama on Windows
|
||||
echo(prompt.rstrip(" "), nl=False, err=err)
|
||||
# Echo a space to stdout to work around an issue where
|
||||
# readline causes backspace to clear the whole line.
|
||||
value = visible_prompt_func(" ").lower().strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
raise Abort() from None
|
||||
if value in ("y", "yes"):
|
||||
rv = True
|
||||
elif value in ("n", "no"):
|
||||
rv = False
|
||||
elif default is not None and value == "":
|
||||
rv = default
|
||||
else:
|
||||
echo(_("Error: invalid input"), err=err)
|
||||
continue
|
||||
break
|
||||
if abort and not rv:
|
||||
raise Abort()
|
||||
return rv
|
||||
|
||||
|
||||
def echo_via_pager(
|
||||
text_or_generator: t.Union[t.Iterable[str], t.Callable[[], t.Iterable[str]], str],
|
||||
color: t.Optional[bool] = None,
|
||||
) -> None:
|
||||
"""This function takes a text and shows it via an environment specific
|
||||
pager on stdout.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
Added the `color` flag.
|
||||
|
||||
:param text_or_generator: the text to page, or alternatively, a
|
||||
generator emitting the text to page.
|
||||
:param color: controls if the pager supports ANSI colors or not. The
|
||||
default is autodetection.
|
||||
"""
|
||||
color = resolve_color_default(color)
|
||||
|
||||
if inspect.isgeneratorfunction(text_or_generator):
|
||||
i = t.cast(t.Callable[[], t.Iterable[str]], text_or_generator)()
|
||||
elif isinstance(text_or_generator, str):
|
||||
i = [text_or_generator]
|
||||
else:
|
||||
i = iter(t.cast(t.Iterable[str], text_or_generator))
|
||||
|
||||
# convert every element of i to a text type if necessary
|
||||
text_generator = (el if isinstance(el, str) else str(el) for el in i)
|
||||
|
||||
from ._termui_impl import pager
|
||||
|
||||
return pager(itertools.chain(text_generator, "\n"), color)
|
||||
|
||||
|
||||
def progressbar(
|
||||
iterable: t.Optional[t.Iterable[V]] = None,
|
||||
length: t.Optional[int] = None,
|
||||
label: t.Optional[str] = None,
|
||||
show_eta: bool = True,
|
||||
show_percent: t.Optional[bool] = None,
|
||||
show_pos: bool = False,
|
||||
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
|
||||
fill_char: str = "#",
|
||||
empty_char: str = "-",
|
||||
bar_template: str = "%(label)s [%(bar)s] %(info)s",
|
||||
info_sep: str = " ",
|
||||
width: int = 36,
|
||||
file: t.Optional[t.TextIO] = None,
|
||||
color: t.Optional[bool] = None,
|
||||
update_min_steps: int = 1,
|
||||
) -> "ProgressBar[V]":
|
||||
"""This function creates an iterable context manager that can be used
|
||||
to iterate over something while showing a progress bar. It will
|
||||
either iterate over the `iterable` or `length` items (that are counted
|
||||
up). While iteration happens, this function will print a rendered
|
||||
progress bar to the given `file` (defaults to stdout) and will attempt
|
||||
to calculate remaining time and more. By default, this progress bar
|
||||
will not be rendered if the file is not a terminal.
|
||||
|
||||
The context manager creates the progress bar. When the context
|
||||
manager is entered the progress bar is already created. With every
|
||||
iteration over the progress bar, the iterable passed to the bar is
|
||||
advanced and the bar is updated. When the context manager exits,
|
||||
a newline is printed and the progress bar is finalized on screen.
|
||||
|
||||
Note: The progress bar is currently designed for use cases where the
|
||||
total progress can be expected to take at least several seconds.
|
||||
Because of this, the ProgressBar class object won't display
|
||||
progress that is considered too fast, and progress where the time
|
||||
between steps is less than a second.
|
||||
|
||||
No printing must happen or the progress bar will be unintentionally
|
||||
destroyed.
|
||||
|
||||
Example usage::
|
||||
|
||||
with progressbar(items) as bar:
|
||||
for item in bar:
|
||||
do_something_with(item)
|
||||
|
||||
Alternatively, if no iterable is specified, one can manually update the
|
||||
progress bar through the `update()` method instead of directly
|
||||
iterating over the progress bar. The update method accepts the number
|
||||
of steps to increment the bar with::
|
||||
|
||||
with progressbar(length=chunks.total_bytes) as bar:
|
||||
for chunk in chunks:
|
||||
process_chunk(chunk)
|
||||
bar.update(chunks.bytes)
|
||||
|
||||
The ``update()`` method also takes an optional value specifying the
|
||||
``current_item`` at the new position. This is useful when used
|
||||
together with ``item_show_func`` to customize the output for each
|
||||
manual step::
|
||||
|
||||
with click.progressbar(
|
||||
length=total_size,
|
||||
label='Unzipping archive',
|
||||
item_show_func=lambda a: a.filename
|
||||
) as bar:
|
||||
for archive in zip_file:
|
||||
archive.extract()
|
||||
bar.update(archive.size, archive)
|
||||
|
||||
:param iterable: an iterable to iterate over. If not provided the length
|
||||
is required.
|
||||
:param length: the number of items to iterate over. By default the
|
||||
progressbar will attempt to ask the iterator about its
|
||||
length, which might or might not work. If an iterable is
|
||||
also provided this parameter can be used to override the
|
||||
length. If an iterable is not provided the progress bar
|
||||
will iterate over a range of that length.
|
||||
:param label: the label to show next to the progress bar.
|
||||
:param show_eta: enables or disables the estimated time display. This is
|
||||
automatically disabled if the length cannot be
|
||||
determined.
|
||||
:param show_percent: enables or disables the percentage display. The
|
||||
default is `True` if the iterable has a length or
|
||||
`False` if not.
|
||||
:param show_pos: enables or disables the absolute position display. The
|
||||
default is `False`.
|
||||
:param item_show_func: A function called with the current item which
|
||||
can return a string to show next to the progress bar. If the
|
||||
function returns ``None`` nothing is shown. The current item can
|
||||
be ``None``, such as when entering and exiting the bar.
|
||||
:param fill_char: the character to use to show the filled part of the
|
||||
progress bar.
|
||||
:param empty_char: the character to use to show the non-filled part of
|
||||
the progress bar.
|
||||
:param bar_template: the format string to use as template for the bar.
|
||||
The parameters in it are ``label`` for the label,
|
||||
``bar`` for the progress bar and ``info`` for the
|
||||
info section.
|
||||
:param info_sep: the separator between multiple info items (eta etc.)
|
||||
:param width: the width of the progress bar in characters, 0 means full
|
||||
terminal width
|
||||
:param file: The file to write to. If this is not a terminal then
|
||||
only the label is printed.
|
||||
:param color: controls if the terminal supports ANSI colors or not. The
|
||||
default is autodetection. This is only needed if ANSI
|
||||
codes are included anywhere in the progress bar output
|
||||
which is not the case by default.
|
||||
:param update_min_steps: Render only when this many updates have
|
||||
completed. This allows tuning for very fast iterators.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Output is shown even if execution time is less than 0.5 seconds.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
``item_show_func`` shows the current item, not the previous one.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Labels are echoed if the output is not a TTY. Reverts a change
|
||||
in 7.0 that removed all output.
|
||||
|
||||
.. versionadded:: 8.0
|
||||
Added the ``update_min_steps`` parameter.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the ``color`` parameter. Added the ``update`` method to
|
||||
the object.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
from ._termui_impl import ProgressBar
|
||||
|
||||
color = resolve_color_default(color)
|
||||
return ProgressBar(
|
||||
iterable=iterable,
|
||||
length=length,
|
||||
show_eta=show_eta,
|
||||
show_percent=show_percent,
|
||||
show_pos=show_pos,
|
||||
item_show_func=item_show_func,
|
||||
fill_char=fill_char,
|
||||
empty_char=empty_char,
|
||||
bar_template=bar_template,
|
||||
info_sep=info_sep,
|
||||
file=file,
|
||||
label=label,
|
||||
width=width,
|
||||
color=color,
|
||||
update_min_steps=update_min_steps,
|
||||
)
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
"""Clears the terminal screen. This will have the effect of clearing
|
||||
the whole visible space of the terminal and moving the cursor to the
|
||||
top left. This does not do anything if not connected to a terminal.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if not isatty(sys.stdout):
|
||||
return
|
||||
if WIN:
|
||||
os.system("cls")
|
||||
else:
|
||||
sys.stdout.write("\033[2J\033[1;1H")
|
||||
|
||||
|
||||
def _interpret_color(
|
||||
color: t.Union[int, t.Tuple[int, int, int], str], offset: int = 0
|
||||
) -> str:
|
||||
if isinstance(color, int):
|
||||
return f"{38 + offset};5;{color:d}"
|
||||
|
||||
if isinstance(color, (tuple, list)):
|
||||
r, g, b = color
|
||||
return f"{38 + offset};2;{r:d};{g:d};{b:d}"
|
||||
|
||||
return str(_ansi_colors[color] + offset)
|
||||
|
||||
|
||||
def style(
|
||||
text: t.Any,
|
||||
fg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
|
||||
bg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
|
||||
bold: t.Optional[bool] = None,
|
||||
dim: t.Optional[bool] = None,
|
||||
underline: t.Optional[bool] = None,
|
||||
overline: t.Optional[bool] = None,
|
||||
italic: t.Optional[bool] = None,
|
||||
blink: t.Optional[bool] = None,
|
||||
reverse: t.Optional[bool] = None,
|
||||
strikethrough: t.Optional[bool] = None,
|
||||
reset: bool = True,
|
||||
) -> str:
|
||||
"""Styles a text with ANSI styles and returns the new string. By
|
||||
default the styling is self contained which means that at the end
|
||||
of the string a reset code is issued. This can be prevented by
|
||||
passing ``reset=False``.
|
||||
|
||||
Examples::
|
||||
|
||||
click.echo(click.style('Hello World!', fg='green'))
|
||||
click.echo(click.style('ATTENTION!', blink=True))
|
||||
click.echo(click.style('Some things', reverse=True, fg='cyan'))
|
||||
click.echo(click.style('More colors', fg=(255, 12, 128), bg=117))
|
||||
|
||||
Supported color names:
|
||||
|
||||
* ``black`` (might be a gray)
|
||||
* ``red``
|
||||
* ``green``
|
||||
* ``yellow`` (might be an orange)
|
||||
* ``blue``
|
||||
* ``magenta``
|
||||
* ``cyan``
|
||||
* ``white`` (might be light gray)
|
||||
* ``bright_black``
|
||||
* ``bright_red``
|
||||
* ``bright_green``
|
||||
* ``bright_yellow``
|
||||
* ``bright_blue``
|
||||
* ``bright_magenta``
|
||||
* ``bright_cyan``
|
||||
* ``bright_white``
|
||||
* ``reset`` (reset the color code only)
|
||||
|
||||
If the terminal supports it, color may also be specified as:
|
||||
|
||||
- An integer in the interval [0, 255]. The terminal must support
|
||||
8-bit/256-color mode.
|
||||
- An RGB tuple of three integers in [0, 255]. The terminal must
|
||||
support 24-bit/true-color mode.
|
||||
|
||||
See https://en.wikipedia.org/wiki/ANSI_color and
|
||||
https://gist.github.com/XVilka/8346728 for more information.
|
||||
|
||||
:param text: the string to style with ansi codes.
|
||||
:param fg: if provided this will become the foreground color.
|
||||
:param bg: if provided this will become the background color.
|
||||
:param bold: if provided this will enable or disable bold mode.
|
||||
:param dim: if provided this will enable or disable dim mode. This is
|
||||
badly supported.
|
||||
:param underline: if provided this will enable or disable underline.
|
||||
:param overline: if provided this will enable or disable overline.
|
||||
:param italic: if provided this will enable or disable italic.
|
||||
:param blink: if provided this will enable or disable blinking.
|
||||
:param reverse: if provided this will enable or disable inverse
|
||||
rendering (foreground becomes background and the
|
||||
other way round).
|
||||
:param strikethrough: if provided this will enable or disable
|
||||
striking through text.
|
||||
:param reset: by default a reset-all code is added at the end of the
|
||||
string which means that styles do not carry over. This
|
||||
can be disabled to compose styles.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
A non-string ``message`` is converted to a string.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Added support for 256 and RGB color codes.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Added the ``strikethrough``, ``italic``, and ``overline``
|
||||
parameters.
|
||||
|
||||
.. versionchanged:: 7.0
|
||||
Added support for bright colors.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
|
||||
bits = []
|
||||
|
||||
if fg:
|
||||
try:
|
||||
bits.append(f"\033[{_interpret_color(fg)}m")
|
||||
except KeyError:
|
||||
raise TypeError(f"Unknown color {fg!r}") from None
|
||||
|
||||
if bg:
|
||||
try:
|
||||
bits.append(f"\033[{_interpret_color(bg, 10)}m")
|
||||
except KeyError:
|
||||
raise TypeError(f"Unknown color {bg!r}") from None
|
||||
|
||||
if bold is not None:
|
||||
bits.append(f"\033[{1 if bold else 22}m")
|
||||
if dim is not None:
|
||||
bits.append(f"\033[{2 if dim else 22}m")
|
||||
if underline is not None:
|
||||
bits.append(f"\033[{4 if underline else 24}m")
|
||||
if overline is not None:
|
||||
bits.append(f"\033[{53 if overline else 55}m")
|
||||
if italic is not None:
|
||||
bits.append(f"\033[{3 if italic else 23}m")
|
||||
if blink is not None:
|
||||
bits.append(f"\033[{5 if blink else 25}m")
|
||||
if reverse is not None:
|
||||
bits.append(f"\033[{7 if reverse else 27}m")
|
||||
if strikethrough is not None:
|
||||
bits.append(f"\033[{9 if strikethrough else 29}m")
|
||||
bits.append(text)
|
||||
if reset:
|
||||
bits.append(_ansi_reset_all)
|
||||
return "".join(bits)
|
||||
|
||||
|
||||
def unstyle(text: str) -> str:
|
||||
"""Removes ANSI styling information from a string. Usually it's not
|
||||
necessary to use this function as Click's echo function will
|
||||
automatically remove styling if necessary.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
:param text: the text to remove style information from.
|
||||
"""
|
||||
return strip_ansi(text)
|
||||
|
||||
|
||||
def secho(
|
||||
message: t.Optional[t.Any] = None,
|
||||
file: t.Optional[t.IO[t.AnyStr]] = None,
|
||||
nl: bool = True,
|
||||
err: bool = False,
|
||||
color: t.Optional[bool] = None,
|
||||
**styles: t.Any,
|
||||
) -> None:
|
||||
"""This function combines :func:`echo` and :func:`style` into one
|
||||
call. As such the following two calls are the same::
|
||||
|
||||
click.secho('Hello World!', fg='green')
|
||||
click.echo(click.style('Hello World!', fg='green'))
|
||||
|
||||
All keyword arguments are forwarded to the underlying functions
|
||||
depending on which one they go with.
|
||||
|
||||
Non-string types will be converted to :class:`str`. However,
|
||||
:class:`bytes` are passed directly to :meth:`echo` without applying
|
||||
style. If you want to style bytes that represent text, call
|
||||
:meth:`bytes.decode` first.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
A non-string ``message`` is converted to a string. Bytes are
|
||||
passed through without style applied.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if message is not None and not isinstance(message, (bytes, bytearray)):
|
||||
message = style(message, **styles)
|
||||
|
||||
return echo(message, file=file, nl=nl, err=err, color=color)
|
||||
|
||||
|
||||
def edit(
|
||||
text: t.Optional[t.AnyStr] = None,
|
||||
editor: t.Optional[str] = None,
|
||||
env: t.Optional[t.Mapping[str, str]] = None,
|
||||
require_save: bool = True,
|
||||
extension: str = ".txt",
|
||||
filename: t.Optional[str] = None,
|
||||
) -> t.Optional[t.AnyStr]:
|
||||
r"""Edits the given text in the defined editor. If an editor is given
|
||||
(should be the full path to the executable but the regular operating
|
||||
system search path is used for finding the executable) it overrides
|
||||
the detected editor. Optionally, some environment variables can be
|
||||
used. If the editor is closed without changes, `None` is returned. In
|
||||
case a file is edited directly the return value is always `None` and
|
||||
`require_save` and `extension` are ignored.
|
||||
|
||||
If the editor cannot be opened a :exc:`UsageError` is raised.
|
||||
|
||||
Note for Windows: to simplify cross-platform usage, the newlines are
|
||||
automatically converted from POSIX to Windows and vice versa. As such,
|
||||
the message here will have ``\n`` as newline markers.
|
||||
|
||||
:param text: the text to edit.
|
||||
:param editor: optionally the editor to use. Defaults to automatic
|
||||
detection.
|
||||
:param env: environment variables to forward to the editor.
|
||||
:param require_save: if this is true, then not saving in the editor
|
||||
will make the return value become `None`.
|
||||
:param extension: the extension to tell the editor about. This defaults
|
||||
to `.txt` but changing this might change syntax
|
||||
highlighting.
|
||||
:param filename: if provided it will edit this file instead of the
|
||||
provided text contents. It will not use a temporary
|
||||
file as an indirection in that case.
|
||||
"""
|
||||
from ._termui_impl import Editor
|
||||
|
||||
ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension)
|
||||
|
||||
if filename is None:
|
||||
return ed.edit(text)
|
||||
|
||||
ed.edit_file(filename)
|
||||
return None
|
||||
|
||||
|
||||
def launch(url: str, wait: bool = False, locate: bool = False) -> int:
|
||||
"""This function launches the given URL (or filename) in the default
|
||||
viewer application for this file type. If this is an executable, it
|
||||
might launch the executable in a new session. The return value is
|
||||
the exit code of the launched application. Usually, ``0`` indicates
|
||||
success.
|
||||
|
||||
Examples::
|
||||
|
||||
click.launch('https://click.palletsprojects.com/')
|
||||
click.launch('/my/downloaded/file', locate=True)
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
:param url: URL or filename of the thing to launch.
|
||||
:param wait: Wait for the program to exit before returning. This
|
||||
only works if the launched program blocks. In particular,
|
||||
``xdg-open`` on Linux does not block.
|
||||
:param locate: if this is set to `True` then instead of launching the
|
||||
application associated with the URL it will attempt to
|
||||
launch a file manager with the file located. This
|
||||
might have weird effects if the URL does not point to
|
||||
the filesystem.
|
||||
"""
|
||||
from ._termui_impl import open_url
|
||||
|
||||
return open_url(url, wait=wait, locate=locate)
|
||||
|
||||
|
||||
# If this is provided, getchar() calls into this instead. This is used
|
||||
# for unittesting purposes.
|
||||
_getchar: t.Optional[t.Callable[[bool], str]] = None
|
||||
|
||||
|
||||
def getchar(echo: bool = False) -> str:
|
||||
"""Fetches a single character from the terminal and returns it. This
|
||||
will always return a unicode character and under certain rare
|
||||
circumstances this might return more than one character. The
|
||||
situations which more than one character is returned is when for
|
||||
whatever reason multiple characters end up in the terminal buffer or
|
||||
standard input was not actually a terminal.
|
||||
|
||||
Note that this will always read from the terminal, even if something
|
||||
is piped into the standard input.
|
||||
|
||||
Note for Windows: in rare cases when typing non-ASCII characters, this
|
||||
function might wait for a second character and then return both at once.
|
||||
This is because certain Unicode characters look like special-key markers.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
:param echo: if set to `True`, the character read will also show up on
|
||||
the terminal. The default is to not show it.
|
||||
"""
|
||||
global _getchar
|
||||
|
||||
if _getchar is None:
|
||||
from ._termui_impl import getchar as f
|
||||
|
||||
_getchar = f
|
||||
|
||||
return _getchar(echo)
|
||||
|
||||
|
||||
def raw_terminal() -> t.ContextManager[int]:
|
||||
from ._termui_impl import raw_terminal as f
|
||||
|
||||
return f()
|
||||
|
||||
|
||||
def pause(info: t.Optional[str] = None, err: bool = False) -> None:
|
||||
"""This command stops execution and waits for the user to press any
|
||||
key to continue. This is similar to the Windows batch "pause"
|
||||
command. If the program is not run through a terminal, this command
|
||||
will instead do nothing.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. versionadded:: 4.0
|
||||
Added the `err` parameter.
|
||||
|
||||
:param info: The message to print before pausing. Defaults to
|
||||
``"Press any key to continue..."``.
|
||||
:param err: if set to message goes to ``stderr`` instead of
|
||||
``stdout``, the same as with echo.
|
||||
"""
|
||||
if not isatty(sys.stdin) or not isatty(sys.stdout):
|
||||
return
|
||||
|
||||
if info is None:
|
||||
info = _("Press any key to continue...")
|
||||
|
||||
try:
|
||||
if info:
|
||||
echo(info, nl=False, err=err)
|
||||
try:
|
||||
getchar()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
pass
|
||||
finally:
|
||||
if info:
|
||||
echo(err=err)
|
@ -0,0 +1,479 @@
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import typing as t
|
||||
from types import TracebackType
|
||||
|
||||
from . import formatting
|
||||
from . import termui
|
||||
from . import utils
|
||||
from ._compat import _find_binary_reader
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .core import BaseCommand
|
||||
|
||||
|
||||
class EchoingStdin:
|
||||
def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
|
||||
self._input = input
|
||||
self._output = output
|
||||
self._paused = False
|
||||
|
||||
def __getattr__(self, x: str) -> t.Any:
|
||||
return getattr(self._input, x)
|
||||
|
||||
def _echo(self, rv: bytes) -> bytes:
|
||||
if not self._paused:
|
||||
self._output.write(rv)
|
||||
|
||||
return rv
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return self._echo(self._input.read(n))
|
||||
|
||||
def read1(self, n: int = -1) -> bytes:
|
||||
return self._echo(self._input.read1(n)) # type: ignore
|
||||
|
||||
def readline(self, n: int = -1) -> bytes:
|
||||
return self._echo(self._input.readline(n))
|
||||
|
||||
def readlines(self) -> t.List[bytes]:
|
||||
return [self._echo(x) for x in self._input.readlines()]
|
||||
|
||||
def __iter__(self) -> t.Iterator[bytes]:
|
||||
return iter(self._echo(x) for x in self._input)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._input)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]:
|
||||
if stream is None:
|
||||
yield
|
||||
else:
|
||||
stream._paused = True
|
||||
yield
|
||||
stream._paused = False
|
||||
|
||||
|
||||
class _NamedTextIOWrapper(io.TextIOWrapper):
|
||||
def __init__(
|
||||
self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
|
||||
) -> None:
|
||||
super().__init__(buffer, **kwargs)
|
||||
self._name = name
|
||||
self._mode = mode
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def mode(self) -> str:
|
||||
return self._mode
|
||||
|
||||
|
||||
def make_input_stream(
|
||||
input: t.Optional[t.Union[str, bytes, t.IO]], charset: str
|
||||
) -> t.BinaryIO:
|
||||
# Is already an input stream.
|
||||
if hasattr(input, "read"):
|
||||
rv = _find_binary_reader(t.cast(t.IO, input))
|
||||
|
||||
if rv is not None:
|
||||
return rv
|
||||
|
||||
raise TypeError("Could not find binary reader for input stream.")
|
||||
|
||||
if input is None:
|
||||
input = b""
|
||||
elif isinstance(input, str):
|
||||
input = input.encode(charset)
|
||||
|
||||
return io.BytesIO(t.cast(bytes, input))
|
||||
|
||||
|
||||
class Result:
|
||||
"""Holds the captured result of an invoked CLI script."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner: "CliRunner",
|
||||
stdout_bytes: bytes,
|
||||
stderr_bytes: t.Optional[bytes],
|
||||
return_value: t.Any,
|
||||
exit_code: int,
|
||||
exception: t.Optional[BaseException],
|
||||
exc_info: t.Optional[
|
||||
t.Tuple[t.Type[BaseException], BaseException, TracebackType]
|
||||
] = None,
|
||||
):
|
||||
#: The runner that created the result
|
||||
self.runner = runner
|
||||
#: The standard output as bytes.
|
||||
self.stdout_bytes = stdout_bytes
|
||||
#: The standard error as bytes, or None if not available
|
||||
self.stderr_bytes = stderr_bytes
|
||||
#: The value returned from the invoked command.
|
||||
#:
|
||||
#: .. versionadded:: 8.0
|
||||
self.return_value = return_value
|
||||
#: The exit code as integer.
|
||||
self.exit_code = exit_code
|
||||
#: The exception that happened if one did.
|
||||
self.exception = exception
|
||||
#: The traceback
|
||||
self.exc_info = exc_info
|
||||
|
||||
@property
|
||||
def output(self) -> str:
|
||||
"""The (standard) output as unicode string."""
|
||||
return self.stdout
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
"""The standard output as unicode string."""
|
||||
return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
|
||||
"\r\n", "\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
"""The standard error as unicode string."""
|
||||
if self.stderr_bytes is None:
|
||||
raise ValueError("stderr not separately captured")
|
||||
return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
|
||||
"\r\n", "\n"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exc_str = repr(self.exception) if self.exception else "okay"
|
||||
return f"<{type(self).__name__} {exc_str}>"
|
||||
|
||||
|
||||
class CliRunner:
|
||||
"""The CLI runner provides functionality to invoke a Click command line
|
||||
script for unittesting purposes in a isolated environment. This only
|
||||
works in single-threaded systems without any concurrency as it changes the
|
||||
global interpreter state.
|
||||
|
||||
:param charset: the character set for the input and output data.
|
||||
:param env: a dictionary with environment variables for overriding.
|
||||
:param echo_stdin: if this is set to `True`, then reading from stdin writes
|
||||
to stdout. This is useful for showing examples in
|
||||
some circumstances. Note that regular prompts
|
||||
will automatically echo the input.
|
||||
:param mix_stderr: if this is set to `False`, then stdout and stderr are
|
||||
preserved as independent streams. This is useful for
|
||||
Unix-philosophy apps that have predictable stdout and
|
||||
noisy stderr, such that each may be measured
|
||||
independently
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
charset: str = "utf-8",
|
||||
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
|
||||
echo_stdin: bool = False,
|
||||
mix_stderr: bool = True,
|
||||
) -> None:
|
||||
self.charset = charset
|
||||
self.env = env or {}
|
||||
self.echo_stdin = echo_stdin
|
||||
self.mix_stderr = mix_stderr
|
||||
|
||||
def get_default_prog_name(self, cli: "BaseCommand") -> str:
|
||||
"""Given a command object it will return the default program name
|
||||
for it. The default is the `name` attribute or ``"root"`` if not
|
||||
set.
|
||||
"""
|
||||
return cli.name or "root"
|
||||
|
||||
def make_env(
|
||||
self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None
|
||||
) -> t.Mapping[str, t.Optional[str]]:
|
||||
"""Returns the environment overrides for invoking a script."""
|
||||
rv = dict(self.env)
|
||||
if overrides:
|
||||
rv.update(overrides)
|
||||
return rv
|
||||
|
||||
@contextlib.contextmanager
|
||||
def isolation(
|
||||
self,
|
||||
input: t.Optional[t.Union[str, bytes, t.IO]] = None,
|
||||
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
|
||||
color: bool = False,
|
||||
) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]:
|
||||
"""A context manager that sets up the isolation for invoking of a
|
||||
command line tool. This sets up stdin with the given input data
|
||||
and `os.environ` with the overrides from the given dictionary.
|
||||
This also rebinds some internals in Click to be mocked (like the
|
||||
prompt functionality).
|
||||
|
||||
This is automatically done in the :meth:`invoke` method.
|
||||
|
||||
:param input: the input stream to put into sys.stdin.
|
||||
:param env: the environment overrides as dictionary.
|
||||
:param color: whether the output should contain color codes. The
|
||||
application can still override this explicitly.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
``stderr`` is opened with ``errors="backslashreplace"``
|
||||
instead of the default ``"strict"``.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the ``color`` parameter.
|
||||
"""
|
||||
bytes_input = make_input_stream(input, self.charset)
|
||||
echo_input = None
|
||||
|
||||
old_stdin = sys.stdin
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
old_forced_width = formatting.FORCED_WIDTH
|
||||
formatting.FORCED_WIDTH = 80
|
||||
|
||||
env = self.make_env(env)
|
||||
|
||||
bytes_output = io.BytesIO()
|
||||
|
||||
if self.echo_stdin:
|
||||
bytes_input = echo_input = t.cast(
|
||||
t.BinaryIO, EchoingStdin(bytes_input, bytes_output)
|
||||
)
|
||||
|
||||
sys.stdin = text_input = _NamedTextIOWrapper(
|
||||
bytes_input, encoding=self.charset, name="<stdin>", mode="r"
|
||||
)
|
||||
|
||||
if self.echo_stdin:
|
||||
# Force unbuffered reads, otherwise TextIOWrapper reads a
|
||||
# large chunk which is echoed early.
|
||||
text_input._CHUNK_SIZE = 1 # type: ignore
|
||||
|
||||
sys.stdout = _NamedTextIOWrapper(
|
||||
bytes_output, encoding=self.charset, name="<stdout>", mode="w"
|
||||
)
|
||||
|
||||
bytes_error = None
|
||||
if self.mix_stderr:
|
||||
sys.stderr = sys.stdout
|
||||
else:
|
||||
bytes_error = io.BytesIO()
|
||||
sys.stderr = _NamedTextIOWrapper(
|
||||
bytes_error,
|
||||
encoding=self.charset,
|
||||
name="<stderr>",
|
||||
mode="w",
|
||||
errors="backslashreplace",
|
||||
)
|
||||
|
||||
@_pause_echo(echo_input) # type: ignore
|
||||
def visible_input(prompt: t.Optional[str] = None) -> str:
|
||||
sys.stdout.write(prompt or "")
|
||||
val = text_input.readline().rstrip("\r\n")
|
||||
sys.stdout.write(f"{val}\n")
|
||||
sys.stdout.flush()
|
||||
return val
|
||||
|
||||
@_pause_echo(echo_input) # type: ignore
|
||||
def hidden_input(prompt: t.Optional[str] = None) -> str:
|
||||
sys.stdout.write(f"{prompt or ''}\n")
|
||||
sys.stdout.flush()
|
||||
return text_input.readline().rstrip("\r\n")
|
||||
|
||||
@_pause_echo(echo_input) # type: ignore
|
||||
def _getchar(echo: bool) -> str:
|
||||
char = sys.stdin.read(1)
|
||||
|
||||
if echo:
|
||||
sys.stdout.write(char)
|
||||
|
||||
sys.stdout.flush()
|
||||
return char
|
||||
|
||||
default_color = color
|
||||
|
||||
def should_strip_ansi(
|
||||
stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
|
||||
) -> bool:
|
||||
if color is None:
|
||||
return not default_color
|
||||
return not color
|
||||
|
||||
old_visible_prompt_func = termui.visible_prompt_func
|
||||
old_hidden_prompt_func = termui.hidden_prompt_func
|
||||
old__getchar_func = termui._getchar
|
||||
old_should_strip_ansi = utils.should_strip_ansi # type: ignore
|
||||
termui.visible_prompt_func = visible_input
|
||||
termui.hidden_prompt_func = hidden_input
|
||||
termui._getchar = _getchar
|
||||
utils.should_strip_ansi = should_strip_ansi # type: ignore
|
||||
|
||||
old_env = {}
|
||||
try:
|
||||
for key, value in env.items():
|
||||
old_env[key] = os.environ.get(key)
|
||||
if value is None:
|
||||
try:
|
||||
del os.environ[key]
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
os.environ[key] = value
|
||||
yield (bytes_output, bytes_error)
|
||||
finally:
|
||||
for key, value in old_env.items():
|
||||
if value is None:
|
||||
try:
|
||||
del os.environ[key]
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
os.environ[key] = value
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
sys.stdin = old_stdin
|
||||
termui.visible_prompt_func = old_visible_prompt_func
|
||||
termui.hidden_prompt_func = old_hidden_prompt_func
|
||||
termui._getchar = old__getchar_func
|
||||
utils.should_strip_ansi = old_should_strip_ansi # type: ignore
|
||||
formatting.FORCED_WIDTH = old_forced_width
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
cli: "BaseCommand",
|
||||
args: t.Optional[t.Union[str, t.Sequence[str]]] = None,
|
||||
input: t.Optional[t.Union[str, bytes, t.IO]] = None,
|
||||
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
|
||||
catch_exceptions: bool = True,
|
||||
color: bool = False,
|
||||
**extra: t.Any,
|
||||
) -> Result:
|
||||
"""Invokes a command in an isolated environment. The arguments are
|
||||
forwarded directly to the command line script, the `extra` keyword
|
||||
arguments are passed to the :meth:`~clickpkg.Command.main` function of
|
||||
the command.
|
||||
|
||||
This returns a :class:`Result` object.
|
||||
|
||||
:param cli: the command to invoke
|
||||
:param args: the arguments to invoke. It may be given as an iterable
|
||||
or a string. When given as string it will be interpreted
|
||||
as a Unix shell command. More details at
|
||||
:func:`shlex.split`.
|
||||
:param input: the input data for `sys.stdin`.
|
||||
:param env: the environment overrides.
|
||||
:param catch_exceptions: Whether to catch any other exceptions than
|
||||
``SystemExit``.
|
||||
:param extra: the keyword arguments to pass to :meth:`main`.
|
||||
:param color: whether the output should contain color codes. The
|
||||
application can still override this explicitly.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
The result object has the ``return_value`` attribute with
|
||||
the value returned from the invoked command.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the ``color`` parameter.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
Added the ``catch_exceptions`` parameter.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
The result object has the ``exc_info`` attribute with the
|
||||
traceback if available.
|
||||
"""
|
||||
exc_info = None
|
||||
with self.isolation(input=input, env=env, color=color) as outstreams:
|
||||
return_value = None
|
||||
exception: t.Optional[BaseException] = None
|
||||
exit_code = 0
|
||||
|
||||
if isinstance(args, str):
|
||||
args = shlex.split(args)
|
||||
|
||||
try:
|
||||
prog_name = extra.pop("prog_name")
|
||||
except KeyError:
|
||||
prog_name = self.get_default_prog_name(cli)
|
||||
|
||||
try:
|
||||
return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
|
||||
except SystemExit as e:
|
||||
exc_info = sys.exc_info()
|
||||
e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code)
|
||||
|
||||
if e_code is None:
|
||||
e_code = 0
|
||||
|
||||
if e_code != 0:
|
||||
exception = e
|
||||
|
||||
if not isinstance(e_code, int):
|
||||
sys.stdout.write(str(e_code))
|
||||
sys.stdout.write("\n")
|
||||
e_code = 1
|
||||
|
||||
exit_code = e_code
|
||||
|
||||
except Exception as e:
|
||||
if not catch_exceptions:
|
||||
raise
|
||||
exception = e
|
||||
exit_code = 1
|
||||
exc_info = sys.exc_info()
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
stdout = outstreams[0].getvalue()
|
||||
if self.mix_stderr:
|
||||
stderr = None
|
||||
else:
|
||||
stderr = outstreams[1].getvalue() # type: ignore
|
||||
|
||||
return Result(
|
||||
runner=self,
|
||||
stdout_bytes=stdout,
|
||||
stderr_bytes=stderr,
|
||||
return_value=return_value,
|
||||
exit_code=exit_code,
|
||||
exception=exception,
|
||||
exc_info=exc_info, # type: ignore
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def isolated_filesystem(
|
||||
self, temp_dir: t.Optional[t.Union[str, os.PathLike]] = None
|
||||
) -> t.Iterator[str]:
|
||||
"""A context manager that creates a temporary directory and
|
||||
changes the current working directory to it. This isolates tests
|
||||
that affect the contents of the CWD to prevent them from
|
||||
interfering with each other.
|
||||
|
||||
:param temp_dir: Create the temporary directory under this
|
||||
directory. If given, the created directory is not removed
|
||||
when exiting.
|
||||
|
||||
.. versionchanged:: 8.0
|
||||
Added the ``temp_dir`` parameter.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
dt = tempfile.mkdtemp(dir=temp_dir) # type: ignore[type-var]
|
||||
os.chdir(dt)
|
||||
|
||||
try:
|
||||
yield t.cast(str, dt)
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
if temp_dir is None:
|
||||
try:
|
||||
shutil.rmtree(dt)
|
||||
except OSError: # noqa: B014
|
||||
pass
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,580 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
from functools import update_wrapper
|
||||
from types import ModuleType
|
||||
|
||||
from ._compat import _default_text_stderr
|
||||
from ._compat import _default_text_stdout
|
||||
from ._compat import _find_binary_writer
|
||||
from ._compat import auto_wrap_for_ansi
|
||||
from ._compat import binary_streams
|
||||
from ._compat import get_filesystem_encoding
|
||||
from ._compat import open_stream
|
||||
from ._compat import should_strip_ansi
|
||||
from ._compat import strip_ansi
|
||||
from ._compat import text_streams
|
||||
from ._compat import WIN
|
||||
from .globals import resolve_color_default
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import typing_extensions as te
|
||||
|
||||
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
|
||||
|
||||
|
||||
def _posixify(name: str) -> str:
|
||||
return "-".join(name.split()).lower()
|
||||
|
||||
|
||||
def safecall(func: F) -> F:
|
||||
"""Wraps a function so that it swallows exceptions."""
|
||||
|
||||
def wrapper(*args, **kwargs): # type: ignore
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return update_wrapper(t.cast(F, wrapper), func)
|
||||
|
||||
|
||||
def make_str(value: t.Any) -> str:
|
||||
"""Converts a value into a valid string."""
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
return value.decode(get_filesystem_encoding())
|
||||
except UnicodeError:
|
||||
return value.decode("utf-8", "replace")
|
||||
return str(value)
|
||||
|
||||
|
||||
def make_default_short_help(help: str, max_length: int = 45) -> str:
|
||||
"""Returns a condensed version of help string."""
|
||||
# Consider only the first paragraph.
|
||||
paragraph_end = help.find("\n\n")
|
||||
|
||||
if paragraph_end != -1:
|
||||
help = help[:paragraph_end]
|
||||
|
||||
# Collapse newlines, tabs, and spaces.
|
||||
words = help.split()
|
||||
|
||||
if not words:
|
||||
return ""
|
||||
|
||||
# The first paragraph started with a "no rewrap" marker, ignore it.
|
||||
if words[0] == "\b":
|
||||
words = words[1:]
|
||||
|
||||
total_length = 0
|
||||
last_index = len(words) - 1
|
||||
|
||||
for i, word in enumerate(words):
|
||||
total_length += len(word) + (i > 0)
|
||||
|
||||
if total_length > max_length: # too long, truncate
|
||||
break
|
||||
|
||||
if word[-1] == ".": # sentence end, truncate without "..."
|
||||
return " ".join(words[: i + 1])
|
||||
|
||||
if total_length == max_length and i != last_index:
|
||||
break # not at sentence end, truncate with "..."
|
||||
else:
|
||||
return " ".join(words) # no truncation needed
|
||||
|
||||
# Account for the length of the suffix.
|
||||
total_length += len("...")
|
||||
|
||||
# remove words until the length is short enough
|
||||
while i > 0:
|
||||
total_length -= len(words[i]) + (i > 0)
|
||||
|
||||
if total_length <= max_length:
|
||||
break
|
||||
|
||||
i -= 1
|
||||
|
||||
return " ".join(words[:i]) + "..."
|
||||
|
||||
|
||||
class LazyFile:
|
||||
"""A lazy file works like a regular file but it does not fully open
|
||||
the file but it does perform some basic checks early to see if the
|
||||
filename parameter does make sense. This is useful for safely opening
|
||||
files for writing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
mode: str = "r",
|
||||
encoding: t.Optional[str] = None,
|
||||
errors: t.Optional[str] = "strict",
|
||||
atomic: bool = False,
|
||||
):
|
||||
self.name = filename
|
||||
self.mode = mode
|
||||
self.encoding = encoding
|
||||
self.errors = errors
|
||||
self.atomic = atomic
|
||||
self._f: t.Optional[t.IO]
|
||||
|
||||
if filename == "-":
|
||||
self._f, self.should_close = open_stream(filename, mode, encoding, errors)
|
||||
else:
|
||||
if "r" in mode:
|
||||
# Open and close the file in case we're opening it for
|
||||
# reading so that we can catch at least some errors in
|
||||
# some cases early.
|
||||
open(filename, mode).close()
|
||||
self._f = None
|
||||
self.should_close = True
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
return getattr(self.open(), name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._f is not None:
|
||||
return repr(self._f)
|
||||
return f"<unopened file '{self.name}' {self.mode}>"
|
||||
|
||||
def open(self) -> t.IO:
|
||||
"""Opens the file if it's not yet open. This call might fail with
|
||||
a :exc:`FileError`. Not handling this error will produce an error
|
||||
that Click shows.
|
||||
"""
|
||||
if self._f is not None:
|
||||
return self._f
|
||||
try:
|
||||
rv, self.should_close = open_stream(
|
||||
self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
|
||||
)
|
||||
except OSError as e: # noqa: E402
|
||||
from .exceptions import FileError
|
||||
|
||||
raise FileError(self.name, hint=e.strerror) from e
|
||||
self._f = rv
|
||||
return rv
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the underlying file, no matter what."""
|
||||
if self._f is not None:
|
||||
self._f.close()
|
||||
|
||||
def close_intelligently(self) -> None:
|
||||
"""This function only closes the file if it was opened by the lazy
|
||||
file wrapper. For instance this will never close stdin.
|
||||
"""
|
||||
if self.should_close:
|
||||
self.close()
|
||||
|
||||
def __enter__(self) -> "LazyFile":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb): # type: ignore
|
||||
self.close_intelligently()
|
||||
|
||||
def __iter__(self) -> t.Iterator[t.AnyStr]:
|
||||
self.open()
|
||||
return iter(self._f) # type: ignore
|
||||
|
||||
|
||||
class KeepOpenFile:
|
||||
def __init__(self, file: t.IO) -> None:
|
||||
self._file = file
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
return getattr(self._file, name)
|
||||
|
||||
def __enter__(self) -> "KeepOpenFile":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb): # type: ignore
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._file)
|
||||
|
||||
def __iter__(self) -> t.Iterator[t.AnyStr]:
|
||||
return iter(self._file)
|
||||
|
||||
|
||||
def echo(
|
||||
message: t.Optional[t.Any] = None,
|
||||
file: t.Optional[t.IO[t.Any]] = None,
|
||||
nl: bool = True,
|
||||
err: bool = False,
|
||||
color: t.Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Print a message and newline to stdout or a file. This should be
|
||||
used instead of :func:`print` because it provides better support
|
||||
for different data, files, and environments.
|
||||
|
||||
Compared to :func:`print`, this does the following:
|
||||
|
||||
- Ensures that the output encoding is not misconfigured on Linux.
|
||||
- Supports Unicode in the Windows console.
|
||||
- Supports writing to binary outputs, and supports writing bytes
|
||||
to text outputs.
|
||||
- Supports colors and styles on Windows.
|
||||
- Removes ANSI color and style codes if the output does not look
|
||||
like an interactive terminal.
|
||||
- Always flushes the output.
|
||||
|
||||
:param message: The string or bytes to output. Other objects are
|
||||
converted to strings.
|
||||
:param file: The file to write to. Defaults to ``stdout``.
|
||||
:param err: Write to ``stderr`` instead of ``stdout``.
|
||||
:param nl: Print a newline after the message. Enabled by default.
|
||||
:param color: Force showing or hiding colors and other styles. By
|
||||
default Click will remove color if the output does not look like
|
||||
an interactive terminal.
|
||||
|
||||
.. versionchanged:: 6.0
|
||||
Support Unicode output on the Windows console. Click does not
|
||||
modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()``
|
||||
will still not support Unicode.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the ``color`` parameter.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
Added the ``err`` parameter.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Support colors on Windows if colorama is installed.
|
||||
"""
|
||||
if file is None:
|
||||
if err:
|
||||
file = _default_text_stderr()
|
||||
else:
|
||||
file = _default_text_stdout()
|
||||
|
||||
# Convert non bytes/text into the native string type.
|
||||
if message is not None and not isinstance(message, (str, bytes, bytearray)):
|
||||
out: t.Optional[t.Union[str, bytes]] = str(message)
|
||||
else:
|
||||
out = message
|
||||
|
||||
if nl:
|
||||
out = out or ""
|
||||
if isinstance(out, str):
|
||||
out += "\n"
|
||||
else:
|
||||
out += b"\n"
|
||||
|
||||
if not out:
|
||||
file.flush()
|
||||
return
|
||||
|
||||
# If there is a message and the value looks like bytes, we manually
|
||||
# need to find the binary stream and write the message in there.
|
||||
# This is done separately so that most stream types will work as you
|
||||
# would expect. Eg: you can write to StringIO for other cases.
|
||||
if isinstance(out, (bytes, bytearray)):
|
||||
binary_file = _find_binary_writer(file)
|
||||
|
||||
if binary_file is not None:
|
||||
file.flush()
|
||||
binary_file.write(out)
|
||||
binary_file.flush()
|
||||
return
|
||||
|
||||
# ANSI style code support. For no message or bytes, nothing happens.
|
||||
# When outputting to a file instead of a terminal, strip codes.
|
||||
else:
|
||||
color = resolve_color_default(color)
|
||||
|
||||
if should_strip_ansi(file, color):
|
||||
out = strip_ansi(out)
|
||||
elif WIN:
|
||||
if auto_wrap_for_ansi is not None:
|
||||
file = auto_wrap_for_ansi(file) # type: ignore
|
||||
elif not color:
|
||||
out = strip_ansi(out)
|
||||
|
||||
file.write(out) # type: ignore
|
||||
file.flush()
|
||||
|
||||
|
||||
def get_binary_stream(name: "te.Literal['stdin', 'stdout', 'stderr']") -> t.BinaryIO:
|
||||
"""Returns a system stream for byte processing.
|
||||
|
||||
:param name: the name of the stream to open. Valid names are ``'stdin'``,
|
||||
``'stdout'`` and ``'stderr'``
|
||||
"""
|
||||
opener = binary_streams.get(name)
|
||||
if opener is None:
|
||||
raise TypeError(f"Unknown standard stream '{name}'")
|
||||
return opener()
|
||||
|
||||
|
||||
def get_text_stream(
|
||||
name: "te.Literal['stdin', 'stdout', 'stderr']",
|
||||
encoding: t.Optional[str] = None,
|
||||
errors: t.Optional[str] = "strict",
|
||||
) -> t.TextIO:
|
||||
"""Returns a system stream for text processing. This usually returns
|
||||
a wrapped stream around a binary stream returned from
|
||||
:func:`get_binary_stream` but it also can take shortcuts for already
|
||||
correctly configured streams.
|
||||
|
||||
:param name: the name of the stream to open. Valid names are ``'stdin'``,
|
||||
``'stdout'`` and ``'stderr'``
|
||||
:param encoding: overrides the detected default encoding.
|
||||
:param errors: overrides the default error mode.
|
||||
"""
|
||||
opener = text_streams.get(name)
|
||||
if opener is None:
|
||||
raise TypeError(f"Unknown standard stream '{name}'")
|
||||
return opener(encoding, errors)
|
||||
|
||||
|
||||
def open_file(
|
||||
filename: str,
|
||||
mode: str = "r",
|
||||
encoding: t.Optional[str] = None,
|
||||
errors: t.Optional[str] = "strict",
|
||||
lazy: bool = False,
|
||||
atomic: bool = False,
|
||||
) -> t.IO:
|
||||
"""Open a file, with extra behavior to handle ``'-'`` to indicate
|
||||
a standard stream, lazy open on write, and atomic write. Similar to
|
||||
the behavior of the :class:`~click.File` param type.
|
||||
|
||||
If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is
|
||||
wrapped so that using it in a context manager will not close it.
|
||||
This makes it possible to use the function without accidentally
|
||||
closing a standard stream:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with open_file(filename) as f:
|
||||
...
|
||||
|
||||
:param filename: The name of the file to open, or ``'-'`` for
|
||||
``stdin``/``stdout``.
|
||||
:param mode: The mode in which to open the file.
|
||||
:param encoding: The encoding to decode or encode a file opened in
|
||||
text mode.
|
||||
:param errors: The error handling mode.
|
||||
:param lazy: Wait to open the file until it is accessed. For read
|
||||
mode, the file is temporarily opened to raise access errors
|
||||
early, then closed until it is read again.
|
||||
:param atomic: Write to a temporary file and replace the given file
|
||||
on close.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
if lazy:
|
||||
return t.cast(t.IO, LazyFile(filename, mode, encoding, errors, atomic=atomic))
|
||||
|
||||
f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic)
|
||||
|
||||
if not should_close:
|
||||
f = t.cast(t.IO, KeepOpenFile(f))
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def format_filename(
|
||||
filename: t.Union[str, bytes, os.PathLike], shorten: bool = False
|
||||
) -> str:
|
||||
"""Formats a filename for user display. The main purpose of this
|
||||
function is to ensure that the filename can be displayed at all. This
|
||||
will decode the filename to unicode if necessary in a way that it will
|
||||
not fail. Optionally, it can shorten the filename to not include the
|
||||
full path to the filename.
|
||||
|
||||
:param filename: formats a filename for UI display. This will also convert
|
||||
the filename into unicode without failing.
|
||||
:param shorten: this optionally shortens the filename to strip of the
|
||||
path that leads up to it.
|
||||
"""
|
||||
if shorten:
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
return os.fsdecode(filename)
|
||||
|
||||
|
||||
def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
|
||||
r"""Returns the config folder for the application. The default behavior
|
||||
is to return whatever is most appropriate for the operating system.
|
||||
|
||||
To give you an idea, for an app called ``"Foo Bar"``, something like
|
||||
the following folders could be returned:
|
||||
|
||||
Mac OS X:
|
||||
``~/Library/Application Support/Foo Bar``
|
||||
Mac OS X (POSIX):
|
||||
``~/.foo-bar``
|
||||
Unix:
|
||||
``~/.config/foo-bar``
|
||||
Unix (POSIX):
|
||||
``~/.foo-bar``
|
||||
Windows (roaming):
|
||||
``C:\Users\<user>\AppData\Roaming\Foo Bar``
|
||||
Windows (not roaming):
|
||||
``C:\Users\<user>\AppData\Local\Foo Bar``
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
:param app_name: the application name. This should be properly capitalized
|
||||
and can contain whitespace.
|
||||
:param roaming: controls if the folder should be roaming or not on Windows.
|
||||
Has no affect otherwise.
|
||||
:param force_posix: if this is set to `True` then on any POSIX system the
|
||||
folder will be stored in the home folder with a leading
|
||||
dot instead of the XDG config home or darwin's
|
||||
application support folder.
|
||||
"""
|
||||
if WIN:
|
||||
key = "APPDATA" if roaming else "LOCALAPPDATA"
|
||||
folder = os.environ.get(key)
|
||||
if folder is None:
|
||||
folder = os.path.expanduser("~")
|
||||
return os.path.join(folder, app_name)
|
||||
if force_posix:
|
||||
return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
|
||||
if sys.platform == "darwin":
|
||||
return os.path.join(
|
||||
os.path.expanduser("~/Library/Application Support"), app_name
|
||||
)
|
||||
return os.path.join(
|
||||
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
|
||||
_posixify(app_name),
|
||||
)
|
||||
|
||||
|
||||
class PacifyFlushWrapper:
|
||||
"""This wrapper is used to catch and suppress BrokenPipeErrors resulting
|
||||
from ``.flush()`` being called on broken pipe during the shutdown/final-GC
|
||||
of the Python interpreter. Notably ``.flush()`` is always called on
|
||||
``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any
|
||||
other cleanup code, and the case where the underlying file is not a broken
|
||||
pipe, all calls and attributes are proxied.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped: t.IO) -> None:
|
||||
self.wrapped = wrapped
|
||||
|
||||
def flush(self) -> None:
|
||||
try:
|
||||
self.wrapped.flush()
|
||||
except OSError as e:
|
||||
import errno
|
||||
|
||||
if e.errno != errno.EPIPE:
|
||||
raise
|
||||
|
||||
def __getattr__(self, attr: str) -> t.Any:
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
|
||||
def _detect_program_name(
|
||||
path: t.Optional[str] = None, _main: t.Optional[ModuleType] = None
|
||||
) -> str:
|
||||
"""Determine the command used to run the program, for use in help
|
||||
text. If a file or entry point was executed, the file name is
|
||||
returned. If ``python -m`` was used to execute a module or package,
|
||||
``python -m name`` is returned.
|
||||
|
||||
This doesn't try to be too precise, the goal is to give a concise
|
||||
name for help text. Files are only shown as their name without the
|
||||
path. ``python`` is only shown for modules, and the full path to
|
||||
``sys.executable`` is not shown.
|
||||
|
||||
:param path: The Python file being executed. Python puts this in
|
||||
``sys.argv[0]``, which is used by default.
|
||||
:param _main: The ``__main__`` module. This should only be passed
|
||||
during internal testing.
|
||||
|
||||
.. versionadded:: 8.0
|
||||
Based on command args detection in the Werkzeug reloader.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if _main is None:
|
||||
_main = sys.modules["__main__"]
|
||||
|
||||
if not path:
|
||||
path = sys.argv[0]
|
||||
|
||||
# The value of __package__ indicates how Python was called. It may
|
||||
# not exist if a setuptools script is installed as an egg. It may be
|
||||
# set incorrectly for entry points created with pip on Windows.
|
||||
if getattr(_main, "__package__", None) is None or (
|
||||
os.name == "nt"
|
||||
and _main.__package__ == ""
|
||||
and not os.path.exists(path)
|
||||
and os.path.exists(f"{path}.exe")
|
||||
):
|
||||
# Executed a file, like "python app.py".
|
||||
return os.path.basename(path)
|
||||
|
||||
# Executed a module, like "python -m example".
|
||||
# Rewritten by Python from "-m script" to "/path/to/script.py".
|
||||
# Need to look at main module to determine how it was executed.
|
||||
py_module = t.cast(str, _main.__package__)
|
||||
name = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
# A submodule like "example.cli".
|
||||
if name != "__main__":
|
||||
py_module = f"{py_module}.{name}"
|
||||
|
||||
return f"python -m {py_module.lstrip('.')}"
|
||||
|
||||
|
||||
def _expand_args(
|
||||
args: t.Iterable[str],
|
||||
*,
|
||||
user: bool = True,
|
||||
env: bool = True,
|
||||
glob_recursive: bool = True,
|
||||
) -> t.List[str]:
|
||||
"""Simulate Unix shell expansion with Python functions.
|
||||
|
||||
See :func:`glob.glob`, :func:`os.path.expanduser`, and
|
||||
:func:`os.path.expandvars`.
|
||||
|
||||
This is intended for use on Windows, where the shell does not do any
|
||||
expansion. It may not exactly match what a Unix shell would do.
|
||||
|
||||
:param args: List of command line arguments to expand.
|
||||
:param user: Expand user home directory.
|
||||
:param env: Expand environment variables.
|
||||
:param glob_recursive: ``**`` matches directories recursively.
|
||||
|
||||
.. versionchanged:: 8.1
|
||||
Invalid glob patterns are treated as empty expansions rather
|
||||
than raising an error.
|
||||
|
||||
.. versionadded:: 8.0
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
from glob import glob
|
||||
|
||||
out = []
|
||||
|
||||
for arg in args:
|
||||
if user:
|
||||
arg = os.path.expanduser(arg)
|
||||
|
||||
if env:
|
||||
arg = os.path.expandvars(arg)
|
||||
|
||||
try:
|
||||
matches = glob(arg, recursive=glob_recursive)
|
||||
except re.error:
|
||||
matches = []
|
||||
|
||||
if not matches:
|
||||
out.append(arg)
|
||||
else:
|
||||
out.extend(matches)
|
||||
|
||||
return out
|
@ -0,0 +1,49 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from .main import (dotenv_values, find_dotenv, get_key, load_dotenv, set_key,
|
||||
unset_key)
|
||||
|
||||
|
||||
def load_ipython_extension(ipython: Any) -> None:
|
||||
from .ipython import load_ipython_extension
|
||||
load_ipython_extension(ipython)
|
||||
|
||||
|
||||
def get_cli_string(
|
||||
path: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
key: Optional[str] = None,
|
||||
value: Optional[str] = None,
|
||||
quote: Optional[str] = None,
|
||||
):
|
||||
"""Returns a string suitable for running as a shell script.
|
||||
|
||||
Useful for converting a arguments passed to a fabric task
|
||||
to be passed to a `local` or `run` command.
|
||||
"""
|
||||
command = ['dotenv']
|
||||
if quote:
|
||||
command.append('-q %s' % quote)
|
||||
if path:
|
||||
command.append('-f %s' % path)
|
||||
if action:
|
||||
command.append(action)
|
||||
if key:
|
||||
command.append(key)
|
||||
if value:
|
||||
if ' ' in value:
|
||||
command.append('"%s"' % value)
|
||||
else:
|
||||
command.append(value)
|
||||
|
||||
return ' '.join(command).strip()
|
||||
|
||||
|
||||
__all__ = ['get_cli_string',
|
||||
'load_dotenv',
|
||||
'dotenv_values',
|
||||
'get_key',
|
||||
'set_key',
|
||||
'unset_key',
|
||||
'find_dotenv',
|
||||
'load_ipython_extension']
|
@ -0,0 +1,164 @@
|
||||
import os
|
||||
import sys
|
||||
from subprocess import Popen
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
import click
|
||||
except ImportError:
|
||||
sys.stderr.write('It seems python-dotenv is not installed with cli option. \n'
|
||||
'Run pip install "python-dotenv[cli]" to fix this.')
|
||||
sys.exit(1)
|
||||
|
||||
from .main import dotenv_values, get_key, set_key, unset_key
|
||||
from .version import __version__
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option('-f', '--file', default=os.path.join(os.getcwd(), '.env'),
|
||||
type=click.Path(file_okay=True),
|
||||
help="Location of the .env file, defaults to .env file in current working directory.")
|
||||
@click.option('-q', '--quote', default='always',
|
||||
type=click.Choice(['always', 'never', 'auto']),
|
||||
help="Whether to quote or not the variable values. Default mode is always. This does not affect parsing.")
|
||||
@click.option('-e', '--export', default=False,
|
||||
type=click.BOOL,
|
||||
help="Whether to write the dot file as an executable bash script.")
|
||||
@click.version_option(version=__version__)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, file: Any, quote: Any, export: Any) -> None:
|
||||
'''This script is used to set, get or unset values from a .env file.'''
|
||||
ctx.obj = {}
|
||||
ctx.obj['QUOTE'] = quote
|
||||
ctx.obj['EXPORT'] = export
|
||||
ctx.obj['FILE'] = file
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
def list(ctx: click.Context) -> None:
|
||||
'''Display all the stored key/value.'''
|
||||
file = ctx.obj['FILE']
|
||||
if not os.path.isfile(file):
|
||||
raise click.BadParameter(
|
||||
'Path "%s" does not exist.' % (file),
|
||||
ctx=ctx
|
||||
)
|
||||
dotenv_as_dict = dotenv_values(file)
|
||||
for k, v in dotenv_as_dict.items():
|
||||
click.echo('%s=%s' % (k, v))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.argument('key', required=True)
|
||||
@click.argument('value', required=True)
|
||||
def set(ctx: click.Context, key: Any, value: Any) -> None:
|
||||
'''Store the given key/value.'''
|
||||
file = ctx.obj['FILE']
|
||||
quote = ctx.obj['QUOTE']
|
||||
export = ctx.obj['EXPORT']
|
||||
success, key, value = set_key(file, key, value, quote, export)
|
||||
if success:
|
||||
click.echo('%s=%s' % (key, value))
|
||||
else:
|
||||
exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.argument('key', required=True)
|
||||
def get(ctx: click.Context, key: Any) -> None:
|
||||
'''Retrieve the value for the given key.'''
|
||||
file = ctx.obj['FILE']
|
||||
if not os.path.isfile(file):
|
||||
raise click.BadParameter(
|
||||
'Path "%s" does not exist.' % (file),
|
||||
ctx=ctx
|
||||
)
|
||||
stored_value = get_key(file, key)
|
||||
if stored_value:
|
||||
click.echo(stored_value)
|
||||
else:
|
||||
exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.argument('key', required=True)
|
||||
def unset(ctx: click.Context, key: Any) -> None:
|
||||
'''Removes the given key.'''
|
||||
file = ctx.obj['FILE']
|
||||
quote = ctx.obj['QUOTE']
|
||||
success, key = unset_key(file, key, quote)
|
||||
if success:
|
||||
click.echo("Successfully removed %s" % key)
|
||||
else:
|
||||
exit(1)
|
||||
|
||||
|
||||
@cli.command(context_settings={'ignore_unknown_options': True})
|
||||
@click.pass_context
|
||||
@click.option(
|
||||
"--override/--no-override",
|
||||
default=True,
|
||||
help="Override variables from the environment file with those from the .env file.",
|
||||
)
|
||||
@click.argument('commandline', nargs=-1, type=click.UNPROCESSED)
|
||||
def run(ctx: click.Context, override: bool, commandline: List[str]) -> None:
|
||||
"""Run command with environment variables present."""
|
||||
file = ctx.obj['FILE']
|
||||
if not os.path.isfile(file):
|
||||
raise click.BadParameter(
|
||||
'Invalid value for \'-f\' "%s" does not exist.' % (file),
|
||||
ctx=ctx
|
||||
)
|
||||
dotenv_as_dict = {
|
||||
k: v
|
||||
for (k, v) in dotenv_values(file).items()
|
||||
if v is not None and (override or k not in os.environ)
|
||||
}
|
||||
|
||||
if not commandline:
|
||||
click.echo('No command given.')
|
||||
exit(1)
|
||||
ret = run_command(commandline, dotenv_as_dict)
|
||||
exit(ret)
|
||||
|
||||
|
||||
def run_command(command: List[str], env: Dict[str, str]) -> int:
|
||||
"""Run command in sub process.
|
||||
|
||||
Runs the command in a sub process with the variables from `env`
|
||||
added in the current environment variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
command: List[str]
|
||||
The command and it's parameters
|
||||
env: Dict
|
||||
The additional environment variables
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The return code of the command
|
||||
|
||||
"""
|
||||
# copy the current environment variables and add the vales from
|
||||
# `env`
|
||||
cmd_env = os.environ.copy()
|
||||
cmd_env.update(env)
|
||||
|
||||
p = Popen(command,
|
||||
universal_newlines=True,
|
||||
bufsize=0,
|
||||
shell=False,
|
||||
env=cmd_env)
|
||||
_, _ = p.communicate()
|
||||
|
||||
return p.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
@ -0,0 +1,39 @@
|
||||
from IPython.core.magic import Magics, line_magic, magics_class # type: ignore
|
||||
from IPython.core.magic_arguments import (argument, magic_arguments, # type: ignore
|
||||
parse_argstring) # type: ignore
|
||||
|
||||
from .main import find_dotenv, load_dotenv
|
||||
|
||||
|
||||
@magics_class
|
||||
class IPythonDotEnv(Magics):
|
||||
|
||||
@magic_arguments()
|
||||
@argument(
|
||||
'-o', '--override', action='store_true',
|
||||
help="Indicate to override existing variables"
|
||||
)
|
||||
@argument(
|
||||
'-v', '--verbose', action='store_true',
|
||||
help="Indicate function calls to be verbose"
|
||||
)
|
||||
@argument('dotenv_path', nargs='?', type=str, default='.env',
|
||||
help='Search in increasingly higher folders for the `dotenv_path`')
|
||||
@line_magic
|
||||
def dotenv(self, line):
|
||||
args = parse_argstring(self.dotenv, line)
|
||||
# Locate the .env file
|
||||
dotenv_path = args.dotenv_path
|
||||
try:
|
||||
dotenv_path = find_dotenv(dotenv_path, True, True)
|
||||
except IOError:
|
||||
print("cannot find .env file")
|
||||
return
|
||||
|
||||
# Load the .env file
|
||||
load_dotenv(dotenv_path, verbose=args.verbose, override=args.override)
|
||||
|
||||
|
||||
def load_ipython_extension(ipython):
|
||||
"""Register the %dotenv magic."""
|
||||
ipython.register_magics(IPythonDotEnv)
|
@ -0,0 +1,373 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from typing import (IO, Dict, Iterable, Iterator, Mapping, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from .parser import Binding, parse_stream
|
||||
from .variables import parse_variables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if sys.version_info >= (3, 6):
|
||||
_PathLike = os.PathLike
|
||||
else:
|
||||
_PathLike = str
|
||||
|
||||
|
||||
def with_warn_for_invalid_lines(mappings: Iterator[Binding]) -> Iterator[Binding]:
|
||||
for mapping in mappings:
|
||||
if mapping.error:
|
||||
logger.warning(
|
||||
"Python-dotenv could not parse statement starting at line %s",
|
||||
mapping.original.line,
|
||||
)
|
||||
yield mapping
|
||||
|
||||
|
||||
class DotEnv():
|
||||
def __init__(
|
||||
self,
|
||||
dotenv_path: Optional[Union[str, _PathLike]],
|
||||
stream: Optional[IO[str]] = None,
|
||||
verbose: bool = False,
|
||||
encoding: Union[None, str] = None,
|
||||
interpolate: bool = True,
|
||||
override: bool = True,
|
||||
) -> None:
|
||||
self.dotenv_path = dotenv_path # type: Optional[Union[str, _PathLike]]
|
||||
self.stream = stream # type: Optional[IO[str]]
|
||||
self._dict = None # type: Optional[Dict[str, Optional[str]]]
|
||||
self.verbose = verbose # type: bool
|
||||
self.encoding = encoding # type: Union[None, str]
|
||||
self.interpolate = interpolate # type: bool
|
||||
self.override = override # type: bool
|
||||
|
||||
@contextmanager
|
||||
def _get_stream(self) -> Iterator[IO[str]]:
|
||||
if self.dotenv_path and os.path.isfile(self.dotenv_path):
|
||||
with io.open(self.dotenv_path, encoding=self.encoding) as stream:
|
||||
yield stream
|
||||
elif self.stream is not None:
|
||||
yield self.stream
|
||||
else:
|
||||
if self.verbose:
|
||||
logger.info(
|
||||
"Python-dotenv could not find configuration file %s.",
|
||||
self.dotenv_path or '.env',
|
||||
)
|
||||
yield io.StringIO('')
|
||||
|
||||
def dict(self) -> Dict[str, Optional[str]]:
|
||||
"""Return dotenv as dict"""
|
||||
if self._dict:
|
||||
return self._dict
|
||||
|
||||
raw_values = self.parse()
|
||||
|
||||
if self.interpolate:
|
||||
self._dict = OrderedDict(resolve_variables(raw_values, override=self.override))
|
||||
else:
|
||||
self._dict = OrderedDict(raw_values)
|
||||
|
||||
return self._dict
|
||||
|
||||
def parse(self) -> Iterator[Tuple[str, Optional[str]]]:
|
||||
with self._get_stream() as stream:
|
||||
for mapping in with_warn_for_invalid_lines(parse_stream(stream)):
|
||||
if mapping.key is not None:
|
||||
yield mapping.key, mapping.value
|
||||
|
||||
def set_as_environment_variables(self) -> bool:
|
||||
"""
|
||||
Load the current dotenv as system environment variable.
|
||||
"""
|
||||
for k, v in self.dict().items():
|
||||
if k in os.environ and not self.override:
|
||||
continue
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
|
||||
return True
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
"""
|
||||
data = self.dict()
|
||||
|
||||
if key in data:
|
||||
return data[key]
|
||||
|
||||
if self.verbose:
|
||||
logger.warning("Key %s not found in %s.", key, self.dotenv_path)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_key(
|
||||
dotenv_path: Union[str, _PathLike],
|
||||
key_to_get: str,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value of a given key from the given .env.
|
||||
|
||||
Returns `None` if the key isn't found or doesn't have a value.
|
||||
"""
|
||||
return DotEnv(dotenv_path, verbose=True, encoding=encoding).get(key_to_get)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def rewrite(
|
||||
path: Union[str, _PathLike],
|
||||
encoding: Optional[str],
|
||||
) -> Iterator[Tuple[IO[str], IO[str]]]:
|
||||
try:
|
||||
if not os.path.isfile(path):
|
||||
with io.open(path, "w+", encoding=encoding) as source:
|
||||
source.write("")
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding=encoding) as dest:
|
||||
with io.open(path, encoding=encoding) as source:
|
||||
yield (source, dest) # type: ignore
|
||||
except BaseException:
|
||||
if os.path.isfile(dest.name):
|
||||
os.unlink(dest.name)
|
||||
raise
|
||||
else:
|
||||
shutil.move(dest.name, path)
|
||||
|
||||
|
||||
def set_key(
|
||||
dotenv_path: Union[str, _PathLike],
|
||||
key_to_set: str,
|
||||
value_to_set: str,
|
||||
quote_mode: str = "always",
|
||||
export: bool = False,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
) -> Tuple[Optional[bool], str, str]:
|
||||
"""
|
||||
Adds or Updates a key/value to the given .env
|
||||
|
||||
If the .env path given doesn't exist, fails instead of risking creating
|
||||
an orphan .env somewhere in the filesystem
|
||||
"""
|
||||
if quote_mode not in ("always", "auto", "never"):
|
||||
raise ValueError("Unknown quote_mode: {}".format(quote_mode))
|
||||
|
||||
quote = (
|
||||
quote_mode == "always"
|
||||
or (quote_mode == "auto" and not value_to_set.isalnum())
|
||||
)
|
||||
|
||||
if quote:
|
||||
value_out = "'{}'".format(value_to_set.replace("'", "\\'"))
|
||||
else:
|
||||
value_out = value_to_set
|
||||
if export:
|
||||
line_out = 'export {}={}\n'.format(key_to_set, value_out)
|
||||
else:
|
||||
line_out = "{}={}\n".format(key_to_set, value_out)
|
||||
|
||||
with rewrite(dotenv_path, encoding=encoding) as (source, dest):
|
||||
replaced = False
|
||||
missing_newline = False
|
||||
for mapping in with_warn_for_invalid_lines(parse_stream(source)):
|
||||
if mapping.key == key_to_set:
|
||||
dest.write(line_out)
|
||||
replaced = True
|
||||
else:
|
||||
dest.write(mapping.original.string)
|
||||
missing_newline = not mapping.original.string.endswith("\n")
|
||||
if not replaced:
|
||||
if missing_newline:
|
||||
dest.write("\n")
|
||||
dest.write(line_out)
|
||||
|
||||
return True, key_to_set, value_to_set
|
||||
|
||||
|
||||
def unset_key(
|
||||
dotenv_path: Union[str, _PathLike],
|
||||
key_to_unset: str,
|
||||
quote_mode: str = "always",
|
||||
encoding: Optional[str] = "utf-8",
|
||||
) -> Tuple[Optional[bool], str]:
|
||||
"""
|
||||
Removes a given key from the given .env
|
||||
|
||||
If the .env path given doesn't exist, fails
|
||||
If the given key doesn't exist in the .env, fails
|
||||
"""
|
||||
if not os.path.exists(dotenv_path):
|
||||
logger.warning("Can't delete from %s - it doesn't exist.", dotenv_path)
|
||||
return None, key_to_unset
|
||||
|
||||
removed = False
|
||||
with rewrite(dotenv_path, encoding=encoding) as (source, dest):
|
||||
for mapping in with_warn_for_invalid_lines(parse_stream(source)):
|
||||
if mapping.key == key_to_unset:
|
||||
removed = True
|
||||
else:
|
||||
dest.write(mapping.original.string)
|
||||
|
||||
if not removed:
|
||||
logger.warning("Key %s not removed from %s - key doesn't exist.", key_to_unset, dotenv_path)
|
||||
return None, key_to_unset
|
||||
|
||||
return removed, key_to_unset
|
||||
|
||||
|
||||
def resolve_variables(
|
||||
values: Iterable[Tuple[str, Optional[str]]],
|
||||
override: bool,
|
||||
) -> Mapping[str, Optional[str]]:
|
||||
new_values = {} # type: Dict[str, Optional[str]]
|
||||
|
||||
for (name, value) in values:
|
||||
if value is None:
|
||||
result = None
|
||||
else:
|
||||
atoms = parse_variables(value)
|
||||
env = {} # type: Dict[str, Optional[str]]
|
||||
if override:
|
||||
env.update(os.environ) # type: ignore
|
||||
env.update(new_values)
|
||||
else:
|
||||
env.update(new_values)
|
||||
env.update(os.environ) # type: ignore
|
||||
result = "".join(atom.resolve(env) for atom in atoms)
|
||||
|
||||
new_values[name] = result
|
||||
|
||||
return new_values
|
||||
|
||||
|
||||
def _walk_to_root(path: str) -> Iterator[str]:
|
||||
"""
|
||||
Yield directories starting from the given directory up to the root
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
raise IOError('Starting path not found')
|
||||
|
||||
if os.path.isfile(path):
|
||||
path = os.path.dirname(path)
|
||||
|
||||
last_dir = None
|
||||
current_dir = os.path.abspath(path)
|
||||
while last_dir != current_dir:
|
||||
yield current_dir
|
||||
parent_dir = os.path.abspath(os.path.join(current_dir, os.path.pardir))
|
||||
last_dir, current_dir = current_dir, parent_dir
|
||||
|
||||
|
||||
def find_dotenv(
|
||||
filename: str = '.env',
|
||||
raise_error_if_not_found: bool = False,
|
||||
usecwd: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Search in increasingly higher folders for the given file
|
||||
|
||||
Returns path to the file if found, or an empty string otherwise
|
||||
"""
|
||||
|
||||
def _is_interactive():
|
||||
""" Decide whether this is running in a REPL or IPython notebook """
|
||||
main = __import__('__main__', None, None, fromlist=['__file__'])
|
||||
return not hasattr(main, '__file__')
|
||||
|
||||
if usecwd or _is_interactive() or getattr(sys, 'frozen', False):
|
||||
# Should work without __file__, e.g. in REPL or IPython notebook.
|
||||
path = os.getcwd()
|
||||
else:
|
||||
# will work for .py files
|
||||
frame = sys._getframe()
|
||||
current_file = __file__
|
||||
|
||||
while frame.f_code.co_filename == current_file:
|
||||
assert frame.f_back is not None
|
||||
frame = frame.f_back
|
||||
frame_filename = frame.f_code.co_filename
|
||||
path = os.path.dirname(os.path.abspath(frame_filename))
|
||||
|
||||
for dirname in _walk_to_root(path):
|
||||
check_path = os.path.join(dirname, filename)
|
||||
if os.path.isfile(check_path):
|
||||
return check_path
|
||||
|
||||
if raise_error_if_not_found:
|
||||
raise IOError('File not found')
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
def load_dotenv(
|
||||
dotenv_path: Union[str, _PathLike, None] = None,
|
||||
stream: Optional[IO[str]] = None,
|
||||
verbose: bool = False,
|
||||
override: bool = False,
|
||||
interpolate: bool = True,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
) -> bool:
|
||||
"""Parse a .env file and then load all the variables found as environment variables.
|
||||
|
||||
- *dotenv_path*: absolute or relative path to .env file.
|
||||
- *stream*: Text stream (such as `io.StringIO`) with .env content, used if
|
||||
`dotenv_path` is `None`.
|
||||
- *verbose*: whether to output a warning the .env file is missing. Defaults to
|
||||
`False`.
|
||||
- *override*: whether to override the system environment variables with the variables
|
||||
in `.env` file. Defaults to `False`.
|
||||
- *encoding*: encoding to be used to read the file.
|
||||
|
||||
If both `dotenv_path` and `stream`, `find_dotenv()` is used to find the .env file.
|
||||
"""
|
||||
if dotenv_path is None and stream is None:
|
||||
dotenv_path = find_dotenv()
|
||||
|
||||
dotenv = DotEnv(
|
||||
dotenv_path=dotenv_path,
|
||||
stream=stream,
|
||||
verbose=verbose,
|
||||
interpolate=interpolate,
|
||||
override=override,
|
||||
encoding=encoding,
|
||||
)
|
||||
return dotenv.set_as_environment_variables()
|
||||
|
||||
|
||||
def dotenv_values(
|
||||
dotenv_path: Union[str, _PathLike, None] = None,
|
||||
stream: Optional[IO[str]] = None,
|
||||
verbose: bool = False,
|
||||
interpolate: bool = True,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Parse a .env file and return its content as a dict.
|
||||
|
||||
- *dotenv_path*: absolute or relative path to .env file.
|
||||
- *stream*: `StringIO` object with .env content, used if `dotenv_path` is `None`.
|
||||
- *verbose*: whether to output a warning the .env file is missing. Defaults to
|
||||
`False`.
|
||||
in `.env` file. Defaults to `False`.
|
||||
- *encoding*: encoding to be used to read the file.
|
||||
|
||||
If both `dotenv_path` and `stream`, `find_dotenv()` is used to find the .env file.
|
||||
"""
|
||||
if dotenv_path is None and stream is None:
|
||||
dotenv_path = find_dotenv()
|
||||
|
||||
return DotEnv(
|
||||
dotenv_path=dotenv_path,
|
||||
stream=stream,
|
||||
verbose=verbose,
|
||||
interpolate=interpolate,
|
||||
override=True,
|
||||
encoding=encoding,
|
||||
).dict()
|
@ -0,0 +1,182 @@
|
||||
import codecs
|
||||
import re
|
||||
from typing import (IO, Iterator, Match, NamedTuple, Optional, # noqa:F401
|
||||
Pattern, Sequence, Tuple)
|
||||
|
||||
|
||||
def make_regex(string: str, extra_flags: int = 0) -> Pattern[str]:
|
||||
return re.compile(string, re.UNICODE | extra_flags)
|
||||
|
||||
|
||||
_newline = make_regex(r"(\r\n|\n|\r)")
|
||||
_multiline_whitespace = make_regex(r"\s*", extra_flags=re.MULTILINE)
|
||||
_whitespace = make_regex(r"[^\S\r\n]*")
|
||||
_export = make_regex(r"(?:export[^\S\r\n]+)?")
|
||||
_single_quoted_key = make_regex(r"'([^']+)'")
|
||||
_unquoted_key = make_regex(r"([^=\#\s]+)")
|
||||
_equal_sign = make_regex(r"(=[^\S\r\n]*)")
|
||||
_single_quoted_value = make_regex(r"'((?:\\'|[^'])*)'")
|
||||
_double_quoted_value = make_regex(r'"((?:\\"|[^"])*)"')
|
||||
_unquoted_value = make_regex(r"([^\r\n]*)")
|
||||
_comment = make_regex(r"(?:[^\S\r\n]*#[^\r\n]*)?")
|
||||
_end_of_line = make_regex(r"[^\S\r\n]*(?:\r\n|\n|\r|$)")
|
||||
_rest_of_line = make_regex(r"[^\r\n]*(?:\r|\n|\r\n)?")
|
||||
_double_quote_escapes = make_regex(r"\\[\\'\"abfnrtv]")
|
||||
_single_quote_escapes = make_regex(r"\\[\\']")
|
||||
|
||||
|
||||
Original = NamedTuple(
|
||||
"Original",
|
||||
[
|
||||
("string", str),
|
||||
("line", int),
|
||||
],
|
||||
)
|
||||
|
||||
Binding = NamedTuple(
|
||||
"Binding",
|
||||
[
|
||||
("key", Optional[str]),
|
||||
("value", Optional[str]),
|
||||
("original", Original),
|
||||
("error", bool),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Position:
|
||||
def __init__(self, chars: int, line: int) -> None:
|
||||
self.chars = chars
|
||||
self.line = line
|
||||
|
||||
@classmethod
|
||||
def start(cls) -> "Position":
|
||||
return cls(chars=0, line=1)
|
||||
|
||||
def set(self, other: "Position") -> None:
|
||||
self.chars = other.chars
|
||||
self.line = other.line
|
||||
|
||||
def advance(self, string: str) -> None:
|
||||
self.chars += len(string)
|
||||
self.line += len(re.findall(_newline, string))
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Reader:
|
||||
def __init__(self, stream: IO[str]) -> None:
|
||||
self.string = stream.read()
|
||||
self.position = Position.start()
|
||||
self.mark = Position.start()
|
||||
|
||||
def has_next(self) -> bool:
|
||||
return self.position.chars < len(self.string)
|
||||
|
||||
def set_mark(self) -> None:
|
||||
self.mark.set(self.position)
|
||||
|
||||
def get_marked(self) -> Original:
|
||||
return Original(
|
||||
string=self.string[self.mark.chars:self.position.chars],
|
||||
line=self.mark.line,
|
||||
)
|
||||
|
||||
def peek(self, count: int) -> str:
|
||||
return self.string[self.position.chars:self.position.chars + count]
|
||||
|
||||
def read(self, count: int) -> str:
|
||||
result = self.string[self.position.chars:self.position.chars + count]
|
||||
if len(result) < count:
|
||||
raise Error("read: End of string")
|
||||
self.position.advance(result)
|
||||
return result
|
||||
|
||||
def read_regex(self, regex: Pattern[str]) -> Sequence[str]:
|
||||
match = regex.match(self.string, self.position.chars)
|
||||
if match is None:
|
||||
raise Error("read_regex: Pattern not found")
|
||||
self.position.advance(self.string[match.start():match.end()])
|
||||
return match.groups()
|
||||
|
||||
|
||||
def decode_escapes(regex: Pattern[str], string: str) -> str:
|
||||
def decode_match(match: Match[str]) -> str:
|
||||
return codecs.decode(match.group(0), 'unicode-escape') # type: ignore
|
||||
|
||||
return regex.sub(decode_match, string)
|
||||
|
||||
|
||||
def parse_key(reader: Reader) -> Optional[str]:
|
||||
char = reader.peek(1)
|
||||
if char == "#":
|
||||
return None
|
||||
elif char == "'":
|
||||
(key,) = reader.read_regex(_single_quoted_key)
|
||||
else:
|
||||
(key,) = reader.read_regex(_unquoted_key)
|
||||
return key
|
||||
|
||||
|
||||
def parse_unquoted_value(reader: Reader) -> str:
|
||||
(part,) = reader.read_regex(_unquoted_value)
|
||||
return re.sub(r"\s+#.*", "", part).rstrip()
|
||||
|
||||
|
||||
def parse_value(reader: Reader) -> str:
|
||||
char = reader.peek(1)
|
||||
if char == u"'":
|
||||
(value,) = reader.read_regex(_single_quoted_value)
|
||||
return decode_escapes(_single_quote_escapes, value)
|
||||
elif char == u'"':
|
||||
(value,) = reader.read_regex(_double_quoted_value)
|
||||
return decode_escapes(_double_quote_escapes, value)
|
||||
elif char in (u"", u"\n", u"\r"):
|
||||
return u""
|
||||
else:
|
||||
return parse_unquoted_value(reader)
|
||||
|
||||
|
||||
def parse_binding(reader: Reader) -> Binding:
|
||||
reader.set_mark()
|
||||
try:
|
||||
reader.read_regex(_multiline_whitespace)
|
||||
if not reader.has_next():
|
||||
return Binding(
|
||||
key=None,
|
||||
value=None,
|
||||
original=reader.get_marked(),
|
||||
error=False,
|
||||
)
|
||||
reader.read_regex(_export)
|
||||
key = parse_key(reader)
|
||||
reader.read_regex(_whitespace)
|
||||
if reader.peek(1) == "=":
|
||||
reader.read_regex(_equal_sign)
|
||||
value = parse_value(reader) # type: Optional[str]
|
||||
else:
|
||||
value = None
|
||||
reader.read_regex(_comment)
|
||||
reader.read_regex(_end_of_line)
|
||||
return Binding(
|
||||
key=key,
|
||||
value=value,
|
||||
original=reader.get_marked(),
|
||||
error=False,
|
||||
)
|
||||
except Error:
|
||||
reader.read_regex(_rest_of_line)
|
||||
return Binding(
|
||||
key=None,
|
||||
value=None,
|
||||
original=reader.get_marked(),
|
||||
error=True,
|
||||
)
|
||||
|
||||
|
||||
def parse_stream(stream: IO[str]) -> Iterator[Binding]:
|
||||
reader = Reader(stream)
|
||||
while reader.has_next():
|
||||
yield parse_binding(reader)
|
@ -0,0 +1 @@
|
||||
# Marker file for PEP 561
|
@ -0,0 +1,88 @@
|
||||
import re
|
||||
from abc import ABCMeta
|
||||
from typing import Iterator, Mapping, Optional, Pattern
|
||||
|
||||
_posix_variable = re.compile(
|
||||
r"""
|
||||
\$\{
|
||||
(?P<name>[^\}:]*)
|
||||
(?::-
|
||||
(?P<default>[^\}]*)
|
||||
)?
|
||||
\}
|
||||
""",
|
||||
re.VERBOSE,
|
||||
) # type: Pattern[str]
|
||||
|
||||
|
||||
class Atom():
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
result = self.__eq__(other)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
return not result
|
||||
|
||||
def resolve(self, env: Mapping[str, Optional[str]]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Literal(Atom):
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Literal(value={})".format(self.value)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.value == other.value
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__class__, self.value))
|
||||
|
||||
def resolve(self, env: Mapping[str, Optional[str]]) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class Variable(Atom):
|
||||
def __init__(self, name: str, default: Optional[str]) -> None:
|
||||
self.name = name
|
||||
self.default = default
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Variable(name={}, default={})".format(self.name, self.default)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return (self.name, self.default) == (other.name, other.default)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__class__, self.name, self.default))
|
||||
|
||||
def resolve(self, env: Mapping[str, Optional[str]]) -> str:
|
||||
default = self.default if self.default is not None else ""
|
||||
result = env.get(self.name, default)
|
||||
return result if result is not None else ""
|
||||
|
||||
|
||||
def parse_variables(value: str) -> Iterator[Atom]:
|
||||
cursor = 0
|
||||
|
||||
for match in _posix_variable.finditer(value):
|
||||
(start, end) = match.span()
|
||||
name = match.groupdict()["name"]
|
||||
default = match.groupdict()["default"]
|
||||
|
||||
if start > cursor:
|
||||
yield Literal(value=value[cursor:start])
|
||||
|
||||
yield Variable(name=name, default=default)
|
||||
cursor = end
|
||||
|
||||
length = len(value)
|
||||
if cursor < length:
|
||||
yield Literal(value=value[cursor:length])
|
@ -0,0 +1 @@
|
||||
__version__ = "0.20.0"
|
@ -0,0 +1 @@
|
||||
pip
|
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2018 Sebastián Ramírez
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
@ -0,0 +1,91 @@
|
||||
fastapi-0.81.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
fastapi-0.81.0.dist-info/LICENSE,sha256=Tsif_IFIW5f-xYSy1KlhAy7v_oNEU4lP2cEnSQbMdE4,1086
|
||||
fastapi-0.81.0.dist-info/METADATA,sha256=CD9zEFul1ZeqIjXG_CDdJOsBUO--4HnPgtRpq_8jSTs,25126
|
||||
fastapi-0.81.0.dist-info/RECORD,,
|
||||
fastapi-0.81.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
fastapi-0.81.0.dist-info/WHEEL,sha256=4TfKIB_xu-04bc2iKz6_zFt-gEFEEDU_31HGhqzOCE8,81
|
||||
fastapi/__init__.py,sha256=zwiRX7r_siNxDDxGccYn9vMq5KoZEHagieRbpPmpiYI,1015
|
||||
fastapi/__pycache__/__init__.cpython-37.pyc,,
|
||||
fastapi/__pycache__/applications.cpython-37.pyc,,
|
||||
fastapi/__pycache__/background.cpython-37.pyc,,
|
||||
fastapi/__pycache__/concurrency.cpython-37.pyc,,
|
||||
fastapi/__pycache__/datastructures.cpython-37.pyc,,
|
||||
fastapi/__pycache__/encoders.cpython-37.pyc,,
|
||||
fastapi/__pycache__/exception_handlers.cpython-37.pyc,,
|
||||
fastapi/__pycache__/exceptions.cpython-37.pyc,,
|
||||
fastapi/__pycache__/logger.cpython-37.pyc,,
|
||||
fastapi/__pycache__/param_functions.cpython-37.pyc,,
|
||||
fastapi/__pycache__/params.cpython-37.pyc,,
|
||||
fastapi/__pycache__/requests.cpython-37.pyc,,
|
||||
fastapi/__pycache__/responses.cpython-37.pyc,,
|
||||
fastapi/__pycache__/routing.cpython-37.pyc,,
|
||||
fastapi/__pycache__/staticfiles.cpython-37.pyc,,
|
||||
fastapi/__pycache__/templating.cpython-37.pyc,,
|
||||
fastapi/__pycache__/testclient.cpython-37.pyc,,
|
||||
fastapi/__pycache__/types.cpython-37.pyc,,
|
||||
fastapi/__pycache__/utils.cpython-37.pyc,,
|
||||
fastapi/__pycache__/websockets.cpython-37.pyc,,
|
||||
fastapi/applications.py,sha256=DGDGZRPpNW_KRKMeyoo5gl4YuE1c9uCNWigLmsqXJ30,38062
|
||||
fastapi/background.py,sha256=HtN5_pJJrOdalSbuGSMKJAPNWUU5h7rY_BXXubu7-IQ,76
|
||||
fastapi/concurrency.py,sha256=G9Q9--9t5362vQuzLzB0pL8e6BbxcVU9EpBhVyaEGJ8,1078
|
||||
fastapi/datastructures.py,sha256=oW6xuU0C-sBwbcyXI-MlBO0tSS4BSPB2lYUa1yCw8-A,1905
|
||||
fastapi/dependencies/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
fastapi/dependencies/__pycache__/__init__.cpython-37.pyc,,
|
||||
fastapi/dependencies/__pycache__/models.cpython-37.pyc,,
|
||||
fastapi/dependencies/__pycache__/utils.cpython-37.pyc,,
|
||||
fastapi/dependencies/models.py,sha256=zNbioxICuOeb-9ADDVQ45hUHOC0PBtPVEfVU3f1l_nc,2494
|
||||
fastapi/dependencies/utils.py,sha256=zsbThdKoayRv0pWplZ6jWjK9ekIoix036NO_nv565xc,27373
|
||||
fastapi/encoders.py,sha256=0Zd7Ad60YMVzB5A0rcHkUZwwriY9In7FY6ziyX5BZpw,5885
|
||||
fastapi/exception_handlers.py,sha256=UVYCCe4qt5-5_NuQ3SxTXjDvOdKMHiTfcLp3RUKXhg8,912
|
||||
fastapi/exceptions.py,sha256=Wy1sP3EisJohtxr-uKoH58QumPWmqHp6cpXOD3TTPOs,1117
|
||||
fastapi/logger.py,sha256=I9NNi3ov8AcqbsbC9wl1X-hdItKgYt2XTrx1f99Zpl4,54
|
||||
fastapi/middleware/__init__.py,sha256=oQDxiFVcc1fYJUOIFvphnK7pTT5kktmfL32QXpBFvvo,58
|
||||
fastapi/middleware/__pycache__/__init__.cpython-37.pyc,,
|
||||
fastapi/middleware/__pycache__/asyncexitstack.cpython-37.pyc,,
|
||||
fastapi/middleware/__pycache__/cors.cpython-37.pyc,,
|
||||
fastapi/middleware/__pycache__/gzip.cpython-37.pyc,,
|
||||
fastapi/middleware/__pycache__/httpsredirect.cpython-37.pyc,,
|
||||
fastapi/middleware/__pycache__/trustedhost.cpython-37.pyc,,
|
||||
fastapi/middleware/__pycache__/wsgi.cpython-37.pyc,,
|
||||
fastapi/middleware/asyncexitstack.py,sha256=72XjQmQ_tB_tTs9xOc0akXF_7TwZUPdyfc8gsN5LV8E,1197
|
||||
fastapi/middleware/cors.py,sha256=ynwjWQZoc_vbhzZ3_ZXceoaSrslHFHPdoM52rXr0WUU,79
|
||||
fastapi/middleware/gzip.py,sha256=xM5PcsH8QlAimZw4VDvcmTnqQamslThsfe3CVN2voa0,79
|
||||
fastapi/middleware/httpsredirect.py,sha256=rL8eXMnmLijwVkH7_400zHri1AekfeBd6D6qs8ix950,115
|
||||
fastapi/middleware/trustedhost.py,sha256=eE5XGRxGa7c5zPnMJDGp3BxaL25k5iVQlhnv-Pk0Pss,109
|
||||
fastapi/middleware/wsgi.py,sha256=Z3Ue-7wni4lUZMvH3G9ek__acgYdJstbnpZX_HQAboY,79
|
||||
fastapi/openapi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
fastapi/openapi/__pycache__/__init__.cpython-37.pyc,,
|
||||
fastapi/openapi/__pycache__/constants.cpython-37.pyc,,
|
||||
fastapi/openapi/__pycache__/docs.cpython-37.pyc,,
|
||||
fastapi/openapi/__pycache__/models.cpython-37.pyc,,
|
||||
fastapi/openapi/__pycache__/utils.cpython-37.pyc,,
|
||||
fastapi/openapi/constants.py,sha256=mWxYBjED6PU-tQ9X227Qkq2SsW2cv-C1jYFKt63xxEs,107
|
||||
fastapi/openapi/docs.py,sha256=JBRaq7EEmeC-xoRSRFj6qZWQxfOZW_jvTw0r-PiKcZ4,6532
|
||||
fastapi/openapi/models.py,sha256=_XWDBU4Zlp5M9V6YI1kmXbqYKwK_xZxaqIA_DhLwqHk,11027
|
||||
fastapi/openapi/utils.py,sha256=NUrPv65N8976sqfG1kxHiHJrPYxT53NkjCWgmuyAOF0,18409
|
||||
fastapi/param_functions.py,sha256=mhV6aNZmXuf_A7rZ830o3V-DFqbonIDRs_prDTetLs4,7521
|
||||
fastapi/params.py,sha256=LRoO2H1XBBIfBGB82gHtfnXhDZiDz-7CIordN3FoU1I,10600
|
||||
fastapi/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
fastapi/requests.py,sha256=zayepKFcienBllv3snmWI20Gk0oHNVLU4DDhqXBb4LU,142
|
||||
fastapi/responses.py,sha256=K9Gzj0E5GfwTgLiuCE-BuiXn9JdqGJHeXBMkvd8lFC8,1196
|
||||
fastapi/routing.py,sha256=gq17oux-hW8G5sr4ZhBRl-m6R3mx0jJRDHCMEMNTPcU,53459
|
||||
fastapi/security/__init__.py,sha256=bO8pNmxqVRXUjfl2mOKiVZLn0FpBQ61VUYVjmppnbJw,881
|
||||
fastapi/security/__pycache__/__init__.cpython-37.pyc,,
|
||||
fastapi/security/__pycache__/api_key.cpython-37.pyc,,
|
||||
fastapi/security/__pycache__/base.cpython-37.pyc,,
|
||||
fastapi/security/__pycache__/http.cpython-37.pyc,,
|
||||
fastapi/security/__pycache__/oauth2.cpython-37.pyc,,
|
||||
fastapi/security/__pycache__/open_id_connect_url.cpython-37.pyc,,
|
||||
fastapi/security/__pycache__/utils.cpython-37.pyc,,
|
||||
fastapi/security/api_key.py,sha256=NbVpS9TxDOaipoZa8-SREHyMtTcM3bmy5szMiQxEX9s,2793
|
||||
fastapi/security/base.py,sha256=dl4pvbC-RxjfbWgPtCWd8MVU-7CB2SZ22rJDXVCXO6c,141
|
||||
fastapi/security/http.py,sha256=ZSy3DFKFDLa3-I4vwsY1r8hQB_VrtAXw4-fMJauZIK0,5984
|
||||
fastapi/security/oauth2.py,sha256=1NPA12T1_r2uo4iQWxJCKjUqVVdb532YDvX9e3PVpcE,8212
|
||||
fastapi/security/open_id_connect_url.py,sha256=iikzuJCz_DG44Q77VrupqSoCbJYaiXkuo_W-kdmAzeo,1145
|
||||
fastapi/security/utils.py,sha256=izlh-HBaL1VnJeOeRTQnyNgI3hgTFs73eCyLy-snb4A,266
|
||||
fastapi/staticfiles.py,sha256=iirGIt3sdY2QZXd36ijs3Cj-T0FuGFda3cd90kM9Ikw,69
|
||||
fastapi/templating.py,sha256=4zsuTWgcjcEainMJFAlW6-gnslm6AgOS1SiiDWfmQxk,76
|
||||
fastapi/testclient.py,sha256=nBvaAmX66YldReJNZXPOk1sfuo2Q6hs8bOvIaCep6LQ,66
|
||||
fastapi/types.py,sha256=r6MngTHzkZOP9lzXgduje9yeZe5EInWAzCLuRJlhIuE,118
|
||||
fastapi/utils.py,sha256=g5JDiJ4vwc-l2ffpLdkkBDEjJJz3dqZbQ2cwz-yNVwI,6594
|
||||
fastapi/websockets.py,sha256=SroIkqE-lfChvtRP3mFaNKKtD6TmePDWBZtQfgM4noo,148
|
@ -0,0 +1,4 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: flit 3.7.1
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
@ -0,0 +1,24 @@
|
||||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.81.0"
|
||||
|
||||
from starlette import status as status
|
||||
|
||||
from .applications import FastAPI as FastAPI
|
||||
from .background import BackgroundTasks as BackgroundTasks
|
||||
from .datastructures import UploadFile as UploadFile
|
||||
from .exceptions import HTTPException as HTTPException
|
||||
from .param_functions import Body as Body
|
||||
from .param_functions import Cookie as Cookie
|
||||
from .param_functions import Depends as Depends
|
||||
from .param_functions import File as File
|
||||
from .param_functions import Form as Form
|
||||
from .param_functions import Header as Header
|
||||
from .param_functions import Path as Path
|
||||
from .param_functions import Query as Query
|
||||
from .param_functions import Security as Security
|
||||
from .requests import Request as Request
|
||||
from .responses import Response as Response
|
||||
from .routing import APIRouter as APIRouter
|
||||
from .websockets import WebSocket as WebSocket
|
||||
from .websockets import WebSocketDisconnect as WebSocketDisconnect
|
@ -0,0 +1,871 @@
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.encoders import DictIntStrAny, SetIntStr
|
||||
from fastapi.exception_handlers import (
|
||||
http_exception_handler,
|
||||
request_validation_exception_handler,
|
||||
)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
||||
from fastapi.openapi.docs import (
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
get_swagger_ui_oauth2_redirect_html,
|
||||
)
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.params import Depends
|
||||
from fastapi.types import DecoratedCallable
|
||||
from fastapi.utils import generate_unique_id
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import State
|
||||
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class FastAPI(Starlette):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
debug: bool = False,
|
||||
routes: Optional[List[BaseRoute]] = None,
|
||||
title: str = "FastAPI",
|
||||
description: str = "",
|
||||
version: str = "0.1.0",
|
||||
openapi_url: Optional[str] = "/openapi.json",
|
||||
openapi_tags: Optional[List[Dict[str, Any]]] = None,
|
||||
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
default_response_class: Type[Response] = Default(JSONResponse),
|
||||
docs_url: Optional[str] = "/docs",
|
||||
redoc_url: Optional[str] = "/redoc",
|
||||
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
|
||||
swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
exception_handlers: Optional[
|
||||
Dict[
|
||||
Union[int, Type[Exception]],
|
||||
Callable[[Request, Any], Coroutine[Any, Any, Response]],
|
||||
]
|
||||
] = None,
|
||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
terms_of_service: Optional[str] = None,
|
||||
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
openapi_prefix: str = "",
|
||||
root_path: str = "",
|
||||
root_path_in_servers: bool = True,
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
swagger_ui_parameters: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
self._debug: bool = debug
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.version = version
|
||||
self.terms_of_service = terms_of_service
|
||||
self.contact = contact
|
||||
self.license_info = license_info
|
||||
self.openapi_url = openapi_url
|
||||
self.openapi_tags = openapi_tags
|
||||
self.root_path_in_servers = root_path_in_servers
|
||||
self.docs_url = docs_url
|
||||
self.redoc_url = redoc_url
|
||||
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
|
||||
self.swagger_ui_init_oauth = swagger_ui_init_oauth
|
||||
self.swagger_ui_parameters = swagger_ui_parameters
|
||||
self.servers = servers or []
|
||||
self.extra = extra
|
||||
self.openapi_version = "3.0.2"
|
||||
self.openapi_schema: Optional[Dict[str, Any]] = None
|
||||
if self.openapi_url:
|
||||
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
|
||||
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
|
||||
# TODO: remove when discarding the openapi_prefix parameter
|
||||
if openapi_prefix:
|
||||
logger.warning(
|
||||
'"openapi_prefix" has been deprecated in favor of "root_path", which '
|
||||
"follows more closely the ASGI standard, is simpler, and more "
|
||||
"automatic. Check the docs at "
|
||||
"https://fastapi.tiangolo.com/advanced/sub-applications/"
|
||||
)
|
||||
self.root_path = root_path or openapi_prefix
|
||||
self.state: State = State()
|
||||
self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
self.router: routing.APIRouter = routing.APIRouter(
|
||||
routes=routes,
|
||||
dependency_overrides_provider=self,
|
||||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
default_response_class=default_response_class,
|
||||
dependencies=dependencies,
|
||||
callbacks=callbacks,
|
||||
deprecated=deprecated,
|
||||
include_in_schema=include_in_schema,
|
||||
responses=responses,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
self.exception_handlers: Dict[
|
||||
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
|
||||
] = ({} if exception_handlers is None else dict(exception_handlers))
|
||||
self.exception_handlers.setdefault(HTTPException, http_exception_handler)
|
||||
self.exception_handlers.setdefault(
|
||||
RequestValidationError, request_validation_exception_handler
|
||||
)
|
||||
|
||||
self.user_middleware: List[Middleware] = (
|
||||
[] if middleware is None else list(middleware)
|
||||
)
|
||||
self.middleware_stack: ASGIApp = self.build_middleware_stack()
|
||||
self.setup()
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
|
||||
# inside of ExceptionMiddleware, inside of custom user middlewares
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
error_handler = value
|
||||
else:
|
||||
exception_handlers[key] = value
|
||||
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||
+ self.user_middleware
|
||||
+ [
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||
),
|
||||
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
|
||||
# contextvars.
|
||||
# This needs to happen after user middlewares because those create a
|
||||
# new contextvars context copy by using a new AnyIO task group.
|
||||
# The initial part of dependencies with yield is executed in the
|
||||
# FastAPI code, inside all the middlewares, but the teardown part
|
||||
# (after yield) is executed in the AsyncExitStack in this middleware,
|
||||
# if the AsyncExitStack lived outside of the custom middlewares and
|
||||
# contextvars were set in a dependency with yield in that internal
|
||||
# contextvars context, the values would not be available in the
|
||||
# outside context of the AsyncExitStack.
|
||||
# By putting the middleware and the AsyncExitStack here, inside all
|
||||
# user middlewares, the code before and after yield in dependencies
|
||||
# with yield is executed in the same contextvars context, so all values
|
||||
# set in contextvars before yield is still available after yield as
|
||||
# would be expected.
|
||||
# Additionally, by having this AsyncExitStack here, after the
|
||||
# ExceptionMiddleware, now dependencies can catch handled exceptions,
|
||||
# e.g. HTTPException, to customize the teardown code (e.g. DB session
|
||||
# rollback).
|
||||
Middleware(AsyncExitStackMiddleware),
|
||||
]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, options in reversed(middleware):
|
||||
app = cls(app=app, **options)
|
||||
return app
|
||||
|
||||
def openapi(self) -> Dict[str, Any]:
|
||||
if not self.openapi_schema:
|
||||
self.openapi_schema = get_openapi(
|
||||
title=self.title,
|
||||
version=self.version,
|
||||
openapi_version=self.openapi_version,
|
||||
description=self.description,
|
||||
terms_of_service=self.terms_of_service,
|
||||
contact=self.contact,
|
||||
license_info=self.license_info,
|
||||
routes=self.routes,
|
||||
tags=self.openapi_tags,
|
||||
servers=self.servers,
|
||||
)
|
||||
return self.openapi_schema
|
||||
|
||||
def setup(self) -> None:
|
||||
if self.openapi_url:
|
||||
urls = (server_data.get("url") for server_data in self.servers)
|
||||
server_urls = {url for url in urls if url}
|
||||
|
||||
async def openapi(req: Request) -> JSONResponse:
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
if root_path not in server_urls:
|
||||
if root_path and self.root_path_in_servers:
|
||||
self.servers.insert(0, {"url": root_path})
|
||||
server_urls.add(root_path)
|
||||
return JSONResponse(self.openapi())
|
||||
|
||||
self.add_route(self.openapi_url, openapi, include_in_schema=False)
|
||||
if self.openapi_url and self.docs_url:
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
openapi_url = root_path + self.openapi_url
|
||||
oauth2_redirect_url = self.swagger_ui_oauth2_redirect_url
|
||||
if oauth2_redirect_url:
|
||||
oauth2_redirect_url = root_path + oauth2_redirect_url
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=openapi_url,
|
||||
title=self.title + " - Swagger UI",
|
||||
oauth2_redirect_url=oauth2_redirect_url,
|
||||
init_oauth=self.swagger_ui_init_oauth,
|
||||
swagger_ui_parameters=self.swagger_ui_parameters,
|
||||
)
|
||||
|
||||
self.add_route(self.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
|
||||
if self.swagger_ui_oauth2_redirect_url:
|
||||
|
||||
async def swagger_ui_redirect(req: Request) -> HTMLResponse:
|
||||
return get_swagger_ui_oauth2_redirect_html()
|
||||
|
||||
self.add_route(
|
||||
self.swagger_ui_oauth2_redirect_url,
|
||||
swagger_ui_redirect,
|
||||
include_in_schema=False,
|
||||
)
|
||||
if self.openapi_url and self.redoc_url:
|
||||
|
||||
async def redoc_html(req: Request) -> HTMLResponse:
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
openapi_url = root_path + self.openapi_url
|
||||
return get_redoc_html(
|
||||
openapi_url=openapi_url, title=self.title + " - ReDoc"
|
||||
)
|
||||
|
||||
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.root_path:
|
||||
scope["root_path"] = self.root_path
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
def add_api_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[..., Coroutine[Any, Any, Response]],
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
methods: Optional[List[str]] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
|
||||
JSONResponse
|
||||
),
|
||||
name: Optional[str] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> None:
|
||||
self.router.add_api_route(
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
methods=methods,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def api_route(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
methods: Optional[List[str]] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.router.add_api_route(
|
||||
path,
|
||||
func,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
methods=methods,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def add_api_websocket_route(
|
||||
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
|
||||
) -> None:
|
||||
self.router.add_api_websocket_route(path, endpoint, name=name)
|
||||
|
||||
def websocket(
|
||||
self, path: str, name: Optional[str] = None
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def include_router(
|
||||
self,
|
||||
router: routing.APIRouter,
|
||||
*,
|
||||
prefix: str = "",
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
default_response_class: Type[Response] = Default(JSONResponse),
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> None:
|
||||
self.router.include_router(
|
||||
router,
|
||||
prefix=prefix,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
include_in_schema=include_in_schema,
|
||||
default_response_class=default_response_class,
|
||||
callbacks=callbacks,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.get(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.put(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.post(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.delete(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
operation_id=operation_id,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def options(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.options(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def head(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.head(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.patch(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_model: Any = None,
|
||||
status_code: Optional[int] = None,
|
||||
tags: Optional[List[Union[str, Enum]]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
summary: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
response_description: str = "Successful Response",
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
response_model_by_alias: bool = True,
|
||||
response_model_exclude_unset: bool = False,
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
openapi_extra: Optional[Dict[str, Any]] = None,
|
||||
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.trace(
|
||||
path,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
response_model_exclude=response_model_exclude,
|
||||
response_model_by_alias=response_model_by_alias,
|
||||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
@ -0,0 +1 @@
|
||||
from starlette.background import BackgroundTasks as BackgroundTasks # noqa
|
@ -0,0 +1,32 @@
|
||||
import sys
|
||||
from typing import AsyncGenerator, ContextManager, TypeVar
|
||||
|
||||
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
|
||||
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
|
||||
from starlette.concurrency import ( # noqa
|
||||
run_until_first_complete as run_until_first_complete,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
from contextlib import AsyncExitStack as AsyncExitStack
|
||||
from contextlib import asynccontextmanager as asynccontextmanager
|
||||
else:
|
||||
from contextlib2 import AsyncExitStack as AsyncExitStack # noqa
|
||||
from contextlib2 import asynccontextmanager as asynccontextmanager # noqa
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def contextmanager_in_threadpool(
|
||||
cm: ContextManager[_T],
|
||||
) -> AsyncGenerator[_T, None]:
|
||||
try:
|
||||
yield await run_in_threadpool(cm.__enter__)
|
||||
except Exception as e:
|
||||
ok: bool = await run_in_threadpool(cm.__exit__, type(e), e, None)
|
||||
if not ok:
|
||||
raise e
|
||||
else:
|
||||
await run_in_threadpool(cm.__exit__, None, None, None)
|
@ -0,0 +1,56 @@
|
||||
from typing import Any, Callable, Dict, Iterable, Type, TypeVar
|
||||
|
||||
from starlette.datastructures import URL as URL # noqa: F401
|
||||
from starlette.datastructures import Address as Address # noqa: F401
|
||||
from starlette.datastructures import FormData as FormData # noqa: F401
|
||||
from starlette.datastructures import Headers as Headers # noqa: F401
|
||||
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
|
||||
from starlette.datastructures import State as State # noqa: F401
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
|
||||
class UploadFile(StarletteUploadFile):
|
||||
@classmethod
|
||||
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls: Type["UploadFile"], v: Any) -> Any:
|
||||
if not isinstance(v, StarletteUploadFile):
|
||||
raise ValueError(f"Expected UploadFile, received: {type(v)}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update({"type": "string", "format": "binary"})
|
||||
|
||||
|
||||
class DefaultPlaceholder:
|
||||
"""
|
||||
You shouldn't use this class directly.
|
||||
|
||||
It's used internally to recognize when a default value has been overwritten, even
|
||||
if the overridden default value was truthy.
|
||||
"""
|
||||
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.value)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return isinstance(o, DefaultPlaceholder) and o.value == self.value
|
||||
|
||||
|
||||
DefaultType = TypeVar("DefaultType")
|
||||
|
||||
|
||||
def Default(value: DefaultType) -> DefaultType:
|
||||
"""
|
||||
You shouldn't use this function directly.
|
||||
|
||||
It's used internally to recognize when a default value has been overwritten, even
|
||||
if the overridden default value was truthy.
|
||||
"""
|
||||
return DefaultPlaceholder(value) # type: ignore
|
@ -0,0 +1,58 @@
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from fastapi.security.base import SecurityBase
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
|
||||
class SecurityRequirement:
|
||||
def __init__(
|
||||
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
|
||||
):
|
||||
self.security_scheme = security_scheme
|
||||
self.scopes = scopes
|
||||
|
||||
|
||||
class Dependant:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
path_params: Optional[List[ModelField]] = None,
|
||||
query_params: Optional[List[ModelField]] = None,
|
||||
header_params: Optional[List[ModelField]] = None,
|
||||
cookie_params: Optional[List[ModelField]] = None,
|
||||
body_params: Optional[List[ModelField]] = None,
|
||||
dependencies: Optional[List["Dependant"]] = None,
|
||||
security_schemes: Optional[List[SecurityRequirement]] = None,
|
||||
name: Optional[str] = None,
|
||||
call: Optional[Callable[..., Any]] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
response_param_name: Optional[str] = None,
|
||||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path_params = path_params or []
|
||||
self.query_params = query_params or []
|
||||
self.header_params = header_params or []
|
||||
self.cookie_params = cookie_params or []
|
||||
self.body_params = body_params or []
|
||||
self.dependencies = dependencies or []
|
||||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.http_connection_param_name = http_connection_param_name
|
||||
self.response_param_name = response_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
# Save the cache key at creation to optimize performance
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
@ -0,0 +1,756 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from fastapi import params
|
||||
from fastapi.concurrency import (
|
||||
AsyncExitStack,
|
||||
asynccontextmanager,
|
||||
contextmanager_in_threadpool,
|
||||
)
|
||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||
from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.utils import create_response_field, get_path_param_names
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic.error_wrappers import ErrorWrapper
|
||||
from pydantic.errors import MissingError
|
||||
from pydantic.fields import (
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_LIST,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_SET,
|
||||
SHAPE_SINGLETON,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
FieldInfo,
|
||||
ModelField,
|
||||
Required,
|
||||
Undefined,
|
||||
)
|
||||
from pydantic.schema import get_annotation_from_field_info
|
||||
from pydantic.typing import ForwardRef, evaluate_forwardref
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
sequence_shapes = {
|
||||
SHAPE_LIST,
|
||||
SHAPE_SET,
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
}
|
||||
sequence_types = (list, set, tuple)
|
||||
sequence_shape_to_type = {
|
||||
SHAPE_LIST: list,
|
||||
SHAPE_SET: set,
|
||||
SHAPE_TUPLE: tuple,
|
||||
SHAPE_SEQUENCE: list,
|
||||
SHAPE_TUPLE_ELLIPSIS: list,
|
||||
}
|
||||
|
||||
|
||||
multipart_not_installed_error = (
|
||||
'Form data requires "python-multipart" to be installed. \n'
|
||||
'You can install "python-multipart" with: \n\n'
|
||||
"pip install python-multipart\n"
|
||||
)
|
||||
multipart_incorrect_install_error = (
|
||||
'Form data requires "python-multipart" to be installed. '
|
||||
'It seems you installed "multipart" instead. \n'
|
||||
'You can remove "multipart" with: \n\n'
|
||||
"pip uninstall multipart\n\n"
|
||||
'And then install "python-multipart" with: \n\n'
|
||||
"pip install python-multipart\n"
|
||||
)
|
||||
|
||||
|
||||
def check_file_field(field: ModelField) -> None:
|
||||
field_info = field.field_info
|
||||
if isinstance(field_info, params.Form):
|
||||
try:
|
||||
# __version__ is available in both multiparts, and can be mocked
|
||||
from multipart import __version__ # type: ignore
|
||||
|
||||
assert __version__
|
||||
try:
|
||||
# parse_options_header is only available in the right multipart
|
||||
from multipart.multipart import parse_options_header # type: ignore
|
||||
|
||||
assert parse_options_header
|
||||
except ImportError:
|
||||
logger.error(multipart_incorrect_install_error)
|
||||
raise RuntimeError(multipart_incorrect_install_error)
|
||||
except ImportError:
|
||||
logger.error(multipart_not_installed_error)
|
||||
raise RuntimeError(multipart_not_installed_error)
|
||||
|
||||
|
||||
def get_param_sub_dependant(
|
||||
*, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
|
||||
) -> Dependant:
|
||||
depends: params.Depends = param.default
|
||||
if depends.dependency:
|
||||
dependency = depends.dependency
|
||||
else:
|
||||
dependency = param.annotation
|
||||
return get_sub_dependant(
|
||||
depends=depends,
|
||||
dependency=dependency,
|
||||
path=path,
|
||||
name=param.name,
|
||||
security_scopes=security_scopes,
|
||||
)
|
||||
|
||||
|
||||
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
|
||||
assert callable(
|
||||
depends.dependency
|
||||
), "A parameter-less dependency must have a callable dependency"
|
||||
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
|
||||
|
||||
|
||||
def get_sub_dependant(
|
||||
*,
|
||||
depends: params.Depends,
|
||||
dependency: Callable[..., Any],
|
||||
path: str,
|
||||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
) -> Dependant:
|
||||
security_requirement = None
|
||||
security_scopes = security_scopes or []
|
||||
if isinstance(depends, params.Security):
|
||||
dependency_scopes = depends.scopes
|
||||
security_scopes.extend(dependency_scopes)
|
||||
if isinstance(dependency, SecurityBase):
|
||||
use_scopes: List[str] = []
|
||||
if isinstance(dependency, (OAuth2, OpenIdConnect)):
|
||||
use_scopes = security_scopes
|
||||
security_requirement = SecurityRequirement(
|
||||
security_scheme=dependency, scopes=use_scopes
|
||||
)
|
||||
sub_dependant = get_dependant(
|
||||
path=path,
|
||||
call=dependency,
|
||||
name=name,
|
||||
security_scopes=security_scopes,
|
||||
use_cache=depends.use_cache,
|
||||
)
|
||||
if security_requirement:
|
||||
sub_dependant.security_requirements.append(security_requirement)
|
||||
return sub_dependant
|
||||
|
||||
|
||||
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
|
||||
|
||||
|
||||
def get_flat_dependant(
|
||||
dependant: Dependant,
|
||||
*,
|
||||
skip_repeats: bool = False,
|
||||
visited: Optional[List[CacheKey]] = None,
|
||||
) -> Dependant:
|
||||
if visited is None:
|
||||
visited = []
|
||||
visited.append(dependant.cache_key)
|
||||
|
||||
flat_dependant = Dependant(
|
||||
path_params=dependant.path_params.copy(),
|
||||
query_params=dependant.query_params.copy(),
|
||||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_schemes=dependant.security_requirements.copy(),
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
)
|
||||
for sub_dependant in dependant.dependencies:
|
||||
if skip_repeats and sub_dependant.cache_key in visited:
|
||||
continue
|
||||
flat_sub = get_flat_dependant(
|
||||
sub_dependant, skip_repeats=skip_repeats, visited=visited
|
||||
)
|
||||
flat_dependant.path_params.extend(flat_sub.path_params)
|
||||
flat_dependant.query_params.extend(flat_sub.query_params)
|
||||
flat_dependant.header_params.extend(flat_sub.header_params)
|
||||
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
||||
flat_dependant.body_params.extend(flat_sub.body_params)
|
||||
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
|
||||
return flat_dependant
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
return (
|
||||
flat_dependant.path_params
|
||||
+ flat_dependant.query_params
|
||||
+ flat_dependant.header_params
|
||||
+ flat_dependant.cookie_params
|
||||
)
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
field_info = field.field_info
|
||||
if not (
|
||||
field.shape == SHAPE_SINGLETON
|
||||
and not lenient_issubclass(field.type_, BaseModel)
|
||||
and not lenient_issubclass(field.type_, sequence_types + (dict,))
|
||||
and not dataclasses.is_dataclass(field.type_)
|
||||
and not isinstance(field_info, params.Body)
|
||||
):
|
||||
return False
|
||||
if field.sub_fields:
|
||||
if not all(is_scalar_field(f) for f in field.sub_fields):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
if (field.shape in sequence_shapes) and not lenient_issubclass(
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is not None:
|
||||
for sub_field in field.sub_fields:
|
||||
if not is_scalar_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
if lenient_issubclass(field.type_, sequence_types):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_dependant(
|
||||
*,
|
||||
path: str,
|
||||
call: Callable[..., Any],
|
||||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
) -> Dependant:
|
||||
path_param_names = get_path_param_names(path)
|
||||
endpoint_signature = get_typed_signature(call)
|
||||
signature_params = endpoint_signature.parameters
|
||||
dependant = Dependant(
|
||||
call=call,
|
||||
name=name,
|
||||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
for param_name, param in signature_params.items():
|
||||
if isinstance(param.default, params.Depends):
|
||||
sub_dependant = get_param_sub_dependant(
|
||||
param=param, path=path, security_scopes=security_scopes
|
||||
)
|
||||
dependant.dependencies.append(sub_dependant)
|
||||
continue
|
||||
if add_non_field_param_to_dependency(param=param, dependant=dependant):
|
||||
continue
|
||||
param_field = get_param_field(
|
||||
param=param, default_field_info=params.Query, param_name=param_name
|
||||
)
|
||||
if param_name in path_param_names:
|
||||
assert is_scalar_field(
|
||||
field=param_field
|
||||
), "Path params must be of one of the supported types"
|
||||
ignore_default = not isinstance(param.default, params.Path)
|
||||
param_field = get_param_field(
|
||||
param=param,
|
||||
param_name=param_name,
|
||||
default_field_info=params.Path,
|
||||
force_type=params.ParamTypes.path,
|
||||
ignore_default=ignore_default,
|
||||
)
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
elif is_scalar_field(field=param_field):
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
elif isinstance(
|
||||
param.default, (params.Query, params.Header)
|
||||
) and is_scalar_sequence_field(param_field):
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
else:
|
||||
field_info = param_field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Body
|
||||
), f"Param: {param_field.name} can only be a request body, using Body()"
|
||||
dependant.body_params.append(param_field)
|
||||
return dependant
|
||||
|
||||
|
||||
def add_non_field_param_to_dependency(
|
||||
*, param: inspect.Parameter, dependant: Dependant
|
||||
) -> Optional[bool]:
|
||||
if lenient_issubclass(param.annotation, Request):
|
||||
dependant.request_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, WebSocket):
|
||||
dependant.websocket_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, HTTPConnection):
|
||||
dependant.http_connection_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, Response):
|
||||
dependant.response_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, BackgroundTasks):
|
||||
dependant.background_tasks_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, SecurityScopes):
|
||||
dependant.security_scopes_param_name = param.name
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
def get_param_field(
|
||||
*,
|
||||
param: inspect.Parameter,
|
||||
param_name: str,
|
||||
default_field_info: Type[params.Param] = params.Param,
|
||||
force_type: Optional[params.ParamTypes] = None,
|
||||
ignore_default: bool = False,
|
||||
) -> ModelField:
|
||||
default_value: Any = Undefined
|
||||
had_schema = False
|
||||
if not param.default == param.empty and ignore_default is False:
|
||||
default_value = param.default
|
||||
if isinstance(default_value, FieldInfo):
|
||||
had_schema = True
|
||||
field_info = default_value
|
||||
default_value = field_info.default
|
||||
if (
|
||||
isinstance(field_info, params.Param)
|
||||
and getattr(field_info, "in_", None) is None
|
||||
):
|
||||
field_info.in_ = default_field_info.in_
|
||||
if force_type:
|
||||
field_info.in_ = force_type # type: ignore
|
||||
else:
|
||||
field_info = default_field_info(default=default_value)
|
||||
required = True
|
||||
if default_value is Required or ignore_default:
|
||||
required = True
|
||||
default_value = None
|
||||
elif default_value is not Undefined:
|
||||
required = False
|
||||
annotation: Any = Any
|
||||
if not param.annotation == param.empty:
|
||||
annotation = param.annotation
|
||||
annotation = get_annotation_from_field_info(annotation, field_info, param_name)
|
||||
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
||||
alias = param.name.replace("_", "-")
|
||||
else:
|
||||
alias = field_info.alias or param.name
|
||||
field = create_response_field(
|
||||
name=param.name,
|
||||
type_=annotation,
|
||||
default=default_value,
|
||||
alias=alias,
|
||||
required=required,
|
||||
field_info=field_info,
|
||||
)
|
||||
if not had_schema and not is_scalar_field(field=field):
|
||||
field.field_info = params.Body(field_info.default)
|
||||
if not had_schema and lenient_issubclass(field.type_, UploadFile):
|
||||
field.field_info = params.File(field_info.default)
|
||||
|
||||
return field
|
||||
|
||||
|
||||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||
field_info = cast(params.Param, field.field_info)
|
||||
if field_info.in_ == params.ParamTypes.path:
|
||||
dependant.path_params.append(field)
|
||||
elif field_info.in_ == params.ParamTypes.query:
|
||||
dependant.query_params.append(field)
|
||||
elif field_info.in_ == params.ParamTypes.header:
|
||||
dependant.header_params.append(field)
|
||||
else:
|
||||
assert (
|
||||
field_info.in_ == params.ParamTypes.cookie
|
||||
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
|
||||
dependant.cookie_params.append(field)
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(call):
|
||||
return inspect.iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
call = getattr(call, "__call__", None)
|
||||
return inspect.iscoroutinefunction(call)
|
||||
|
||||
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(call):
|
||||
return True
|
||||
call = getattr(call, "__call__", None)
|
||||
return inspect.isasyncgenfunction(call)
|
||||
|
||||
|
||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(call):
|
||||
return True
|
||||
call = getattr(call, "__call__", None)
|
||||
return inspect.isgeneratorfunction(call)
|
||||
|
||||
|
||||
async def solve_generator(
|
||||
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
) -> Any:
|
||||
if is_gen_callable(call):
|
||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
||||
elif is_async_gen_callable(call):
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Union[Request, WebSocket],
|
||||
dependant: Dependant,
|
||||
body: Optional[Union[Dict[str, Any], FormData]] = None,
|
||||
background_tasks: Optional[BackgroundTasks] = None,
|
||||
response: Optional[Response] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[ErrorWrapper],
|
||||
Optional[BackgroundTasks],
|
||||
Response,
|
||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
||||
]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[ErrorWrapper] = []
|
||||
if response is None:
|
||||
response = Response()
|
||||
del response.headers["content-length"]
|
||||
response.status_code = None # type: ignore
|
||||
dependency_cache = dependency_cache or {}
|
||||
sub_dependant: Dependant
|
||||
for sub_dependant in dependant.dependencies:
|
||||
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
||||
sub_dependant.cache_key = cast(
|
||||
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
|
||||
)
|
||||
call = sub_dependant.call
|
||||
use_sub_dependant = sub_dependant
|
||||
if (
|
||||
dependency_overrides_provider
|
||||
and dependency_overrides_provider.dependency_overrides
|
||||
):
|
||||
original_call = sub_dependant.call
|
||||
call = getattr(
|
||||
dependency_overrides_provider, "dependency_overrides", {}
|
||||
).get(original_call, original_call)
|
||||
use_path: str = sub_dependant.path # type: ignore
|
||||
use_sub_dependant = get_dependant(
|
||||
path=use_path,
|
||||
call=call,
|
||||
name=sub_dependant.name,
|
||||
security_scopes=sub_dependant.security_scopes,
|
||||
)
|
||||
|
||||
solved_result = await solve_dependencies(
|
||||
request=request,
|
||||
dependant=use_sub_dependant,
|
||||
body=body,
|
||||
background_tasks=background_tasks,
|
||||
response=response,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
(
|
||||
sub_values,
|
||||
sub_errors,
|
||||
background_tasks,
|
||||
_, # the subdependency returns the same response we have
|
||||
sub_dependency_cache,
|
||||
) = solved_result
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
if sub_errors:
|
||||
errors.extend(sub_errors)
|
||||
continue
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
stack = request.scope.get("fastapi_astack")
|
||||
assert isinstance(stack, AsyncExitStack)
|
||||
solved = await solve_generator(
|
||||
call=call, stack=stack, sub_values=sub_values
|
||||
)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **sub_values)
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependant.cache_key] = solved
|
||||
path_values, path_errors = request_params_to_args(
|
||||
dependant.path_params, request.path_params
|
||||
)
|
||||
query_values, query_errors = request_params_to_args(
|
||||
dependant.query_params, request.query_params
|
||||
)
|
||||
header_values, header_errors = request_params_to_args(
|
||||
dependant.header_params, request.headers
|
||||
)
|
||||
cookie_values, cookie_errors = request_params_to_args(
|
||||
dependant.cookie_params, request.cookies
|
||||
)
|
||||
values.update(path_values)
|
||||
values.update(query_values)
|
||||
values.update(header_values)
|
||||
values.update(cookie_values)
|
||||
errors += path_errors + query_errors + header_errors + cookie_errors
|
||||
if dependant.body_params:
|
||||
(
|
||||
body_values,
|
||||
body_errors,
|
||||
) = await request_body_to_args( # body_params checked above
|
||||
required_params=dependant.body_params, received_body=body
|
||||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
if dependant.http_connection_param_name:
|
||||
values[dependant.http_connection_param_name] = request
|
||||
if dependant.request_param_name and isinstance(request, Request):
|
||||
values[dependant.request_param_name] = request
|
||||
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
||||
values[dependant.websocket_param_name] = request
|
||||
if dependant.background_tasks_param_name:
|
||||
if background_tasks is None:
|
||||
background_tasks = BackgroundTasks()
|
||||
values[dependant.background_tasks_param_name] = background_tasks
|
||||
if dependant.response_param_name:
|
||||
values[dependant.response_param_name] = response
|
||||
if dependant.security_scopes_param_name:
|
||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
)
|
||||
return values, errors, background_tasks, response, dependency_cache
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
required_params: Sequence[ModelField],
|
||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
||||
values = {}
|
||||
errors = []
|
||||
for field in required_params:
|
||||
if is_scalar_sequence_field(field) and isinstance(
|
||||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = received_params.getlist(field.alias) or field.default
|
||||
else:
|
||||
value = received_params.get(field.alias)
|
||||
field_info = field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
if value is None:
|
||||
if field.required:
|
||||
errors.append(
|
||||
ErrorWrapper(
|
||||
MissingError(), loc=(field_info.in_.value, field.alias)
|
||||
)
|
||||
)
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = field.validate(
|
||||
value, values, loc=(field_info.in_.value, field.alias)
|
||||
)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
required_params: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
||||
values = {}
|
||||
errors = []
|
||||
if required_params:
|
||||
field = required_params[0]
|
||||
field_info = field.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
field_alias_omitted = len(required_params) == 1 and not embed
|
||||
if field_alias_omitted:
|
||||
received_body = {field.alias: received_body}
|
||||
|
||||
for field in required_params:
|
||||
loc: Tuple[str, ...]
|
||||
if field_alias_omitted:
|
||||
loc = ("body",)
|
||||
else:
|
||||
loc = ("body", field.alias)
|
||||
|
||||
value: Optional[Any] = None
|
||||
if received_body is not None:
|
||||
if (
|
||||
field.shape in sequence_shapes or field.type_ in sequence_types
|
||||
) and isinstance(received_body, FormData):
|
||||
value = received_body.getlist(field.alias)
|
||||
else:
|
||||
try:
|
||||
value = received_body.get(field.alias)
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
value is None
|
||||
or (isinstance(field_info, params.Form) and value == "")
|
||||
or (
|
||||
isinstance(field_info, params.Form)
|
||||
and field.shape in sequence_shapes
|
||||
and len(value) == 0
|
||||
)
|
||||
):
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
and lenient_issubclass(field.type_, bytes)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
field.shape in sequence_shapes
|
||||
and isinstance(field_info, params.File)
|
||||
and lenient_issubclass(field.type_, bytes)
|
||||
and isinstance(value, sequence_types)
|
||||
):
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]]
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = sequence_shape_to_type[field.shape](results)
|
||||
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
|
||||
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
|
||||
return missing_field_error
|
||||
|
||||
|
||||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant)
|
||||
if not flat_dependant.body_params:
|
||||
return None
|
||||
first_param = flat_dependant.body_params[0]
|
||||
field_info = first_param.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
||||
if len(body_param_names_set) == 1 and not embed:
|
||||
check_file_field(first_param)
|
||||
return first_param
|
||||
# If one field requires to embed, all have to be embedded
|
||||
# in case a sub-dependency is evaluated with a single unique body field
|
||||
# That is combined (embedded) with other body fields
|
||||
for param in flat_dependant.body_params:
|
||||
setattr(param.field_info, "embed", True)
|
||||
model_name = "Body_" + name
|
||||
BodyModel: Type[BaseModel] = create_model(model_name)
|
||||
for f in flat_dependant.body_params:
|
||||
BodyModel.__fields__[f.name] = f
|
||||
required = any(True for f in flat_dependant.body_params if f.required)
|
||||
|
||||
BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
|
||||
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
|
||||
BodyFieldInfo: Type[params.Body] = params.File
|
||||
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
|
||||
BodyFieldInfo = params.Form
|
||||
else:
|
||||
BodyFieldInfo = params.Body
|
||||
|
||||
body_param_media_types = [
|
||||
getattr(f.field_info, "media_type")
|
||||
for f in flat_dependant.body_params
|
||||
if isinstance(f.field_info, params.Body)
|
||||
]
|
||||
if len(set(body_param_media_types)) == 1:
|
||||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
||||
final_field = create_response_field(
|
||||
name="body",
|
||||
type_=BodyModel,
|
||||
required=required,
|
||||
alias="body",
|
||||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
check_file_field(final_field)
|
||||
return final_field
|
@ -0,0 +1,167 @@
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from pathlib import PurePath
|
||||
from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json import ENCODERS_BY_TYPE
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]]
|
||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
||||
tuple
|
||||
)
|
||||
for type_, encoder in type_encoder_map.items():
|
||||
encoders_by_class_tuples[encoder] += (type_,)
|
||||
return encoders_by_class_tuples
|
||||
|
||||
|
||||
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
|
||||
sqlalchemy_safe: bool = True,
|
||||
) -> Any:
|
||||
custom_encoder = custom_encoder or {}
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
return custom_encoder[type(obj)](obj)
|
||||
else:
|
||||
for encoder_type, encoder_instance in custom_encoder.items():
|
||||
if isinstance(obj, encoder_type):
|
||||
return encoder_instance(obj)
|
||||
if include is not None and not isinstance(include, (set, dict)):
|
||||
include = set(include)
|
||||
if exclude is not None and not isinstance(exclude, (set, dict)):
|
||||
exclude = set(exclude)
|
||||
if isinstance(obj, BaseModel):
|
||||
encoder = getattr(obj.__config__, "json_encoders", {})
|
||||
if custom_encoder:
|
||||
encoder.update(custom_encoder)
|
||||
obj_dict = obj.dict(
|
||||
include=include, # type: ignore # in Pydantic
|
||||
exclude=exclude, # type: ignore # in Pydantic
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
)
|
||||
if "__root__" in obj_dict:
|
||||
obj_dict = obj_dict["__root__"]
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
custom_encoder=encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
if isinstance(obj, PurePath):
|
||||
return str(obj)
|
||||
if isinstance(obj, (str, int, float, type(None))):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
allowed_keys = set(obj.keys())
|
||||
if include is not None:
|
||||
allowed_keys &= set(include)
|
||||
if exclude is not None:
|
||||
allowed_keys -= set(exclude)
|
||||
for key, value in obj.items():
|
||||
if (
|
||||
(
|
||||
not sqlalchemy_safe
|
||||
or (not isinstance(key, str))
|
||||
or (not key.startswith("_sa"))
|
||||
)
|
||||
and (value is not None or not exclude_none)
|
||||
and key in allowed_keys
|
||||
):
|
||||
encoded_key = jsonable_encoder(
|
||||
key,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
encoded_value = jsonable_encoder(
|
||||
value,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
encoded_dict[encoded_key] = encoded_value
|
||||
return encoded_dict
|
||||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
|
||||
encoded_list = []
|
||||
for item in obj:
|
||||
encoded_list.append(
|
||||
jsonable_encoder(
|
||||
item,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
)
|
||||
return encoded_list
|
||||
|
||||
if type(obj) in ENCODERS_BY_TYPE:
|
||||
return ENCODERS_BY_TYPE[type(obj)](obj)
|
||||
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
||||
if isinstance(obj, classes_tuple):
|
||||
return encoder(obj)
|
||||
|
||||
try:
|
||||
data = dict(obj)
|
||||
except Exception as e:
|
||||
errors: List[Exception] = []
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(errors)
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
@ -0,0 +1,25 @@
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
headers = getattr(exc, "headers", None)
|
||||
if headers:
|
||||
return JSONResponse(
|
||||
{"detail": exc.detail}, status_code=exc.status_code, headers=headers
|
||||
)
|
||||
else:
|
||||
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
||||
|
||||
|
||||
async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={"detail": jsonable_encoder(exc.errors())},
|
||||
)
|
@ -0,0 +1,36 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Type
|
||||
|
||||
from pydantic import BaseModel, ValidationError, create_model
|
||||
from pydantic.error_wrappers import ErrorList
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
|
||||
class HTTPException(StarletteHTTPException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: Any = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(status_code=status_code, detail=detail, headers=headers)
|
||||
|
||||
|
||||
RequestErrorModel: Type[BaseModel] = create_model("Request")
|
||||
WebSocketErrorModel: Type[BaseModel] = create_model("WebSocket")
|
||||
|
||||
|
||||
class FastAPIError(RuntimeError):
|
||||
"""
|
||||
A generic, FastAPI-specific error.
|
||||
"""
|
||||
|
||||
|
||||
class RequestValidationError(ValidationError):
|
||||
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
|
||||
self.body = body
|
||||
super().__init__(errors, RequestErrorModel)
|
||||
|
||||
|
||||
class WebSocketRequestValidationError(ValidationError):
|
||||
def __init__(self, errors: Sequence[ErrorList]) -> None:
|
||||
super().__init__(errors, WebSocketErrorModel)
|
@ -0,0 +1,3 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("fastapi")
|
@ -0,0 +1 @@
|
||||
from starlette.middleware import Middleware as Middleware
|
@ -0,0 +1,28 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.concurrency import AsyncExitStack
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AsyncExitStackMiddleware:
|
||||
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
|
||||
self.app = app
|
||||
self.context_name = context_name
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if AsyncExitStack:
|
||||
dependency_exception: Optional[Exception] = None
|
||||
async with AsyncExitStack() as stack:
|
||||
scope[self.context_name] = stack
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
dependency_exception = e
|
||||
raise e
|
||||
if dependency_exception:
|
||||
# This exception was possibly handled by the dependency but it should
|
||||
# still bubble up so that the ServerErrorMiddleware can return a 500
|
||||
# or the ExceptionMiddleware can catch and handle any other exceptions
|
||||
raise dependency_exception
|
||||
else:
|
||||
await self.app(scope, receive, send) # pragma: no cover
|
@ -0,0 +1 @@
|
||||
from starlette.middleware.cors import CORSMiddleware as CORSMiddleware # noqa
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue