Edit on GitHub

serde.compat

Compatibility layer which handles mostly differences of typing module between python versions.

  1"""
  2Compatibility layer which handles mostly differences of `typing` module between python versions.
  3"""
  4
  5import dataclasses
  6import datetime
  7import decimal
  8import enum
  9import functools
 10import ipaddress
 11import itertools
 12import pathlib
 13import sys
 14import types
 15import uuid
 16import typing
 17from collections import defaultdict
 18from collections.abc import Iterator
 19from dataclasses import is_dataclass
 20from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable
 21
 22import typing_inspect
 23from typing_extensions import TypeGuard, ParamSpec, TypeAliasType
 24
 25from .sqlalchemy import is_sqlalchemy_inspectable
 26
 27
 28def get_np_origin(tp: type[Any]) -> Optional[Any]:
 29    return None
 30
 31
 32def get_np_args(tp: type[Any]) -> tuple[Any, ...]:
 33    return ()
 34
 35
 36__all__: list[str] = []
 37
 38T = TypeVar("T")
 39
 40
 41StrSerializableTypes = (
 42    decimal.Decimal,
 43    pathlib.Path,
 44    pathlib.PosixPath,
 45    pathlib.WindowsPath,
 46    pathlib.PurePath,
 47    pathlib.PurePosixPath,
 48    pathlib.PureWindowsPath,
 49    uuid.UUID,
 50    ipaddress.IPv4Address,
 51    ipaddress.IPv6Address,
 52    ipaddress.IPv4Network,
 53    ipaddress.IPv6Network,
 54    ipaddress.IPv4Interface,
 55    ipaddress.IPv6Interface,
 56)
 57""" List of standard types (de)serializable to str """
 58
 59DateTimeTypes = (datetime.date, datetime.time, datetime.datetime)
 60""" List of datetime types """
 61
 62
 63@dataclasses.dataclass(unsafe_hash=True)
 64class _WithTagging(Generic[T]):
 65    """
 66    Intermediate data structure for (de)serializaing Union without dataclass.
 67    """
 68
 69    inner: T
 70    """ Union type .e.g Union[Foo,Bar] passed in from_obj. """
 71    tagging: Any
 72    """ Union Tagging """
 73
 74
 75class SerdeError(Exception):
 76    """
 77    Serde error class.
 78    """
 79
 80
 81@dataclasses.dataclass
 82class UserError(Exception):
 83    """
 84    Error from user code e.g. __post_init__.
 85    """
 86
 87    inner: Exception
 88
 89
 90class SerdeSkip(Exception):
 91    """
 92    Skip a field in custom (de)serializer.
 93    """
 94
 95
 96def is_hashable(typ: Any) -> TypeGuard[Hashable]:
 97    """
 98    Test is an object hashable
 99    """
100    try:
101        hash(typ)
102    except TypeError:
103        return False
104    return True
105
106
107P = ParamSpec("P")
108
109
110def cache(f: Callable[P, T]) -> Callable[P, T]:
111    """
112    Wrapper for `functools.cache` to avoid `Hashable` related type errors.
113    """
114
115    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
116        return f(*args, **kwargs)
117
118    return functools.cache(wrapper)  # type: ignore
119
120
121@cache
122def get_origin(typ: Any) -> Optional[Any]:
123    """
124    Provide `get_origin` that works in all python versions.
125    """
126    return typing.get_origin(typ) or get_np_origin(typ)
127
128
129@cache
130def get_args(typ: type[Any]) -> tuple[Any, ...]:
131    """
132    Provide `get_args` that works in all python versions.
133    """
134    return typing.get_args(typ) or get_np_args(typ)
135
136
137@cache
138def typename(typ: Any, with_typing_module: bool = False) -> str:
139    """
140    >>> from typing import Any
141    >>> typename(int)
142    'int'
143    >>> class Foo: pass
144    >>> typename(Foo)
145    'Foo'
146    >>> typename(list[Foo])
147    'list[Foo]'
148    >>> typename(dict[str, Foo])
149    'dict[str, Foo]'
150    >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]])
151    'tuple[int, str, Foo, list[int], dict[str, Foo]]'
152    >>> typename(Optional[list[Foo]])
153    'Optional[list[Foo]]'
154    >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]])
155    'Union[Optional[Foo], list[Foo], str, int]'
156    >>> typename(set[Foo])
157    'set[Foo]'
158    >>> typename(Any)
159    'Any'
160    """
161    mod = "typing." if with_typing_module else ""
162    thisfunc = functools.partial(typename, with_typing_module=with_typing_module)
163    if is_opt(typ):
164        args = type_args(typ)
165        if args:
166            return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]"
167        else:
168            return f"{mod}Optional"
169    elif is_union(typ):
170        args = union_args(typ)
171        if args:
172            return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]'
173        else:
174            return f"{mod}Union"
175    elif is_list(typ):
176        args = type_args(typ)
177        if args:
178            et = thisfunc(args[0])
179            return f"{mod}list[{et}]"
180        else:
181            return f"{mod}list"
182    elif is_set(typ):
183        args = type_args(typ)
184        if args:
185            et = thisfunc(args[0])
186            return f"{mod}set[{et}]"
187        else:
188            return f"{mod}set"
189    elif is_dict(typ):
190        args = type_args(typ)
191        if args and len(args) == 2:
192            kt = thisfunc(args[0])
193            vt = thisfunc(args[1])
194            return f"{mod}dict[{kt}, {vt}]"
195        else:
196            return f"{mod}dict"
197    elif is_tuple(typ):
198        args = type_args(typ)
199        if args:
200            return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]'
201        else:
202            return f"{mod}tuple"
203    elif is_generic(typ):
204        origin = get_origin(typ)
205        if origin is None:
206            raise SerdeError("Could not extract origin class from generic class")
207
208        if not isinstance(origin.__name__, str):
209            raise SerdeError("Name of generic class is not string")
210
211        return origin.__name__
212
213    elif is_literal(typ):
214        args = type_args(typ)
215        if not args:
216            raise TypeError("Literal type requires at least one literal argument")
217        return f'Literal[{", ".join(stringify_literal(e) for e in args)}]'
218    elif typ is Any:
219        return f"{mod}Any"
220    elif is_ellipsis(typ):
221        return "..."
222    else:
223        # Get super type for NewType
224        inner = getattr(typ, "__supertype__", None)
225        if inner:
226            return typename(inner)
227
228        name: Optional[str] = getattr(typ, "_name", None)
229        if name:
230            return name
231        else:
232            name = getattr(typ, "__name__", None)
233            if isinstance(name, str):
234                return name
235            else:
236                raise SerdeError(f"Could not get a type name from: {typ}")
237
238
239def stringify_literal(v: Any) -> str:
240    if isinstance(v, str):
241        return f"'{v}'"
242    else:
243        return str(v)
244
245
246def type_args(typ: Any) -> tuple[type[Any], ...]:
247    """
248    Wrapper to suppress type error for accessing private members.
249    """
250    try:
251        args: tuple[type[Any, ...]] = typ.__args__  # type: ignore
252        if args is None:
253            return ()
254        else:
255            return args
256    except AttributeError:
257        return get_args(typ)
258
259
260def union_args(typ: Any) -> tuple[type[Any], ...]:
261    if not is_union(typ):
262        raise TypeError(f"{typ} is not Union")
263    args = type_args(typ)
264    if len(args) == 1:
265        return (args[0],)
266    it = iter(args)
267    types = []
268    for i1, i2 in itertools.zip_longest(it, it):
269        if not i2:
270            types.append(i1)
271        elif is_none(i2):
272            types.append(Optional[i1])
273        else:
274            types.extend((i1, i2))
275    return tuple(types)
276
277
278def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]:  # type: ignore
279    raw_fields = dataclasses.fields(cls)
280
281    try:
282        # this resolves types when string forward reference
283        # or PEP 563: "from __future__ import annotations" are used
284        resolved_hints = typing.get_type_hints(cls)
285    except Exception as e:
286        raise SerdeError(
287            f"Failed to resolve type hints for {typename(cls)}:\n"
288            f"{e.__class__.__name__}: {e}\n\n"
289            f"If you are using forward references make sure you are calling deserialize & "
290            "serialize after all classes are globally visible."
291        ) from e
292
293    for f in raw_fields:
294        real_type = resolved_hints.get(f.name)
295        if real_type is not None:
296            f.type = real_type
297            if is_generic(real_type) and is_sqlalchemy_inspectable(cls):
298                f.type = get_args(real_type)[0]
299
300    return iter(raw_fields)
301
302
303TypeLike = Union[type[Any], typing.Any]
304
305
306def iter_types(cls: type[Any]) -> list[type[Any]]:
307    """
308    Iterate field types recursively.
309
310    The correct return type is `Iterator[Union[Type, typing._specialform]],
311    but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
312    """
313    lst: set[Union[type[Any], Any]] = set()
314
315    def recursive(cls: Union[type[Any], Any]) -> None:
316        if cls in lst:
317            return
318
319        if is_dataclass(cls):
320            lst.add(cls)
321            for f in dataclass_fields(cls):
322                recursive(f.type)
323        elif isinstance(cls, str):
324            lst.add(cls)
325        elif is_opt(cls):
326            lst.add(Optional)
327            args = type_args(cls)
328            if args:
329                recursive(args[0])
330        elif is_union(cls):
331            lst.add(Union)
332            for arg in type_args(cls):
333                recursive(arg)
334        elif is_list(cls):
335            lst.add(list)
336            args = type_args(cls)
337            if args:
338                recursive(args[0])
339        elif is_set(cls):
340            lst.add(set)
341            args = type_args(cls)
342            if args:
343                recursive(args[0])
344        elif is_tuple(cls):
345            lst.add(tuple)
346            for arg in type_args(cls):
347                recursive(arg)
348        elif is_dict(cls):
349            lst.add(dict)
350            args = type_args(cls)
351            if args and len(args) >= 2:
352                recursive(args[0])
353                recursive(args[1])
354        elif is_pep695_type_alias(cls):
355            recursive(cls.__value__)
356        else:
357            lst.add(cls)
358
359    recursive(cls)
360    return list(lst)
361
362
363def iter_unions(cls: TypeLike) -> list[TypeLike]:
364    """
365    Iterate over all unions that are used in the dataclass
366    """
367    lst: list[TypeLike] = []
368    stack: list[TypeLike] = []  # To prevent infinite recursion
369
370    def recursive(cls: TypeLike) -> None:
371        if cls in stack:
372            return
373
374        if is_union(cls):
375            lst.append(cls)
376            for arg in type_args(cls):
377                recursive(arg)
378        elif is_pep695_type_alias(cls):
379            recursive(cls.__value__)
380        if is_dataclass(cls):
381            stack.append(cls)
382            for f in dataclass_fields(cls):
383                recursive(f.type)
384            stack.pop()
385        elif is_opt(cls):
386            args = type_args(cls)
387            if args:
388                recursive(args[0])
389        elif is_list(cls) or is_set(cls):
390            args = type_args(cls)
391            if args:
392                recursive(args[0])
393        elif is_tuple(cls):
394            for arg in type_args(cls):
395                recursive(arg)
396        elif is_dict(cls):
397            args = type_args(cls)
398            if args and len(args) >= 2:
399                recursive(args[0])
400                recursive(args[1])
401
402    recursive(cls)
403    return lst
404
405
406def iter_literals(cls: type[Any]) -> list[TypeLike]:
407    """
408    Iterate over all literals that are used in the dataclass
409    """
410    lst: set[Union[type[Any], Any]] = set()
411
412    def recursive(cls: Union[type[Any], Any]) -> None:
413        if cls in lst:
414            return
415
416        if is_literal(cls):
417            lst.add(cls)
418        if is_union(cls):
419            for arg in type_args(cls):
420                recursive(arg)
421        if is_dataclass(cls):
422            lst.add(cls)
423            for f in dataclass_fields(cls):
424                recursive(f.type)
425        elif is_opt(cls):
426            args = type_args(cls)
427            if args:
428                recursive(args[0])
429        elif is_list(cls) or is_set(cls):
430            args = type_args(cls)
431            if args:
432                recursive(args[0])
433        elif is_tuple(cls):
434            for arg in type_args(cls):
435                recursive(arg)
436        elif is_dict(cls):
437            args = type_args(cls)
438            if args and len(args) >= 2:
439                recursive(args[0])
440                recursive(args[1])
441
442    recursive(cls)
443    return list(lst)
444
445
446@cache
447def is_union(typ: Any) -> bool:
448    """
449    Test if the type is `typing.Union`.
450
451    >>> is_union(Union[int, str])
452    True
453    """
454
455    try:
456        # When `_WithTagging` is received, it will check inner type.
457        if isinstance(typ, _WithTagging):
458            return is_union(typ.inner)
459    except Exception:
460        pass
461
462    # Python 3.10 Union operator e.g. str | int
463    if sys.version_info[:2] >= (3, 10):
464        try:
465            if isinstance(typ, types.UnionType):
466                return True
467        except Exception:
468            pass
469
470    # typing.Union
471    return typing_inspect.is_union_type(typ)  # type: ignore
472
473
474@cache
475def is_opt(typ: Any) -> bool:
476    """
477    Test if the type is `typing.Optional`.
478
479    >>> is_opt(Optional[int])
480    True
481    >>> is_opt(Optional)
482    True
483    >>> is_opt(None.__class__)
484    False
485    """
486
487    # Python 3.10 Union operator e.g. str | None
488    is_union_type = False
489    if sys.version_info[:2] >= (3, 10):
490        try:
491            if isinstance(typ, types.UnionType):
492                is_union_type = True
493        except Exception:
494            pass
495
496    # typing.Optional
497    is_typing_union = typing_inspect.is_optional_type(typ)
498
499    args = type_args(typ)
500    if args:
501        return (
502            (is_union_type or is_typing_union)
503            and len(args) == 2
504            and not is_none(args[0])
505            and is_none(args[1])
506        )
507    else:
508        return typ is Optional
509
510
511@cache
512def is_bare_opt(typ: Any) -> bool:
513    """
514    Test if the type is `typing.Optional` without type args.
515    >>> is_bare_opt(Optional[int])
516    False
517    >>> is_bare_opt(Optional)
518    True
519    >>> is_bare_opt(None.__class__)
520    False
521    """
522    return not type_args(typ) and typ is Optional
523
524
525@cache
526def is_opt_dataclass(typ: Any) -> bool:
527    """
528    Test if the type is optional dataclass.
529
530    >>> is_opt_dataclass(Optional[int])
531    False
532    >>> @dataclasses.dataclass
533    ... class Foo:
534    ...     pass
535    >>> is_opt_dataclass(Foo)
536    False
537    >>> is_opt_dataclass(Optional[Foo])
538    False
539    """
540    args = get_args(typ)
541    return is_opt(typ) and len(args) > 0 and is_dataclass(args[0])
542
543
544@cache
545def is_list(typ: type[Any]) -> bool:
546    """
547    Test if the type is `list`.
548
549    >>> is_list(list[int])
550    True
551    >>> is_list(list)
552    True
553    """
554    try:
555        return issubclass(get_origin(typ), list)  # type: ignore
556    except TypeError:
557        return typ is list
558
559
560@cache
561def is_bare_list(typ: type[Any]) -> bool:
562    """
563    Test if the type is `list` without type args.
564
565    >>> is_bare_list(list[int])
566    False
567    >>> is_bare_list(list)
568    True
569    """
570    return typ is list
571
572
573@cache
574def is_tuple(typ: Any) -> bool:
575    """
576    Test if the type is tuple.
577    """
578    try:
579        return issubclass(get_origin(typ), tuple)  # type: ignore
580    except TypeError:
581        return typ is tuple
582
583
584@cache
585def is_bare_tuple(typ: type[Any]) -> bool:
586    """
587    Test if the type is tuple without type args.
588
589    >>> is_bare_tuple(tuple[int, str])
590    False
591    >>> is_bare_tuple(tuple)
592    True
593    """
594    return typ is tuple
595
596
597@cache
598def is_variable_tuple(typ: type[Any]) -> bool:
599    """
600    Test if the type is a variable length of tuple tuple[T, ...]`.
601
602    >>> is_variable_tuple(tuple[int, ...])
603    True
604    >>> is_variable_tuple(tuple[int, bool])
605    False
606    >>> is_variable_tuple(tuple[()])
607    False
608    """
609    istuple = is_tuple(typ) and not is_bare_tuple(typ)
610    args = get_args(typ)
611    return istuple and len(args) == 2 and is_ellipsis(args[1])
612
613
614@cache
615def is_set(typ: type[Any]) -> bool:
616    """
617    Test if the type is `set` or `frozenset`.
618
619    >>> is_set(set[int])
620    True
621    >>> is_set(set)
622    True
623    >>> is_set(frozenset[int])
624    True
625    """
626    try:
627        return issubclass(get_origin(typ), (set, frozenset))  # type: ignore
628    except TypeError:
629        return typ in (set, frozenset)
630
631
632@cache
633def is_bare_set(typ: type[Any]) -> bool:
634    """
635    Test if the type is `set` without type args.
636
637    >>> is_bare_set(set[int])
638    False
639    >>> is_bare_set(set)
640    True
641    """
642    return typ in (set, frozenset)
643
644
645@cache
646def is_frozen_set(typ: type[Any]) -> bool:
647    """
648    Test if the type is `frozenset`.
649
650    >>> is_frozen_set(frozenset[int])
651    True
652    >>> is_frozen_set(set)
653    False
654    """
655    try:
656        return issubclass(get_origin(typ), frozenset)  # type: ignore
657    except TypeError:
658        return typ is frozenset
659
660
661@cache
662def is_dict(typ: type[Any]) -> bool:
663    """
664    Test if the type is dict.
665
666    >>> is_dict(dict[int, int])
667    True
668    >>> is_dict(dict)
669    True
670    >>> is_dict(defaultdict[int, int])
671    True
672    """
673    try:
674        return issubclass(get_origin(typ), (dict, defaultdict))  # type: ignore
675    except TypeError:
676        return typ in (dict, defaultdict)
677
678
679@cache
680def is_bare_dict(typ: type[Any]) -> bool:
681    """
682    Test if the type is `dict` without type args.
683
684    >>> is_bare_dict(dict[int, str])
685    False
686    >>> is_bare_dict(dict)
687    True
688    """
689    return typ is dict
690
691
692@cache
693def is_default_dict(typ: type[Any]) -> bool:
694    """
695    Test if the type is `defaultdict`.
696
697    >>> is_default_dict(defaultdict[int, int])
698    True
699    >>> is_default_dict(dict[int, int])
700    False
701    """
702    try:
703        return issubclass(get_origin(typ), defaultdict)  # type: ignore
704    except TypeError:
705        return typ is defaultdict
706
707
708@cache
709def is_none(typ: type[Any]) -> bool:
710    """
711    >>> is_none(int)
712    False
713    >>> is_none(type(None))
714    True
715    >>> is_none(None)
716    False
717    """
718    return typ is type(None)  # noqa
719
720
721PRIMITIVES = [int, float, bool, str]
722
723
724@cache
725def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
726    """
727    Test if the type is `enum.Enum`.
728    """
729    try:
730        return issubclass(typ, enum.Enum)
731    except TypeError:
732        return isinstance(typ, enum.Enum)
733
734
735@cache
736def is_primitive_subclass(typ: type[Any]) -> bool:
737    """
738    Test if the type is a subclass of primitive type.
739
740    >>> is_primitive_subclass(str)
741    False
742    >>> class Str(str):
743    ...     pass
744    >>> is_primitive_subclass(Str)
745    True
746    """
747    return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)
748
749
750@cache
751def is_primitive(typ: Union[type[Any], NewType]) -> bool:
752    """
753    Test if the type is primitive.
754
755    >>> is_primitive(int)
756    True
757    >>> class CustomInt(int):
758    ...     pass
759    >>> is_primitive(CustomInt)
760    True
761    """
762    try:
763        return any(issubclass(typ, ty) for ty in PRIMITIVES)  # type: ignore
764    except TypeError:
765        return is_new_type_primitive(typ)
766
767
768@cache
769def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
770    """
771    Test if the type is a NewType of primitives.
772    """
773    inner = getattr(typ, "__supertype__", None)
774    if inner:
775        return is_primitive(inner)
776    else:
777        return any(isinstance(typ, ty) for ty in PRIMITIVES)
778
779
780@cache
781def has_generic_base(typ: Any) -> bool:
782    return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ())
783
784
785@cache
786def is_generic(typ: Any) -> bool:
787    """
788    Test if the type is derived from `typing.Generic`.
789
790    >>> T = typing.TypeVar('T')
791    >>> class GenericFoo(typing.Generic[T]):
792    ...     pass
793    >>> is_generic(GenericFoo[int])
794    True
795    >>> is_generic(GenericFoo)
796    False
797    """
798    origin = get_origin(typ)
799    return origin is not None and has_generic_base(origin)
800
801
802@cache
803def is_class_var(typ: type[Any]) -> bool:
804    """
805    Test if the type is `typing.ClassVar`.
806
807    >>> is_class_var(ClassVar[int])
808    True
809    >>> is_class_var(ClassVar)
810    True
811    """
812    return get_origin(typ) is ClassVar or typ is ClassVar  # type: ignore
813
814
815@cache
816def is_literal(typ: type[Any]) -> bool:
817    """
818    Test if the type is derived from `typing.Literal`.
819
820    >>> T = typing.TypeVar('T')
821    >>> class GenericFoo(typing.Generic[T]):
822    ...     pass
823    >>> is_generic(GenericFoo[int])
824    True
825    >>> is_generic(GenericFoo)
826    False
827    """
828    origin = get_origin(typ)
829    return origin is not None and origin is typing.Literal
830
831
832@cache
833def is_any(typ: type[Any]) -> bool:
834    """
835    Test if the type is `typing.Any`.
836    """
837    return typ is Any  # type: ignore
838
839
840@cache
841def is_str_serializable(typ: type[Any]) -> bool:
842    """
843    Test if the type is serializable to `str`.
844    """
845    return typ in StrSerializableTypes or (
846        type(typ) is type and issubclass(typ, StrSerializableTypes)
847    )
848
849
850def is_datetime(
851    typ: type[Any],
852) -> TypeGuard[Union[datetime.date, datetime.time, datetime.datetime]]:
853    """
854    Test if the type is any of the datetime types..
855    """
856    return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes))
857
858
859def is_str_serializable_instance(obj: Any) -> bool:
860    return isinstance(obj, StrSerializableTypes)
861
862
863def is_datetime_instance(obj: Any) -> bool:
864    return isinstance(obj, DateTimeTypes)
865
866
867def is_ellipsis(typ: Any) -> bool:
868    return typ is Ellipsis
869
870
871def is_pep695_type_alias(typ: Any) -> bool:
872    """
873    Test if the type is of PEP695 type alias.
874    """
875    return isinstance(typ, TypeAliasType)
876
877
878@cache
879def get_type_var_names(cls: type[Any]) -> Optional[list[str]]:
880    """
881    Get type argument names of a generic class.
882
883    >>> T = typing.TypeVar('T')
884    >>> class GenericFoo(typing.Generic[T]):
885    ...     pass
886    >>> get_type_var_names(GenericFoo)
887    ['T']
888    >>> get_type_var_names(int)
889    """
890    bases = getattr(cls, "__orig_bases__", ())
891    if not bases:
892        return None
893
894    type_arg_names: list[str] = []
895    for base in bases:
896        type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__"))
897
898    return type_arg_names
899
900
901def find_generic_arg(cls: type[Any], field: TypeVar) -> int:
902    """
903    Find a type in generic parameters.
904
905    >>> T = typing.TypeVar('T')
906    >>> U = typing.TypeVar('U')
907    >>> V = typing.TypeVar('V')
908    >>> class GenericFoo(typing.Generic[T, U]):
909    ...     pass
910    >>> find_generic_arg(GenericFoo, T)
911    0
912    >>> find_generic_arg(GenericFoo, U)
913    1
914    >>> find_generic_arg(GenericFoo, V)
915    -1
916    """
917    bases = getattr(cls, "__orig_bases__", ())
918    if not bases:
919        raise Exception(f'"__orig_bases__" property was not found: {cls}')
920
921    for base in bases:
922        for n, arg in enumerate(get_args(base)):
923            if arg.__name__ == field.__name__:
924                return n
925
926    if not bases:
927        raise Exception(f"Generic field not found in class: {bases}")
928
929    return -1
930
931
932def get_generic_arg(
933    typ: Any,
934    maybe_generic_type_vars: Optional[list[str]],
935    variable_type_args: Optional[list[str]],
936    index: int,
937) -> Any:
938    """
939    Get generic type argument.
940
941    >>> T = typing.TypeVar('T')
942    >>> U = typing.TypeVar('U')
943    >>> class GenericFoo(typing.Generic[T, U]):
944    ...     pass
945    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__
946    'int'
947    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__
948    'str'
949    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__
950    'str'
951    """
952    if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None:
953        return typing.Any
954
955    args = get_args(typ)
956
957    if len(args) != len(maybe_generic_type_vars):
958        raise SerdeError(
959            f"Number of type args for {typ} does not match number of generic type vars: "
960            f"\n  type args: {args}\n  type_vars: {maybe_generic_type_vars}"
961        )
962
963    # Get the name of the type var used for this field in the parent class definition
964    type_var_name = variable_type_args[index]
965
966    try:
967        # Find the slot of that type var in the original generic class definition
968        orig_index = maybe_generic_type_vars.index(type_var_name)
969    except ValueError:
970        return typing.Any
971
972    return args[orig_index]