Edit on GitHub

serde.compat

Compatibility layer which handles mostly differences of typing module between python versions.

   1"""
   2Compatibility layer which handles mostly differences of `typing` module between python versions.
   3"""
   4
   5import dataclasses
   6import datetime
   7import decimal
   8import enum
   9import functools
  10import ipaddress
  11import itertools
  12import pathlib
  13import types
  14import uuid
  15import typing
  16import typing_extensions
  17from collections import defaultdict, deque, Counter
  18from collections.abc import Iterator, Sequence, MutableSequence
  19from collections.abc import Mapping, MutableMapping, Set, MutableSet
  20from dataclasses import is_dataclass
  21from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable
  22
  23import typing_inspect
  24from typing_extensions import TypeGuard, ParamSpec
  25
  26# `typing_extensions.TypeAliasType` isn't always an alias to `typing.TypeAliasType`
  27# depending on certain versions of `typing_extensions` and python.
  28_PEP695_TYPES: tuple[type, ...]
  29if hasattr(typing, "TypeAliasType"):
  30    _PEP695_TYPES = (typing_extensions.TypeAliasType, typing.TypeAliasType)
  31else:
  32    _PEP695_TYPES = (typing_extensions.TypeAliasType,)
  33
  34
  35# Lazy SQLAlchemy imports to improve startup time
  36
  37
  38# Lazy SQLAlchemy import wrapper to improve startup time
  39def _is_sqlalchemy_inspectable(subject: Any) -> bool:
  40    from .sqlalchemy import is_sqlalchemy_inspectable
  41
  42    return is_sqlalchemy_inspectable(subject)
  43
  44
  45def get_np_origin(tp: type[Any]) -> Any | None:
  46    return None
  47
  48
  49def get_np_args(tp: type[Any]) -> tuple[Any, ...]:
  50    return ()
  51
  52
  53__all__: list[str] = []
  54
  55T = TypeVar("T")
  56
  57
  58StrSerializableTypes = (
  59    decimal.Decimal,
  60    pathlib.Path,
  61    pathlib.PosixPath,
  62    pathlib.WindowsPath,
  63    pathlib.PurePath,
  64    pathlib.PurePosixPath,
  65    pathlib.PureWindowsPath,
  66    uuid.UUID,
  67    ipaddress.IPv4Address,
  68    ipaddress.IPv6Address,
  69    ipaddress.IPv4Network,
  70    ipaddress.IPv6Network,
  71    ipaddress.IPv4Interface,
  72    ipaddress.IPv6Interface,
  73)
  74""" List of standard types (de)serializable to str """
  75
  76DateTimeTypes = (datetime.date, datetime.time, datetime.datetime)
  77""" List of datetime types """
  78
  79
  80@dataclasses.dataclass(unsafe_hash=True)
  81class _WithTagging(Generic[T]):
  82    """
  83    Intermediate data structure for (de)serializaing Union without dataclass.
  84    """
  85
  86    inner: T
  87    """ Union type .e.g Union[Foo,Bar] passed in from_obj. """
  88    tagging: Any
  89    """ Union Tagging """
  90
  91
  92class SerdeError(Exception):
  93    """
  94    Serde error class.
  95    """
  96
  97
  98@dataclasses.dataclass
  99class UserError(Exception):
 100    """
 101    Error from user code e.g. __post_init__.
 102    """
 103
 104    inner: Exception
 105
 106
 107class SerdeSkip(Exception):
 108    """
 109    Skip a field in custom (de)serializer.
 110    """
 111
 112
 113def is_hashable(typ: Any) -> TypeGuard[Hashable]:
 114    """
 115    Test is an object hashable
 116    """
 117    try:
 118        hash(typ)
 119    except TypeError:
 120        return False
 121    return True
 122
 123
 124P = ParamSpec("P")
 125
 126
 127def cache(f: Callable[P, T]) -> Callable[P, T]:
 128    """
 129    Wrapper for `functools.cache` to avoid `Hashable` related type errors.
 130    """
 131
 132    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
 133        return f(*args, **kwargs)
 134
 135    return functools.cache(wrapper)  # type: ignore
 136
 137
 138@cache
 139def get_origin(typ: Any) -> Any | None:
 140    """
 141    Provide `get_origin` that works in all python versions.
 142    """
 143    return typing.get_origin(typ) or get_np_origin(typ)
 144
 145
 146@cache
 147def get_args(typ: type[Any]) -> tuple[Any, ...]:
 148    """
 149    Provide `get_args` that works in all python versions.
 150    """
 151    return typing.get_args(typ) or get_np_args(typ)
 152
 153
 154@cache
 155def typename(typ: Any, with_typing_module: bool = False) -> str:
 156    """
 157    >>> from typing import Any
 158    >>> typename(int)
 159    'int'
 160    >>> class Foo: pass
 161    >>> typename(Foo)
 162    'Foo'
 163    >>> typename(list[Foo])
 164    'list[Foo]'
 165    >>> typename(dict[str, Foo])
 166    'dict[str, Foo]'
 167    >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]])
 168    'tuple[int, str, Foo, list[int], dict[str, Foo]]'
 169    >>> typename(Optional[list[Foo]])
 170    'Optional[list[Foo]]'
 171    >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]])
 172    'Union[Optional[Foo], list[Foo], str, int]'
 173    >>> typename(set[Foo])
 174    'set[Foo]'
 175    >>> typename(Any)
 176    'Any'
 177    """
 178    mod = "typing." if with_typing_module else ""
 179    thisfunc = functools.partial(typename, with_typing_module=with_typing_module)
 180    if is_opt(typ):
 181        args = type_args(typ)
 182        if args:
 183            return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]"
 184        else:
 185            return f"{mod}Optional"
 186    elif is_union(typ):
 187        args = union_args(typ)
 188        if args:
 189            return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]'
 190        else:
 191            return f"{mod}Union"
 192    elif is_list(typ):
 193        args = type_args(typ)
 194        if args:
 195            et = thisfunc(args[0])
 196            return f"{mod}list[{et}]"
 197        else:
 198            return f"{mod}list"
 199    elif is_set(typ):
 200        args = type_args(typ)
 201        if args:
 202            et = thisfunc(args[0])
 203            return f"{mod}set[{et}]"
 204        else:
 205            return f"{mod}set"
 206    elif is_dict(typ):
 207        args = type_args(typ)
 208        if args and len(args) == 2:
 209            kt = thisfunc(args[0])
 210            vt = thisfunc(args[1])
 211            return f"{mod}dict[{kt}, {vt}]"
 212        else:
 213            return f"{mod}dict"
 214    elif is_deque(typ):
 215        args = type_args(typ)
 216        if args:
 217            et = thisfunc(args[0])
 218            return f"deque[{et}]"
 219        else:
 220            return "deque"
 221    elif is_counter(typ):
 222        args = type_args(typ)
 223        if args:
 224            et = thisfunc(args[0])
 225            return f"Counter[{et}]"
 226        else:
 227            return "Counter"
 228    elif is_tuple(typ):
 229        args = type_args(typ)
 230        if args:
 231            return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]'
 232        else:
 233            return f"{mod}tuple"
 234    elif is_generic(typ):
 235        origin = get_origin(typ)
 236        if origin is None:
 237            raise SerdeError("Could not extract origin class from generic class")
 238
 239        if not isinstance(origin.__name__, str):
 240            raise SerdeError("Name of generic class is not string")
 241
 242        return origin.__name__
 243
 244    elif is_literal(typ):
 245        args = type_args(typ)
 246        if not args:
 247            raise TypeError("Literal type requires at least one literal argument")
 248        return f'Literal[{", ".join(stringify_literal(e) for e in args)}]'
 249    elif typ is Any:
 250        return f"{mod}Any"
 251    elif is_ellipsis(typ):
 252        return "..."
 253    else:
 254        # Get super type for NewType
 255        inner = getattr(typ, "__supertype__", None)
 256        if inner:
 257            return typename(inner)
 258
 259        name: str | None = getattr(typ, "_name", None)
 260        if name:
 261            return name
 262        else:
 263            name = getattr(typ, "__name__", None)
 264            if isinstance(name, str):
 265                return name
 266            else:
 267                raise SerdeError(f"Could not get a type name from: {typ}")
 268
 269
 270def stringify_literal(v: Any) -> str:
 271    if isinstance(v, str):
 272        return f"'{v}'"
 273    else:
 274        return str(v)
 275
 276
 277def type_args(typ: Any) -> tuple[type[Any], ...]:
 278    """
 279    Wrapper to suppress type error for accessing private members.
 280    """
 281    try:
 282        args: tuple[type[Any, ...]] | None = typ.__args__  # type: ignore
 283    except AttributeError:
 284        return get_args(typ)
 285
 286    # Some typing objects expose __args__ as a member_descriptor (e.g. typing.Union),
 287    # which isn't iterable. Fall back to typing.get_args in that case.
 288    if isinstance(args, tuple):
 289        return args
 290    return get_args(typ)
 291
 292
 293def union_args(typ: Any) -> tuple[type[Any], ...]:
 294    if not is_union(typ):
 295        raise TypeError(f"{typ} is not Union")
 296    args = type_args(typ)
 297    if len(args) == 1:
 298        return (args[0],)
 299    it = iter(args)
 300    types = []
 301    for i1, i2 in itertools.zip_longest(it, it):
 302        if not i2:
 303            types.append(i1)
 304        elif is_none(i2):
 305            types.append(Optional[i1])
 306        else:
 307            types.extend((i1, i2))
 308    return tuple(types)
 309
 310
 311def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]:  # type: ignore
 312    raw_fields = dataclasses.fields(cls)
 313
 314    try:
 315        # this resolves types when string forward reference
 316        # or PEP 563: "from __future__ import annotations" are used
 317        resolved_hints = typing.get_type_hints(cls)
 318    except Exception as e:
 319        raise SerdeError(
 320            f"Failed to resolve type hints for {typename(cls)}:\n"
 321            f"{e.__class__.__name__}: {e}\n\n"
 322            f"If you are using forward references make sure you are calling deserialize & "
 323            "serialize after all classes are globally visible."
 324        ) from e
 325
 326    for f in raw_fields:
 327        real_type = resolved_hints.get(f.name)
 328        if real_type is not None:
 329            f.type = real_type
 330            if is_generic(real_type) and _is_sqlalchemy_inspectable(cls):
 331                f.type = get_args(real_type)[0]
 332
 333    return iter(raw_fields)
 334
 335
 336TypeLike = Union[type[Any], typing.Any]
 337
 338
 339def iter_types(cls: type[Any]) -> list[type[Any]]:
 340    """
 341    Iterate field types recursively.
 342
 343    The correct return type is `Iterator[Union[Type, typing._specialform]],
 344    but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
 345    """
 346    lst: set[Union[type[Any], Any]] = set()
 347
 348    def recursive(cls: Union[type[Any], Any]) -> None:
 349        if cls in lst:
 350            return
 351
 352        if is_dataclass(cls):
 353            lst.add(cls)
 354            if isinstance(cls, type):
 355                for f in dataclass_fields(cls):
 356                    recursive(f.type)
 357        elif isinstance(cls, str):
 358            lst.add(cls)
 359        elif is_opt(cls):
 360            lst.add(Optional)
 361            args = type_args(cls)
 362            if args:
 363                recursive(args[0])
 364        elif is_union(cls):
 365            lst.add(Union)
 366            for arg in type_args(cls):
 367                recursive(arg)
 368        elif is_list(cls):
 369            lst.add(list)
 370            args = type_args(cls)
 371            if args:
 372                recursive(args[0])
 373        elif is_set(cls):
 374            lst.add(set)
 375            args = type_args(cls)
 376            if args:
 377                recursive(args[0])
 378        elif is_deque(cls):
 379            lst.add(deque)
 380            args = type_args(cls)
 381            if args:
 382                recursive(args[0])
 383        elif is_counter(cls):
 384            lst.add(Counter)
 385            args = type_args(cls)
 386            if args:
 387                recursive(args[0])
 388        elif is_tuple(cls):
 389            lst.add(tuple)
 390            for arg in type_args(cls):
 391                recursive(arg)
 392        elif is_dict(cls):
 393            lst.add(dict)
 394            args = type_args(cls)
 395            if args and len(args) >= 2:
 396                recursive(args[0])
 397                recursive(args[1])
 398        elif is_pep695_type_alias(cls):
 399            recursive(cls.__value__)
 400        else:
 401            lst.add(cls)
 402
 403    recursive(cls)
 404    return list(lst)
 405
 406
 407def iter_unions(cls: TypeLike) -> list[TypeLike]:
 408    """
 409    Iterate over all unions that are used in the dataclass
 410    """
 411    lst: list[TypeLike] = []
 412    stack: list[TypeLike] = []  # To prevent infinite recursion
 413
 414    def recursive(cls: TypeLike) -> None:
 415        if cls in stack:
 416            return
 417
 418        if is_union(cls):
 419            lst.append(cls)
 420            for arg in type_args(cls):
 421                recursive(arg)
 422        elif is_pep695_type_alias(cls):
 423            recursive(cls.__value__)
 424        if is_dataclass(cls):
 425            stack.append(cls)
 426            if isinstance(cls, type):
 427                for f in dataclass_fields(cls):
 428                    recursive(f.type)
 429            stack.pop()
 430        elif is_opt(cls):
 431            args = type_args(cls)
 432            if args:
 433                recursive(args[0])
 434        elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls):
 435            args = type_args(cls)
 436            if args:
 437                recursive(args[0])
 438        elif is_tuple(cls):
 439            for arg in type_args(cls):
 440                recursive(arg)
 441        elif is_dict(cls):
 442            args = type_args(cls)
 443            if args and len(args) >= 2:
 444                recursive(args[0])
 445                recursive(args[1])
 446
 447    recursive(cls)
 448    return lst
 449
 450
 451def iter_literals(cls: type[Any]) -> list[TypeLike]:
 452    """
 453    Iterate over all literals that are used in the dataclass
 454    """
 455    lst: set[Union[type[Any], Any]] = set()
 456    stack: list[TypeLike] = []  # To prevent infinite recursion
 457
 458    def recursive(cls: Union[type[Any], Any]) -> None:
 459        if cls in stack:
 460            return
 461
 462        if is_literal(cls):
 463            lst.add(cls)
 464        if is_union(cls):
 465            for arg in type_args(cls):
 466                recursive(arg)
 467        if is_dataclass(cls):
 468            stack.append(cls)
 469            if isinstance(cls, type):
 470                for f in dataclass_fields(cls):
 471                    recursive(f.type)
 472            stack.pop()
 473        elif is_opt(cls):
 474            args = type_args(cls)
 475            if args:
 476                recursive(args[0])
 477        elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls):
 478            args = type_args(cls)
 479            if args:
 480                recursive(args[0])
 481        elif is_tuple(cls):
 482            for arg in type_args(cls):
 483                recursive(arg)
 484        elif is_dict(cls):
 485            args = type_args(cls)
 486            if args and len(args) >= 2:
 487                recursive(args[0])
 488                recursive(args[1])
 489
 490    recursive(cls)
 491    return list(lst)
 492
 493
 494@cache
 495def is_union(typ: Any) -> bool:
 496    """
 497    Test if the type is `typing.Union`.
 498
 499    >>> is_union(Union[int, str])
 500    True
 501    """
 502
 503    try:
 504        # When `_WithTagging` is received, it will check inner type.
 505        if isinstance(typ, _WithTagging):
 506            return is_union(typ.inner)
 507    except Exception:
 508        pass
 509
 510    # Python 3.10+ Union operator e.g. str | int
 511    try:
 512        if isinstance(typ, types.UnionType):
 513            return True
 514    except Exception:
 515        pass
 516
 517    # typing.Union
 518    return typing_inspect.is_union_type(typ)  # type: ignore
 519
 520
 521@cache
 522def is_opt(typ: Any) -> bool:
 523    """
 524    Test if the type is `typing.Optional`.
 525
 526    >>> is_opt(Optional[int])
 527    True
 528    >>> is_opt(Optional)
 529    True
 530    >>> is_opt(None.__class__)
 531    False
 532    """
 533
 534    # Python 3.10+ Union operator e.g. str | None
 535    is_union_type = False
 536    try:
 537        if isinstance(typ, types.UnionType):
 538            is_union_type = True
 539    except Exception:
 540        pass
 541
 542    # typing.Optional
 543    is_typing_union = typing_inspect.is_optional_type(typ)
 544
 545    args = type_args(typ)
 546    if args:
 547        return (
 548            (is_union_type or is_typing_union)
 549            and len(args) == 2
 550            and not is_none(args[0])
 551            and is_none(args[1])
 552        )
 553    else:
 554        return typ is Optional
 555
 556
 557@cache
 558def is_bare_opt(typ: Any) -> bool:
 559    """
 560    Test if the type is `typing.Optional` without type args.
 561    >>> is_bare_opt(Optional[int])
 562    False
 563    >>> is_bare_opt(Optional)
 564    True
 565    >>> is_bare_opt(None.__class__)
 566    False
 567    """
 568    return not type_args(typ) and typ is Optional
 569
 570
 571@cache
 572def is_opt_dataclass(typ: Any) -> bool:
 573    """
 574    Test if the type is optional dataclass.
 575
 576    >>> is_opt_dataclass(Optional[int])
 577    False
 578    >>> @dataclasses.dataclass
 579    ... class Foo:
 580    ...     pass
 581    >>> is_opt_dataclass(Foo)
 582    False
 583    >>> is_opt_dataclass(Optional[Foo])
 584    False
 585    """
 586    args = get_args(typ)
 587    return is_opt(typ) and len(args) > 0 and is_dataclass(args[0])
 588
 589
 590@cache
 591def is_list(typ: type[Any]) -> bool:
 592    """
 593    Test if the type is `list`, `collections.abc.Sequence`, or `collections.abc.MutableSequence`.
 594
 595    >>> is_list(list[int])
 596    True
 597    >>> is_list(list)
 598    True
 599    >>> is_list(Sequence[int])
 600    True
 601    >>> is_list(Sequence)
 602    True
 603    >>> is_list(MutableSequence[int])
 604    True
 605    >>> is_list(MutableSequence)
 606    True
 607    """
 608    origin = get_origin(typ)
 609    if origin is None:
 610        return typ in (list, Sequence, MutableSequence)
 611    return origin in (list, Sequence, MutableSequence)
 612
 613
 614@cache
 615def is_bare_list(typ: type[Any]) -> bool:
 616    """
 617    Test if the type is `list`/`collections.abc.Sequence`/`collections.abc.MutableSequence`
 618    without type args.
 619
 620    >>> is_bare_list(list[int])
 621    False
 622    >>> is_bare_list(list)
 623    True
 624    >>> is_bare_list(Sequence[int])
 625    False
 626    >>> is_bare_list(Sequence)
 627    True
 628    >>> is_bare_list(MutableSequence[int])
 629    False
 630    >>> is_bare_list(MutableSequence)
 631    True
 632    """
 633    origin = get_origin(typ)
 634    if origin in (list, Sequence, MutableSequence):
 635        return not type_args(typ)
 636    return typ in (list, Sequence, MutableSequence)
 637
 638
 639@cache
 640def is_tuple(typ: Any) -> bool:
 641    """
 642    Test if the type is tuple.
 643    """
 644    try:
 645        return issubclass(get_origin(typ), tuple)  # type: ignore
 646    except TypeError:
 647        return typ is tuple
 648
 649
 650@cache
 651def is_bare_tuple(typ: type[Any]) -> bool:
 652    """
 653    Test if the type is tuple without type args.
 654
 655    >>> is_bare_tuple(tuple[int, str])
 656    False
 657    >>> is_bare_tuple(tuple)
 658    True
 659    """
 660    return typ is tuple
 661
 662
 663@cache
 664def is_variable_tuple(typ: type[Any]) -> bool:
 665    """
 666    Test if the type is a variable length of tuple tuple[T, ...]`.
 667
 668    >>> is_variable_tuple(tuple[int, ...])
 669    True
 670    >>> is_variable_tuple(tuple[int, bool])
 671    False
 672    >>> is_variable_tuple(tuple[()])
 673    False
 674    """
 675    istuple = is_tuple(typ) and not is_bare_tuple(typ)
 676    args = get_args(typ)
 677    return istuple and len(args) == 2 and is_ellipsis(args[1])
 678
 679
 680@cache
 681def is_set(typ: type[Any]) -> bool:
 682    """
 683    Test if the type is set-like.
 684
 685    >>> is_set(set[int])
 686    True
 687    >>> is_set(set)
 688    True
 689    >>> is_set(frozenset[int])
 690    True
 691    >>> from collections.abc import Set, MutableSet
 692    >>> is_set(Set[int])
 693    True
 694    >>> is_set(Set)
 695    True
 696    >>> is_set(MutableSet[int])
 697    True
 698    >>> is_set(MutableSet)
 699    True
 700    """
 701    try:
 702        return issubclass(get_origin(typ), (set, frozenset, Set, MutableSet))  # type: ignore[arg-type]
 703    except TypeError:
 704        return typ in (set, frozenset, Set, MutableSet)
 705
 706
 707@cache
 708def is_bare_set(typ: type[Any]) -> bool:
 709    """
 710    Test if the type is `set`/`frozenset`/`Set`/`MutableSet` without type args.
 711
 712    >>> is_bare_set(set[int])
 713    False
 714    >>> is_bare_set(set)
 715    True
 716    >>> from collections.abc import Set, MutableSet
 717    >>> is_bare_set(Set)
 718    True
 719    >>> is_bare_set(MutableSet)
 720    True
 721    """
 722    origin = get_origin(typ)
 723    if origin in (set, frozenset, Set, MutableSet):
 724        return not type_args(typ)
 725    return typ in (set, frozenset, Set, MutableSet)
 726
 727
 728@cache
 729def is_frozen_set(typ: type[Any]) -> bool:
 730    """
 731    Test if the type is `frozenset`.
 732
 733    >>> is_frozen_set(frozenset[int])
 734    True
 735    >>> is_frozen_set(set)
 736    False
 737    """
 738    try:
 739        return issubclass(get_origin(typ), frozenset)  # type: ignore
 740    except TypeError:
 741        return typ is frozenset
 742
 743
 744@cache
 745def is_dict(typ: type[Any]) -> bool:
 746    """
 747    Test if the type is dict-like.
 748
 749    >>> is_dict(dict[int, int])
 750    True
 751    >>> is_dict(dict)
 752    True
 753    >>> is_dict(defaultdict[int, int])
 754    True
 755    >>> from collections.abc import Mapping, MutableMapping
 756    >>> is_dict(Mapping[str, int])
 757    True
 758    >>> is_dict(Mapping)
 759    True
 760    >>> is_dict(MutableMapping[str, int])
 761    True
 762    >>> is_dict(MutableMapping)
 763    True
 764    """
 765    try:
 766        return issubclass(
 767            get_origin(typ), (dict, defaultdict, Mapping, MutableMapping)  # type: ignore[arg-type]
 768        )
 769    except TypeError:
 770        return typ in (dict, defaultdict, Mapping, MutableMapping)
 771
 772
 773@cache
 774@cache
 775def is_bare_dict(typ: type[Any]) -> bool:
 776    """
 777    Test if the type is `dict`/`Mapping`/`MutableMapping` without type args.
 778
 779    >>> is_bare_dict(dict[int, str])
 780    False
 781    >>> is_bare_dict(dict)
 782    True
 783    >>> from collections.abc import Mapping, MutableMapping
 784    >>> is_bare_dict(Mapping)
 785    True
 786    >>> is_bare_dict(MutableMapping)
 787    True
 788    """
 789    origin = get_origin(typ)
 790    if origin in (dict, Mapping, MutableMapping):
 791        return not type_args(typ)
 792    return typ in (dict, Mapping, MutableMapping)
 793
 794
 795@cache
 796def is_default_dict(typ: type[Any]) -> bool:
 797    """
 798    Test if the type is `defaultdict`.
 799
 800    >>> is_default_dict(defaultdict[int, int])
 801    True
 802    >>> is_default_dict(dict[int, int])
 803    False
 804    """
 805    try:
 806        return issubclass(get_origin(typ), defaultdict)  # type: ignore
 807    except TypeError:
 808        return typ is defaultdict
 809
 810
 811@cache
 812def is_deque(typ: type[Any]) -> bool:
 813    """
 814    Test if the type is `collections.deque`.
 815
 816    >>> is_deque(deque[int])
 817    True
 818    >>> is_deque(deque)
 819    True
 820    >>> is_deque(list[int])
 821    False
 822    """
 823    try:
 824        return issubclass(get_origin(typ), deque)  # type: ignore
 825    except TypeError:
 826        return typ is deque
 827
 828
 829@cache
 830def is_bare_deque(typ: type[Any]) -> bool:
 831    """
 832    Test if the type is `collections.deque` without type args.
 833
 834    >>> is_bare_deque(deque[int])
 835    False
 836    >>> is_bare_deque(deque)
 837    True
 838    """
 839    origin = get_origin(typ)
 840    if origin is deque:
 841        return not type_args(typ)
 842    return typ is deque
 843
 844
 845@cache
 846def is_counter(typ: type[Any]) -> bool:
 847    """
 848    Test if the type is `collections.Counter`.
 849
 850    >>> is_counter(Counter[str])
 851    True
 852    >>> is_counter(Counter)
 853    True
 854    >>> is_counter(dict[str, int])
 855    False
 856    """
 857    try:
 858        return issubclass(get_origin(typ), Counter)  # type: ignore
 859    except TypeError:
 860        return typ is Counter
 861
 862
 863@cache
 864def is_bare_counter(typ: type[Any]) -> bool:
 865    """
 866    Test if the type is `collections.Counter` without type args.
 867
 868    >>> is_bare_counter(Counter[str])
 869    False
 870    >>> is_bare_counter(Counter)
 871    True
 872    """
 873    origin = get_origin(typ)
 874    if origin is Counter:
 875        return not type_args(typ)
 876    return typ is Counter
 877
 878
 879@cache
 880def is_none(typ: type[Any]) -> bool:
 881    """
 882    >>> is_none(int)
 883    False
 884    >>> is_none(type(None))
 885    True
 886    >>> is_none(None)
 887    False
 888    """
 889    return typ is type(None)  # noqa
 890
 891
 892PRIMITIVES = [int, float, bool, str]
 893
 894
 895@cache
 896def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
 897    """
 898    Test if the type is `enum.Enum`.
 899    """
 900    try:
 901        return issubclass(typ, enum.Enum)
 902    except TypeError:
 903        return isinstance(typ, enum.Enum)
 904
 905
 906@cache
 907def is_primitive_subclass(typ: type[Any]) -> bool:
 908    """
 909    Test if the type is a subclass of primitive type.
 910
 911    >>> is_primitive_subclass(str)
 912    False
 913    >>> class Str(str):
 914    ...     pass
 915    >>> is_primitive_subclass(Str)
 916    True
 917    """
 918    return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)
 919
 920
 921@cache
 922def is_primitive(typ: type[Any] | NewType) -> bool:
 923    """
 924    Test if the type is primitive.
 925
 926    >>> is_primitive(int)
 927    True
 928    >>> class CustomInt(int):
 929    ...     pass
 930    >>> is_primitive(CustomInt)
 931    True
 932    """
 933    try:
 934        return any(issubclass(typ, ty) for ty in PRIMITIVES)  # type: ignore
 935    except TypeError:
 936        return is_new_type_primitive(typ)
 937
 938
 939@cache
 940def is_new_type_primitive(typ: type[Any] | NewType) -> bool:
 941    """
 942    Test if the type is a NewType of primitives.
 943    """
 944    inner = getattr(typ, "__supertype__", None)
 945    if inner:
 946        return is_primitive(inner)
 947    else:
 948        return any(isinstance(typ, ty) for ty in PRIMITIVES)
 949
 950
 951@cache
 952def has_generic_base(typ: Any) -> bool:
 953    return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ())
 954
 955
 956@cache
 957def is_generic(typ: Any) -> bool:
 958    """
 959    Test if the type is derived from `typing.Generic`.
 960
 961    >>> T = typing.TypeVar('T')
 962    >>> class GenericFoo(typing.Generic[T]):
 963    ...     pass
 964    >>> is_generic(GenericFoo[int])
 965    True
 966    >>> is_generic(GenericFoo)
 967    False
 968    """
 969    origin = get_origin(typ)
 970    return origin is not None and has_generic_base(origin)
 971
 972
 973@cache
 974def is_class_var(typ: type[Any]) -> bool:
 975    """
 976    Test if the type is `typing.ClassVar`.
 977
 978    >>> is_class_var(ClassVar[int])
 979    True
 980    >>> is_class_var(ClassVar)
 981    True
 982    """
 983    return get_origin(typ) is ClassVar or typ is ClassVar  # type: ignore
 984
 985
 986@cache
 987def is_literal(typ: type[Any]) -> bool:
 988    """
 989    Test if the type is derived from `typing.Literal`.
 990
 991    >>> T = typing.TypeVar('T')
 992    >>> class GenericFoo(typing.Generic[T]):
 993    ...     pass
 994    >>> is_generic(GenericFoo[int])
 995    True
 996    >>> is_generic(GenericFoo)
 997    False
 998    """
 999    origin = get_origin(typ)
1000    return origin is not None and origin is typing.Literal
1001
1002
1003@cache
1004def is_any(typ: Any) -> bool:
1005    """
1006    Test if the type is `typing.Any`.
1007    """
1008    return typ is Any
1009
1010
1011@cache
1012def is_str_serializable(typ: type[Any]) -> bool:
1013    """
1014    Test if the type is serializable to `str`.
1015    """
1016    return typ in StrSerializableTypes or (
1017        type(typ) is type and issubclass(typ, StrSerializableTypes)
1018    )
1019
1020
1021def is_datetime(
1022    typ: type[Any],
1023) -> TypeGuard[datetime.date | datetime.time | datetime.datetime]:
1024    """
1025    Test if the type is any of the datetime types..
1026    """
1027    return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes))
1028
1029
1030def is_str_serializable_instance(obj: Any) -> bool:
1031    return isinstance(obj, StrSerializableTypes)
1032
1033
1034def is_datetime_instance(obj: Any) -> bool:
1035    return isinstance(obj, DateTimeTypes)
1036
1037
1038def is_ellipsis(typ: Any) -> bool:
1039    return typ is Ellipsis
1040
1041
1042def is_pep695_type_alias(typ: Any) -> bool:
1043    """
1044    Test if the type is of PEP695 type alias.
1045    """
1046    return isinstance(typ, _PEP695_TYPES)
1047
1048
1049@cache
1050def get_type_var_names(cls: type[Any]) -> list[str] | None:
1051    """
1052    Get type argument names of a generic class.
1053
1054    >>> T = typing.TypeVar('T')
1055    >>> class GenericFoo(typing.Generic[T]):
1056    ...     pass
1057    >>> get_type_var_names(GenericFoo)
1058    ['T']
1059    >>> get_type_var_names(int)
1060    """
1061    bases = getattr(cls, "__orig_bases__", ())
1062    if not bases:
1063        return None
1064
1065    type_arg_names: list[str] = []
1066    for base in bases:
1067        type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__"))
1068
1069    return type_arg_names
1070
1071
1072def find_generic_arg(cls: type[Any], field: TypeVar) -> int:
1073    """
1074    Find a type in generic parameters.
1075
1076    >>> T = typing.TypeVar('T')
1077    >>> U = typing.TypeVar('U')
1078    >>> V = typing.TypeVar('V')
1079    >>> class GenericFoo(typing.Generic[T, U]):
1080    ...     pass
1081    >>> find_generic_arg(GenericFoo, T)
1082    0
1083    >>> find_generic_arg(GenericFoo, U)
1084    1
1085    >>> find_generic_arg(GenericFoo, V)
1086    -1
1087    """
1088    bases = getattr(cls, "__orig_bases__", ())
1089    if not bases:
1090        raise Exception(f'"__orig_bases__" property was not found: {cls}')
1091
1092    for base in bases:
1093        for n, arg in enumerate(get_args(base)):
1094            if arg.__name__ == field.__name__:
1095                return n
1096
1097    if not bases:
1098        raise Exception(f"Generic field not found in class: {bases}")
1099
1100    return -1
1101
1102
1103def get_generic_arg(
1104    typ: Any,
1105    maybe_generic_type_vars: list[str] | None,
1106    variable_type_args: list[str] | None,
1107    index: int,
1108) -> Any:
1109    """
1110    Get generic type argument.
1111
1112    >>> T = typing.TypeVar('T')
1113    >>> U = typing.TypeVar('U')
1114    >>> class GenericFoo(typing.Generic[T, U]):
1115    ...     pass
1116    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__
1117    'int'
1118    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__
1119    'str'
1120    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__
1121    'str'
1122    """
1123    if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None:
1124        return typing.Any
1125
1126    args = get_args(typ)
1127
1128    if len(args) != len(maybe_generic_type_vars):
1129        raise SerdeError(
1130            f"Number of type args for {typ} does not match number of generic type vars: "
1131            f"\n  type args: {args}\n  type_vars: {maybe_generic_type_vars}"
1132        )
1133
1134    # Get the name of the type var used for this field in the parent class definition
1135    type_var_name = variable_type_args[index]
1136
1137    try:
1138        # Find the slot of that type var in the original generic class definition
1139        orig_index = maybe_generic_type_vars.index(type_var_name)
1140    except ValueError:
1141        return typing.Any
1142
1143    return args[orig_index]