Edit on GitHub

serde.compat

Compatibility layer which handles mostly differences of typing module between python versions.

   1"""
   2Compatibility layer which handles mostly differences of `typing` module between python versions.
   3"""
   4
   5import dataclasses
   6import datetime
   7import decimal
   8import enum
   9import functools
  10import ipaddress
  11import itertools
  12import pathlib
  13import types
  14import uuid
  15import typing
  16import typing_extensions
  17from collections import defaultdict
  18from collections.abc import Iterator, Sequence, MutableSequence
  19from collections.abc import Mapping, MutableMapping, Set, MutableSet
  20from dataclasses import is_dataclass
  21from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable
  22
  23import typing_inspect
  24from typing_extensions import TypeGuard, ParamSpec
  25
  26# `typing_extensions.TypeAliasType` isn't always an alias to `typing.TypeAliasType`
  27# depending on certain versions of `typing_extensions` and python.
  28_PEP695_TYPES: tuple[type, ...]
  29if hasattr(typing, "TypeAliasType"):
  30    _PEP695_TYPES = (typing_extensions.TypeAliasType, typing.TypeAliasType)
  31else:
  32    _PEP695_TYPES = (typing_extensions.TypeAliasType,)
  33
  34
  35# Lazy SQLAlchemy imports to improve startup time
  36
  37
  38# Lazy SQLAlchemy import wrapper to improve startup time
  39def _is_sqlalchemy_inspectable(subject: Any) -> bool:
  40    from .sqlalchemy import is_sqlalchemy_inspectable
  41
  42    return is_sqlalchemy_inspectable(subject)
  43
  44
  45def get_np_origin(tp: type[Any]) -> Any | None:
  46    return None
  47
  48
  49def get_np_args(tp: type[Any]) -> tuple[Any, ...]:
  50    return ()
  51
  52
  53__all__: list[str] = []
  54
  55T = TypeVar("T")
  56
  57
  58StrSerializableTypes = (
  59    decimal.Decimal,
  60    pathlib.Path,
  61    pathlib.PosixPath,
  62    pathlib.WindowsPath,
  63    pathlib.PurePath,
  64    pathlib.PurePosixPath,
  65    pathlib.PureWindowsPath,
  66    uuid.UUID,
  67    ipaddress.IPv4Address,
  68    ipaddress.IPv6Address,
  69    ipaddress.IPv4Network,
  70    ipaddress.IPv6Network,
  71    ipaddress.IPv4Interface,
  72    ipaddress.IPv6Interface,
  73)
  74""" List of standard types (de)serializable to str """
  75
  76DateTimeTypes = (datetime.date, datetime.time, datetime.datetime)
  77""" List of datetime types """
  78
  79
  80@dataclasses.dataclass(unsafe_hash=True)
  81class _WithTagging(Generic[T]):
  82    """
  83    Intermediate data structure for (de)serializaing Union without dataclass.
  84    """
  85
  86    inner: T
  87    """ Union type .e.g Union[Foo,Bar] passed in from_obj. """
  88    tagging: Any
  89    """ Union Tagging """
  90
  91
  92class SerdeError(Exception):
  93    """
  94    Serde error class.
  95    """
  96
  97
  98@dataclasses.dataclass
  99class UserError(Exception):
 100    """
 101    Error from user code e.g. __post_init__.
 102    """
 103
 104    inner: Exception
 105
 106
 107class SerdeSkip(Exception):
 108    """
 109    Skip a field in custom (de)serializer.
 110    """
 111
 112
 113def is_hashable(typ: Any) -> TypeGuard[Hashable]:
 114    """
 115    Test is an object hashable
 116    """
 117    try:
 118        hash(typ)
 119    except TypeError:
 120        return False
 121    return True
 122
 123
 124P = ParamSpec("P")
 125
 126
 127def cache(f: Callable[P, T]) -> Callable[P, T]:
 128    """
 129    Wrapper for `functools.cache` to avoid `Hashable` related type errors.
 130    """
 131
 132    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
 133        return f(*args, **kwargs)
 134
 135    return functools.cache(wrapper)  # type: ignore
 136
 137
 138@cache
 139def get_origin(typ: Any) -> Any | None:
 140    """
 141    Provide `get_origin` that works in all python versions.
 142    """
 143    return typing.get_origin(typ) or get_np_origin(typ)
 144
 145
 146@cache
 147def get_args(typ: type[Any]) -> tuple[Any, ...]:
 148    """
 149    Provide `get_args` that works in all python versions.
 150    """
 151    return typing.get_args(typ) or get_np_args(typ)
 152
 153
 154@cache
 155def typename(typ: Any, with_typing_module: bool = False) -> str:
 156    """
 157    >>> from typing import Any
 158    >>> typename(int)
 159    'int'
 160    >>> class Foo: pass
 161    >>> typename(Foo)
 162    'Foo'
 163    >>> typename(list[Foo])
 164    'list[Foo]'
 165    >>> typename(dict[str, Foo])
 166    'dict[str, Foo]'
 167    >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]])
 168    'tuple[int, str, Foo, list[int], dict[str, Foo]]'
 169    >>> typename(Optional[list[Foo]])
 170    'Optional[list[Foo]]'
 171    >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]])
 172    'Union[Optional[Foo], list[Foo], str, int]'
 173    >>> typename(set[Foo])
 174    'set[Foo]'
 175    >>> typename(Any)
 176    'Any'
 177    """
 178    mod = "typing." if with_typing_module else ""
 179    thisfunc = functools.partial(typename, with_typing_module=with_typing_module)
 180    if is_opt(typ):
 181        args = type_args(typ)
 182        if args:
 183            return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]"
 184        else:
 185            return f"{mod}Optional"
 186    elif is_union(typ):
 187        args = union_args(typ)
 188        if args:
 189            return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]'
 190        else:
 191            return f"{mod}Union"
 192    elif is_list(typ):
 193        args = type_args(typ)
 194        if args:
 195            et = thisfunc(args[0])
 196            return f"{mod}list[{et}]"
 197        else:
 198            return f"{mod}list"
 199    elif is_set(typ):
 200        args = type_args(typ)
 201        if args:
 202            et = thisfunc(args[0])
 203            return f"{mod}set[{et}]"
 204        else:
 205            return f"{mod}set"
 206    elif is_dict(typ):
 207        args = type_args(typ)
 208        if args and len(args) == 2:
 209            kt = thisfunc(args[0])
 210            vt = thisfunc(args[1])
 211            return f"{mod}dict[{kt}, {vt}]"
 212        else:
 213            return f"{mod}dict"
 214    elif is_tuple(typ):
 215        args = type_args(typ)
 216        if args:
 217            return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]'
 218        else:
 219            return f"{mod}tuple"
 220    elif is_generic(typ):
 221        origin = get_origin(typ)
 222        if origin is None:
 223            raise SerdeError("Could not extract origin class from generic class")
 224
 225        if not isinstance(origin.__name__, str):
 226            raise SerdeError("Name of generic class is not string")
 227
 228        return origin.__name__
 229
 230    elif is_literal(typ):
 231        args = type_args(typ)
 232        if not args:
 233            raise TypeError("Literal type requires at least one literal argument")
 234        return f'Literal[{", ".join(stringify_literal(e) for e in args)}]'
 235    elif typ is Any:
 236        return f"{mod}Any"
 237    elif is_ellipsis(typ):
 238        return "..."
 239    else:
 240        # Get super type for NewType
 241        inner = getattr(typ, "__supertype__", None)
 242        if inner:
 243            return typename(inner)
 244
 245        name: str | None = getattr(typ, "_name", None)
 246        if name:
 247            return name
 248        else:
 249            name = getattr(typ, "__name__", None)
 250            if isinstance(name, str):
 251                return name
 252            else:
 253                raise SerdeError(f"Could not get a type name from: {typ}")
 254
 255
 256def stringify_literal(v: Any) -> str:
 257    if isinstance(v, str):
 258        return f"'{v}'"
 259    else:
 260        return str(v)
 261
 262
 263def type_args(typ: Any) -> tuple[type[Any], ...]:
 264    """
 265    Wrapper to suppress type error for accessing private members.
 266    """
 267    try:
 268        args: tuple[type[Any, ...]] | None = typ.__args__  # type: ignore
 269    except AttributeError:
 270        return get_args(typ)
 271
 272    # Some typing objects expose __args__ as a member_descriptor (e.g. typing.Union),
 273    # which isn't iterable. Fall back to typing.get_args in that case.
 274    if isinstance(args, tuple):
 275        return args
 276    return get_args(typ)
 277
 278
 279def union_args(typ: Any) -> tuple[type[Any], ...]:
 280    if not is_union(typ):
 281        raise TypeError(f"{typ} is not Union")
 282    args = type_args(typ)
 283    if len(args) == 1:
 284        return (args[0],)
 285    it = iter(args)
 286    types = []
 287    for i1, i2 in itertools.zip_longest(it, it):
 288        if not i2:
 289            types.append(i1)
 290        elif is_none(i2):
 291            types.append(Optional[i1])
 292        else:
 293            types.extend((i1, i2))
 294    return tuple(types)
 295
 296
 297def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]:  # type: ignore
 298    raw_fields = dataclasses.fields(cls)
 299
 300    try:
 301        # this resolves types when string forward reference
 302        # or PEP 563: "from __future__ import annotations" are used
 303        resolved_hints = typing.get_type_hints(cls)
 304    except Exception as e:
 305        raise SerdeError(
 306            f"Failed to resolve type hints for {typename(cls)}:\n"
 307            f"{e.__class__.__name__}: {e}\n\n"
 308            f"If you are using forward references make sure you are calling deserialize & "
 309            "serialize after all classes are globally visible."
 310        ) from e
 311
 312    for f in raw_fields:
 313        real_type = resolved_hints.get(f.name)
 314        if real_type is not None:
 315            f.type = real_type
 316            if is_generic(real_type) and _is_sqlalchemy_inspectable(cls):
 317                f.type = get_args(real_type)[0]
 318
 319    return iter(raw_fields)
 320
 321
 322TypeLike = Union[type[Any], typing.Any]
 323
 324
 325def iter_types(cls: type[Any]) -> list[type[Any]]:
 326    """
 327    Iterate field types recursively.
 328
 329    The correct return type is `Iterator[Union[Type, typing._specialform]],
 330    but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
 331    """
 332    lst: set[Union[type[Any], Any]] = set()
 333
 334    def recursive(cls: Union[type[Any], Any]) -> None:
 335        if cls in lst:
 336            return
 337
 338        if is_dataclass(cls):
 339            lst.add(cls)
 340            if isinstance(cls, type):
 341                for f in dataclass_fields(cls):
 342                    recursive(f.type)
 343        elif isinstance(cls, str):
 344            lst.add(cls)
 345        elif is_opt(cls):
 346            lst.add(Optional)
 347            args = type_args(cls)
 348            if args:
 349                recursive(args[0])
 350        elif is_union(cls):
 351            lst.add(Union)
 352            for arg in type_args(cls):
 353                recursive(arg)
 354        elif is_list(cls):
 355            lst.add(list)
 356            args = type_args(cls)
 357            if args:
 358                recursive(args[0])
 359        elif is_set(cls):
 360            lst.add(set)
 361            args = type_args(cls)
 362            if args:
 363                recursive(args[0])
 364        elif is_tuple(cls):
 365            lst.add(tuple)
 366            for arg in type_args(cls):
 367                recursive(arg)
 368        elif is_dict(cls):
 369            lst.add(dict)
 370            args = type_args(cls)
 371            if args and len(args) >= 2:
 372                recursive(args[0])
 373                recursive(args[1])
 374        elif is_pep695_type_alias(cls):
 375            recursive(cls.__value__)
 376        else:
 377            lst.add(cls)
 378
 379    recursive(cls)
 380    return list(lst)
 381
 382
 383def iter_unions(cls: TypeLike) -> list[TypeLike]:
 384    """
 385    Iterate over all unions that are used in the dataclass
 386    """
 387    lst: list[TypeLike] = []
 388    stack: list[TypeLike] = []  # To prevent infinite recursion
 389
 390    def recursive(cls: TypeLike) -> None:
 391        if cls in stack:
 392            return
 393
 394        if is_union(cls):
 395            lst.append(cls)
 396            for arg in type_args(cls):
 397                recursive(arg)
 398        elif is_pep695_type_alias(cls):
 399            recursive(cls.__value__)
 400        if is_dataclass(cls):
 401            stack.append(cls)
 402            if isinstance(cls, type):
 403                for f in dataclass_fields(cls):
 404                    recursive(f.type)
 405            stack.pop()
 406        elif is_opt(cls):
 407            args = type_args(cls)
 408            if args:
 409                recursive(args[0])
 410        elif is_list(cls) or is_set(cls):
 411            args = type_args(cls)
 412            if args:
 413                recursive(args[0])
 414        elif is_tuple(cls):
 415            for arg in type_args(cls):
 416                recursive(arg)
 417        elif is_dict(cls):
 418            args = type_args(cls)
 419            if args and len(args) >= 2:
 420                recursive(args[0])
 421                recursive(args[1])
 422
 423    recursive(cls)
 424    return lst
 425
 426
 427def iter_literals(cls: type[Any]) -> list[TypeLike]:
 428    """
 429    Iterate over all literals that are used in the dataclass
 430    """
 431    lst: set[Union[type[Any], Any]] = set()
 432
 433    def recursive(cls: Union[type[Any], Any]) -> None:
 434        if cls in lst:
 435            return
 436
 437        if is_literal(cls):
 438            lst.add(cls)
 439        if is_union(cls):
 440            for arg in type_args(cls):
 441                recursive(arg)
 442        if is_dataclass(cls):
 443            lst.add(cls)
 444            if isinstance(cls, type):
 445                for f in dataclass_fields(cls):
 446                    recursive(f.type)
 447        elif is_opt(cls):
 448            args = type_args(cls)
 449            if args:
 450                recursive(args[0])
 451        elif is_list(cls) or is_set(cls):
 452            args = type_args(cls)
 453            if args:
 454                recursive(args[0])
 455        elif is_tuple(cls):
 456            for arg in type_args(cls):
 457                recursive(arg)
 458        elif is_dict(cls):
 459            args = type_args(cls)
 460            if args and len(args) >= 2:
 461                recursive(args[0])
 462                recursive(args[1])
 463
 464    recursive(cls)
 465    return list(lst)
 466
 467
 468@cache
 469def is_union(typ: Any) -> bool:
 470    """
 471    Test if the type is `typing.Union`.
 472
 473    >>> is_union(Union[int, str])
 474    True
 475    """
 476
 477    try:
 478        # When `_WithTagging` is received, it will check inner type.
 479        if isinstance(typ, _WithTagging):
 480            return is_union(typ.inner)
 481    except Exception:
 482        pass
 483
 484    # Python 3.10+ Union operator e.g. str | int
 485    try:
 486        if isinstance(typ, types.UnionType):
 487            return True
 488    except Exception:
 489        pass
 490
 491    # typing.Union
 492    return typing_inspect.is_union_type(typ)  # type: ignore
 493
 494
 495@cache
 496def is_opt(typ: Any) -> bool:
 497    """
 498    Test if the type is `typing.Optional`.
 499
 500    >>> is_opt(Optional[int])
 501    True
 502    >>> is_opt(Optional)
 503    True
 504    >>> is_opt(None.__class__)
 505    False
 506    """
 507
 508    # Python 3.10+ Union operator e.g. str | None
 509    is_union_type = False
 510    try:
 511        if isinstance(typ, types.UnionType):
 512            is_union_type = True
 513    except Exception:
 514        pass
 515
 516    # typing.Optional
 517    is_typing_union = typing_inspect.is_optional_type(typ)
 518
 519    args = type_args(typ)
 520    if args:
 521        return (
 522            (is_union_type or is_typing_union)
 523            and len(args) == 2
 524            and not is_none(args[0])
 525            and is_none(args[1])
 526        )
 527    else:
 528        return typ is Optional
 529
 530
 531@cache
 532def is_bare_opt(typ: Any) -> bool:
 533    """
 534    Test if the type is `typing.Optional` without type args.
 535    >>> is_bare_opt(Optional[int])
 536    False
 537    >>> is_bare_opt(Optional)
 538    True
 539    >>> is_bare_opt(None.__class__)
 540    False
 541    """
 542    return not type_args(typ) and typ is Optional
 543
 544
 545@cache
 546def is_opt_dataclass(typ: Any) -> bool:
 547    """
 548    Test if the type is optional dataclass.
 549
 550    >>> is_opt_dataclass(Optional[int])
 551    False
 552    >>> @dataclasses.dataclass
 553    ... class Foo:
 554    ...     pass
 555    >>> is_opt_dataclass(Foo)
 556    False
 557    >>> is_opt_dataclass(Optional[Foo])
 558    False
 559    """
 560    args = get_args(typ)
 561    return is_opt(typ) and len(args) > 0 and is_dataclass(args[0])
 562
 563
 564@cache
 565def is_list(typ: type[Any]) -> bool:
 566    """
 567    Test if the type is `list`, `collections.abc.Sequence`, or `collections.abc.MutableSequence`.
 568
 569    >>> is_list(list[int])
 570    True
 571    >>> is_list(list)
 572    True
 573    >>> is_list(Sequence[int])
 574    True
 575    >>> is_list(Sequence)
 576    True
 577    >>> is_list(MutableSequence[int])
 578    True
 579    >>> is_list(MutableSequence)
 580    True
 581    """
 582    origin = get_origin(typ)
 583    if origin is None:
 584        return typ in (list, Sequence, MutableSequence)
 585    return origin in (list, Sequence, MutableSequence)
 586
 587
 588@cache
 589def is_bare_list(typ: type[Any]) -> bool:
 590    """
 591    Test if the type is `list`/`collections.abc.Sequence`/`collections.abc.MutableSequence`
 592    without type args.
 593
 594    >>> is_bare_list(list[int])
 595    False
 596    >>> is_bare_list(list)
 597    True
 598    >>> is_bare_list(Sequence[int])
 599    False
 600    >>> is_bare_list(Sequence)
 601    True
 602    >>> is_bare_list(MutableSequence[int])
 603    False
 604    >>> is_bare_list(MutableSequence)
 605    True
 606    """
 607    origin = get_origin(typ)
 608    if origin in (list, Sequence, MutableSequence):
 609        return not type_args(typ)
 610    return typ in (list, Sequence, MutableSequence)
 611
 612
 613@cache
 614def is_tuple(typ: Any) -> bool:
 615    """
 616    Test if the type is tuple.
 617    """
 618    try:
 619        return issubclass(get_origin(typ), tuple)  # type: ignore
 620    except TypeError:
 621        return typ is tuple
 622
 623
 624@cache
 625def is_bare_tuple(typ: type[Any]) -> bool:
 626    """
 627    Test if the type is tuple without type args.
 628
 629    >>> is_bare_tuple(tuple[int, str])
 630    False
 631    >>> is_bare_tuple(tuple)
 632    True
 633    """
 634    return typ is tuple
 635
 636
 637@cache
 638def is_variable_tuple(typ: type[Any]) -> bool:
 639    """
 640    Test if the type is a variable length of tuple tuple[T, ...]`.
 641
 642    >>> is_variable_tuple(tuple[int, ...])
 643    True
 644    >>> is_variable_tuple(tuple[int, bool])
 645    False
 646    >>> is_variable_tuple(tuple[()])
 647    False
 648    """
 649    istuple = is_tuple(typ) and not is_bare_tuple(typ)
 650    args = get_args(typ)
 651    return istuple and len(args) == 2 and is_ellipsis(args[1])
 652
 653
 654@cache
 655def is_set(typ: type[Any]) -> bool:
 656    """
 657    Test if the type is set-like.
 658
 659    >>> is_set(set[int])
 660    True
 661    >>> is_set(set)
 662    True
 663    >>> is_set(frozenset[int])
 664    True
 665    >>> from collections.abc import Set, MutableSet
 666    >>> is_set(Set[int])
 667    True
 668    >>> is_set(Set)
 669    True
 670    >>> is_set(MutableSet[int])
 671    True
 672    >>> is_set(MutableSet)
 673    True
 674    """
 675    try:
 676        return issubclass(get_origin(typ), (set, frozenset, Set, MutableSet))  # type: ignore[arg-type]
 677    except TypeError:
 678        return typ in (set, frozenset, Set, MutableSet)
 679
 680
 681@cache
 682def is_bare_set(typ: type[Any]) -> bool:
 683    """
 684    Test if the type is `set`/`frozenset`/`Set`/`MutableSet` without type args.
 685
 686    >>> is_bare_set(set[int])
 687    False
 688    >>> is_bare_set(set)
 689    True
 690    >>> from collections.abc import Set, MutableSet
 691    >>> is_bare_set(Set)
 692    True
 693    >>> is_bare_set(MutableSet)
 694    True
 695    """
 696    origin = get_origin(typ)
 697    if origin in (set, frozenset, Set, MutableSet):
 698        return not type_args(typ)
 699    return typ in (set, frozenset, Set, MutableSet)
 700
 701
 702@cache
 703def is_frozen_set(typ: type[Any]) -> bool:
 704    """
 705    Test if the type is `frozenset`.
 706
 707    >>> is_frozen_set(frozenset[int])
 708    True
 709    >>> is_frozen_set(set)
 710    False
 711    """
 712    try:
 713        return issubclass(get_origin(typ), frozenset)  # type: ignore
 714    except TypeError:
 715        return typ is frozenset
 716
 717
 718@cache
 719def is_dict(typ: type[Any]) -> bool:
 720    """
 721    Test if the type is dict-like.
 722
 723    >>> is_dict(dict[int, int])
 724    True
 725    >>> is_dict(dict)
 726    True
 727    >>> is_dict(defaultdict[int, int])
 728    True
 729    >>> from collections.abc import Mapping, MutableMapping
 730    >>> is_dict(Mapping[str, int])
 731    True
 732    >>> is_dict(Mapping)
 733    True
 734    >>> is_dict(MutableMapping[str, int])
 735    True
 736    >>> is_dict(MutableMapping)
 737    True
 738    """
 739    try:
 740        return issubclass(
 741            get_origin(typ), (dict, defaultdict, Mapping, MutableMapping)  # type: ignore[arg-type]
 742        )
 743    except TypeError:
 744        return typ in (dict, defaultdict, Mapping, MutableMapping)
 745
 746
 747@cache
 748@cache
 749def is_bare_dict(typ: type[Any]) -> bool:
 750    """
 751    Test if the type is `dict`/`Mapping`/`MutableMapping` without type args.
 752
 753    >>> is_bare_dict(dict[int, str])
 754    False
 755    >>> is_bare_dict(dict)
 756    True
 757    >>> from collections.abc import Mapping, MutableMapping
 758    >>> is_bare_dict(Mapping)
 759    True
 760    >>> is_bare_dict(MutableMapping)
 761    True
 762    """
 763    origin = get_origin(typ)
 764    if origin in (dict, Mapping, MutableMapping):
 765        return not type_args(typ)
 766    return typ in (dict, Mapping, MutableMapping)
 767
 768
 769@cache
 770def is_default_dict(typ: type[Any]) -> bool:
 771    """
 772    Test if the type is `defaultdict`.
 773
 774    >>> is_default_dict(defaultdict[int, int])
 775    True
 776    >>> is_default_dict(dict[int, int])
 777    False
 778    """
 779    try:
 780        return issubclass(get_origin(typ), defaultdict)  # type: ignore
 781    except TypeError:
 782        return typ is defaultdict
 783
 784
 785@cache
 786def is_none(typ: type[Any]) -> bool:
 787    """
 788    >>> is_none(int)
 789    False
 790    >>> is_none(type(None))
 791    True
 792    >>> is_none(None)
 793    False
 794    """
 795    return typ is type(None)  # noqa
 796
 797
 798PRIMITIVES = [int, float, bool, str]
 799
 800
 801@cache
 802def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
 803    """
 804    Test if the type is `enum.Enum`.
 805    """
 806    try:
 807        return issubclass(typ, enum.Enum)
 808    except TypeError:
 809        return isinstance(typ, enum.Enum)
 810
 811
 812@cache
 813def is_primitive_subclass(typ: type[Any]) -> bool:
 814    """
 815    Test if the type is a subclass of primitive type.
 816
 817    >>> is_primitive_subclass(str)
 818    False
 819    >>> class Str(str):
 820    ...     pass
 821    >>> is_primitive_subclass(Str)
 822    True
 823    """
 824    return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)
 825
 826
 827@cache
 828def is_primitive(typ: type[Any] | NewType) -> bool:
 829    """
 830    Test if the type is primitive.
 831
 832    >>> is_primitive(int)
 833    True
 834    >>> class CustomInt(int):
 835    ...     pass
 836    >>> is_primitive(CustomInt)
 837    True
 838    """
 839    try:
 840        return any(issubclass(typ, ty) for ty in PRIMITIVES)  # type: ignore
 841    except TypeError:
 842        return is_new_type_primitive(typ)
 843
 844
 845@cache
 846def is_new_type_primitive(typ: type[Any] | NewType) -> bool:
 847    """
 848    Test if the type is a NewType of primitives.
 849    """
 850    inner = getattr(typ, "__supertype__", None)
 851    if inner:
 852        return is_primitive(inner)
 853    else:
 854        return any(isinstance(typ, ty) for ty in PRIMITIVES)
 855
 856
 857@cache
 858def has_generic_base(typ: Any) -> bool:
 859    return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ())
 860
 861
 862@cache
 863def is_generic(typ: Any) -> bool:
 864    """
 865    Test if the type is derived from `typing.Generic`.
 866
 867    >>> T = typing.TypeVar('T')
 868    >>> class GenericFoo(typing.Generic[T]):
 869    ...     pass
 870    >>> is_generic(GenericFoo[int])
 871    True
 872    >>> is_generic(GenericFoo)
 873    False
 874    """
 875    origin = get_origin(typ)
 876    return origin is not None and has_generic_base(origin)
 877
 878
 879@cache
 880def is_class_var(typ: type[Any]) -> bool:
 881    """
 882    Test if the type is `typing.ClassVar`.
 883
 884    >>> is_class_var(ClassVar[int])
 885    True
 886    >>> is_class_var(ClassVar)
 887    True
 888    """
 889    return get_origin(typ) is ClassVar or typ is ClassVar  # type: ignore
 890
 891
 892@cache
 893def is_literal(typ: type[Any]) -> bool:
 894    """
 895    Test if the type is derived from `typing.Literal`.
 896
 897    >>> T = typing.TypeVar('T')
 898    >>> class GenericFoo(typing.Generic[T]):
 899    ...     pass
 900    >>> is_generic(GenericFoo[int])
 901    True
 902    >>> is_generic(GenericFoo)
 903    False
 904    """
 905    origin = get_origin(typ)
 906    return origin is not None and origin is typing.Literal
 907
 908
 909@cache
 910def is_any(typ: Any) -> bool:
 911    """
 912    Test if the type is `typing.Any`.
 913    """
 914    return typ is Any
 915
 916
 917@cache
 918def is_str_serializable(typ: type[Any]) -> bool:
 919    """
 920    Test if the type is serializable to `str`.
 921    """
 922    return typ in StrSerializableTypes or (
 923        type(typ) is type and issubclass(typ, StrSerializableTypes)
 924    )
 925
 926
 927def is_datetime(
 928    typ: type[Any],
 929) -> TypeGuard[datetime.date | datetime.time | datetime.datetime]:
 930    """
 931    Test if the type is any of the datetime types..
 932    """
 933    return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes))
 934
 935
 936def is_str_serializable_instance(obj: Any) -> bool:
 937    return isinstance(obj, StrSerializableTypes)
 938
 939
 940def is_datetime_instance(obj: Any) -> bool:
 941    return isinstance(obj, DateTimeTypes)
 942
 943
 944def is_ellipsis(typ: Any) -> bool:
 945    return typ is Ellipsis
 946
 947
 948def is_pep695_type_alias(typ: Any) -> bool:
 949    """
 950    Test if the type is of PEP695 type alias.
 951    """
 952    return isinstance(typ, _PEP695_TYPES)
 953
 954
 955@cache
 956def get_type_var_names(cls: type[Any]) -> list[str] | None:
 957    """
 958    Get type argument names of a generic class.
 959
 960    >>> T = typing.TypeVar('T')
 961    >>> class GenericFoo(typing.Generic[T]):
 962    ...     pass
 963    >>> get_type_var_names(GenericFoo)
 964    ['T']
 965    >>> get_type_var_names(int)
 966    """
 967    bases = getattr(cls, "__orig_bases__", ())
 968    if not bases:
 969        return None
 970
 971    type_arg_names: list[str] = []
 972    for base in bases:
 973        type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__"))
 974
 975    return type_arg_names
 976
 977
 978def find_generic_arg(cls: type[Any], field: TypeVar) -> int:
 979    """
 980    Find a type in generic parameters.
 981
 982    >>> T = typing.TypeVar('T')
 983    >>> U = typing.TypeVar('U')
 984    >>> V = typing.TypeVar('V')
 985    >>> class GenericFoo(typing.Generic[T, U]):
 986    ...     pass
 987    >>> find_generic_arg(GenericFoo, T)
 988    0
 989    >>> find_generic_arg(GenericFoo, U)
 990    1
 991    >>> find_generic_arg(GenericFoo, V)
 992    -1
 993    """
 994    bases = getattr(cls, "__orig_bases__", ())
 995    if not bases:
 996        raise Exception(f'"__orig_bases__" property was not found: {cls}')
 997
 998    for base in bases:
 999        for n, arg in enumerate(get_args(base)):
1000            if arg.__name__ == field.__name__:
1001                return n
1002
1003    if not bases:
1004        raise Exception(f"Generic field not found in class: {bases}")
1005
1006    return -1
1007
1008
1009def get_generic_arg(
1010    typ: Any,
1011    maybe_generic_type_vars: list[str] | None,
1012    variable_type_args: list[str] | None,
1013    index: int,
1014) -> Any:
1015    """
1016    Get generic type argument.
1017
1018    >>> T = typing.TypeVar('T')
1019    >>> U = typing.TypeVar('U')
1020    >>> class GenericFoo(typing.Generic[T, U]):
1021    ...     pass
1022    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__
1023    'int'
1024    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__
1025    'str'
1026    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__
1027    'str'
1028    """
1029    if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None:
1030        return typing.Any
1031
1032    args = get_args(typ)
1033
1034    if len(args) != len(maybe_generic_type_vars):
1035        raise SerdeError(
1036            f"Number of type args for {typ} does not match number of generic type vars: "
1037            f"\n  type args: {args}\n  type_vars: {maybe_generic_type_vars}"
1038        )
1039
1040    # Get the name of the type var used for this field in the parent class definition
1041    type_var_name = variable_type_args[index]
1042
1043    try:
1044        # Find the slot of that type var in the original generic class definition
1045        orig_index = maybe_generic_type_vars.index(type_var_name)
1046    except ValueError:
1047        return typing.Any
1048
1049    return args[orig_index]