You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.3 KiB
113 lines
3.3 KiB
import datetime
|
|
from collections import deque
|
|
from decimal import Decimal
|
|
from enum import Enum
|
|
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
|
from pathlib import Path
|
|
from re import Pattern
|
|
from types import GeneratorType
|
|
from typing import Any, Callable, Dict, Type, Union
|
|
from uuid import UUID
|
|
|
|
from .color import Color
|
|
from .networks import NameEmail
|
|
from .types import SecretBytes, SecretStr
|
|
|
|
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
|
|
|
|
|
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
|
return o.isoformat()
|
|
|
|
|
|
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
|
"""
|
|
Encodes a Decimal as int of there's no exponent, otherwise float
|
|
|
|
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
|
where a integer (but not int typed) is used. Encoding this as a float
|
|
results in failed round-tripping between encode and parse.
|
|
Our Id type is a prime example of this.
|
|
|
|
>>> decimal_encoder(Decimal("1.0"))
|
|
1.0
|
|
|
|
>>> decimal_encoder(Decimal("1"))
|
|
1
|
|
"""
|
|
if dec_value.as_tuple().exponent >= 0:
|
|
return int(dec_value)
|
|
else:
|
|
return float(dec_value)
|
|
|
|
|
|
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
|
bytes: lambda o: o.decode(),
|
|
Color: str,
|
|
datetime.date: isoformat,
|
|
datetime.datetime: isoformat,
|
|
datetime.time: isoformat,
|
|
datetime.timedelta: lambda td: td.total_seconds(),
|
|
Decimal: decimal_encoder,
|
|
Enum: lambda o: o.value,
|
|
frozenset: list,
|
|
deque: list,
|
|
GeneratorType: list,
|
|
IPv4Address: str,
|
|
IPv4Interface: str,
|
|
IPv4Network: str,
|
|
IPv6Address: str,
|
|
IPv6Interface: str,
|
|
IPv6Network: str,
|
|
NameEmail: str,
|
|
Path: str,
|
|
Pattern: lambda o: o.pattern,
|
|
SecretBytes: str,
|
|
SecretStr: str,
|
|
set: list,
|
|
UUID: str,
|
|
}
|
|
|
|
|
|
def pydantic_encoder(obj: Any) -> Any:
|
|
from dataclasses import asdict, is_dataclass
|
|
|
|
from .main import BaseModel
|
|
|
|
if isinstance(obj, BaseModel):
|
|
return obj.dict()
|
|
elif is_dataclass(obj):
|
|
return asdict(obj)
|
|
|
|
# Check the class type and its superclasses for a matching encoder
|
|
for base in obj.__class__.__mro__[:-1]:
|
|
try:
|
|
encoder = ENCODERS_BY_TYPE[base]
|
|
except KeyError:
|
|
continue
|
|
return encoder(obj)
|
|
else: # We have exited the for loop without finding a suitable encoder
|
|
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
|
|
|
|
|
|
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
|
# Check the class type and its superclasses for a matching encoder
|
|
for base in obj.__class__.__mro__[:-1]:
|
|
try:
|
|
encoder = type_encoders[base]
|
|
except KeyError:
|
|
continue
|
|
|
|
return encoder(obj)
|
|
else: # We have exited the for loop without finding a suitable encoder
|
|
return pydantic_encoder(obj)
|
|
|
|
|
|
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
|
"""
|
|
ISO 8601 encoding for Python timedelta object.
|
|
"""
|
|
minutes, seconds = divmod(td.seconds, 60)
|
|
hours, minutes = divmod(minutes, 60)
|
|
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|