Edit on GitHub

serde.compat

Compatibility layer which handles mostly differences of typing module between python versions.

  1"""
  2Compatibility layer which handles mostly differences of `typing` module between python versions.
  3"""
  4
  5import dataclasses
  6import datetime
  7import decimal
  8import enum
  9import functools
 10import ipaddress
 11import itertools
 12import pathlib
 13import sys
 14import types
 15import uuid
 16import typing
 17from collections import defaultdict
 18from collections.abc import Iterator
 19from dataclasses import is_dataclass
 20from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable
 21
 22import typing_inspect
 23from typing_extensions import TypeGuard, ParamSpec
 24
 25from .sqlalchemy import is_sqlalchemy_inspectable
 26
 27
 28def get_np_origin(tp: type[Any]) -> Optional[Any]:
 29    return None
 30
 31
 32def get_np_args(tp: type[Any]) -> tuple[Any, ...]:
 33    return ()
 34
 35
 36__all__: list[str] = []
 37
 38T = TypeVar("T")
 39
 40
 41StrSerializableTypes = (
 42    decimal.Decimal,
 43    pathlib.Path,
 44    pathlib.PosixPath,
 45    pathlib.WindowsPath,
 46    pathlib.PurePath,
 47    pathlib.PurePosixPath,
 48    pathlib.PureWindowsPath,
 49    uuid.UUID,
 50    ipaddress.IPv4Address,
 51    ipaddress.IPv6Address,
 52    ipaddress.IPv4Network,
 53    ipaddress.IPv6Network,
 54    ipaddress.IPv4Interface,
 55    ipaddress.IPv6Interface,
 56)
 57""" List of standard types (de)serializable to str """
 58
 59DateTimeTypes = (datetime.date, datetime.time, datetime.datetime)
 60""" List of datetime types """
 61
 62
 63@dataclasses.dataclass(unsafe_hash=True)
 64class _WithTagging(Generic[T]):
 65    """
 66    Intermediate data structure for (de)serializaing Union without dataclass.
 67    """
 68
 69    inner: T
 70    """ Union type .e.g Union[Foo,Bar] passed in from_obj. """
 71    tagging: Any
 72    """ Union Tagging """
 73
 74
 75class SerdeError(Exception):
 76    """
 77    Serde error class.
 78    """
 79
 80
 81@dataclasses.dataclass
 82class UserError(Exception):
 83    """
 84    Error from user code e.g. __post_init__.
 85    """
 86
 87    inner: Exception
 88
 89
 90class SerdeSkip(Exception):
 91    """
 92    Skip a field in custom (de)serializer.
 93    """
 94
 95
 96def is_hashable(typ: Any) -> TypeGuard[Hashable]:
 97    """
 98    Test is an object hashable
 99    """
100    try:
101        hash(typ)
102    except TypeError:
103        return False
104    return True
105
106
107P = ParamSpec("P")
108
109
110def cache(f: Callable[P, T]) -> Callable[P, T]:
111    """
112    Wrapper for `functools.cache` to avoid `Hashable` related type errors.
113    """
114
115    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
116        return f(*args, **kwargs)
117
118    return functools.cache(wrapper)  # type: ignore
119
120
121@cache
122def get_origin(typ: Any) -> Optional[Any]:
123    """
124    Provide `get_origin` that works in all python versions.
125    """
126    return typing.get_origin(typ) or get_np_origin(typ)
127
128
129@cache
130def get_args(typ: type[Any]) -> tuple[Any, ...]:
131    """
132    Provide `get_args` that works in all python versions.
133    """
134    return typing.get_args(typ) or get_np_args(typ)
135
136
137@cache
138def typename(typ: Any, with_typing_module: bool = False) -> str:
139    """
140    >>> from typing import Any
141    >>> typename(int)
142    'int'
143    >>> class Foo: pass
144    >>> typename(Foo)
145    'Foo'
146    >>> typename(list[Foo])
147    'list[Foo]'
148    >>> typename(dict[str, Foo])
149    'dict[str, Foo]'
150    >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]])
151    'tuple[int, str, Foo, list[int], dict[str, Foo]]'
152    >>> typename(Optional[list[Foo]])
153    'Optional[list[Foo]]'
154    >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]])
155    'Union[Optional[Foo], list[Foo], str, int]'
156    >>> typename(set[Foo])
157    'set[Foo]'
158    >>> typename(Any)
159    'Any'
160    """
161    mod = "typing." if with_typing_module else ""
162    thisfunc = functools.partial(typename, with_typing_module=with_typing_module)
163    if is_opt(typ):
164        args = type_args(typ)
165        if args:
166            return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]"
167        else:
168            return f"{mod}Optional"
169    elif is_union(typ):
170        args = union_args(typ)
171        if args:
172            return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]'
173        else:
174            return f"{mod}Union"
175    elif is_list(typ):
176        args = type_args(typ)
177        if args:
178            et = thisfunc(args[0])
179            return f"{mod}list[{et}]"
180        else:
181            return f"{mod}list"
182    elif is_set(typ):
183        args = type_args(typ)
184        if args:
185            et = thisfunc(args[0])
186            return f"{mod}set[{et}]"
187        else:
188            return f"{mod}set"
189    elif is_dict(typ):
190        args = type_args(typ)
191        if args and len(args) == 2:
192            kt = thisfunc(args[0])
193            vt = thisfunc(args[1])
194            return f"{mod}dict[{kt}, {vt}]"
195        else:
196            return f"{mod}dict"
197    elif is_tuple(typ):
198        args = type_args(typ)
199        if args:
200            return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]'
201        else:
202            return f"{mod}tuple"
203    elif is_generic(typ):
204        origin = get_origin(typ)
205        if origin is None:
206            raise SerdeError("Could not extract origin class from generic class")
207
208        if not isinstance(origin.__name__, str):
209            raise SerdeError("Name of generic class is not string")
210
211        return origin.__name__
212
213    elif is_literal(typ):
214        args = type_args(typ)
215        if not args:
216            raise TypeError("Literal type requires at least one literal argument")
217        return f'Literal[{", ".join(stringify_literal(e) for e in args)}]'
218    elif typ is Any:
219        return f"{mod}Any"
220    elif is_ellipsis(typ):
221        return "..."
222    else:
223        # Get super type for NewType
224        inner = getattr(typ, "__supertype__", None)
225        if inner:
226            return typename(inner)
227
228        name: Optional[str] = getattr(typ, "_name", None)
229        if name:
230            return name
231        else:
232            name = getattr(typ, "__name__", None)
233            if isinstance(name, str):
234                return name
235            else:
236                raise SerdeError(f"Could not get a type name from: {typ}")
237
238
239def stringify_literal(v: Any) -> str:
240    if isinstance(v, str):
241        return f"'{v}'"
242    else:
243        return str(v)
244
245
246def type_args(typ: Any) -> tuple[type[Any], ...]:
247    """
248    Wrapper to suppress type error for accessing private members.
249    """
250    try:
251        args: tuple[type[Any, ...]] = typ.__args__  # type: ignore
252        if args is None:
253            return ()
254        else:
255            return args
256    except AttributeError:
257        return get_args(typ)
258
259
260def union_args(typ: Any) -> tuple[type[Any], ...]:
261    if not is_union(typ):
262        raise TypeError(f"{typ} is not Union")
263    args = type_args(typ)
264    if len(args) == 1:
265        return (args[0],)
266    it = iter(args)
267    types = []
268    for i1, i2 in itertools.zip_longest(it, it):
269        if not i2:
270            types.append(i1)
271        elif is_none(i2):
272            types.append(Optional[i1])
273        else:
274            types.extend((i1, i2))
275    return tuple(types)
276
277
278def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]:  # type: ignore
279    raw_fields = dataclasses.fields(cls)
280
281    try:
282        # this resolves types when string forward reference
283        # or PEP 563: "from __future__ import annotations" are used
284        resolved_hints = typing.get_type_hints(cls)
285    except Exception as e:
286        raise SerdeError(
287            f"Failed to resolve type hints for {typename(cls)}:\n"
288            f"{e.__class__.__name__}: {e}\n\n"
289            f"If you are using forward references make sure you are calling deserialize & "
290            "serialize after all classes are globally visible."
291        ) from e
292
293    for f in raw_fields:
294        real_type = resolved_hints.get(f.name)
295        if real_type is not None:
296            f.type = real_type
297            if is_generic(real_type) and is_sqlalchemy_inspectable(cls):
298                f.type = get_args(real_type)[0]
299
300    return iter(raw_fields)
301
302
303TypeLike = Union[type[Any], typing.Any]
304
305
306def iter_types(cls: type[Any]) -> list[type[Any]]:
307    """
308    Iterate field types recursively.
309
310    The correct return type is `Iterator[Union[Type, typing._specialform]],
311    but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
312    """
313    lst: set[Union[type[Any], Any]] = set()
314
315    def recursive(cls: Union[type[Any], Any]) -> None:
316        if cls in lst:
317            return
318
319        if is_dataclass(cls):
320            lst.add(cls)
321            for f in dataclass_fields(cls):
322                recursive(f.type)
323        elif isinstance(cls, str):
324            lst.add(cls)
325        elif is_opt(cls):
326            lst.add(Optional)
327            args = type_args(cls)
328            if args:
329                recursive(args[0])
330        elif is_union(cls):
331            lst.add(Union)
332            for arg in type_args(cls):
333                recursive(arg)
334        elif is_list(cls):
335            lst.add(list)
336            args = type_args(cls)
337            if args:
338                recursive(args[0])
339        elif is_set(cls):
340            lst.add(set)
341            args = type_args(cls)
342            if args:
343                recursive(args[0])
344        elif is_tuple(cls):
345            lst.add(tuple)
346            for arg in type_args(cls):
347                recursive(arg)
348        elif is_dict(cls):
349            lst.add(dict)
350            args = type_args(cls)
351            if args and len(args) >= 2:
352                recursive(args[0])
353                recursive(args[1])
354        else:
355            lst.add(cls)
356
357    recursive(cls)
358    return list(lst)
359
360
361def iter_unions(cls: TypeLike) -> list[TypeLike]:
362    """
363    Iterate over all unions that are used in the dataclass
364    """
365    lst: list[TypeLike] = []
366    stack: list[TypeLike] = []  # To prevent infinite recursion
367
368    def recursive(cls: TypeLike) -> None:
369        if cls in stack:
370            return
371
372        if is_union(cls):
373            lst.append(cls)
374            for arg in type_args(cls):
375                recursive(arg)
376        if is_dataclass(cls):
377            stack.append(cls)
378            for f in dataclass_fields(cls):
379                recursive(f.type)
380            stack.pop()
381        elif is_opt(cls):
382            args = type_args(cls)
383            if args:
384                recursive(args[0])
385        elif is_list(cls) or is_set(cls):
386            args = type_args(cls)
387            if args:
388                recursive(args[0])
389        elif is_tuple(cls):
390            for arg in type_args(cls):
391                recursive(arg)
392        elif is_dict(cls):
393            args = type_args(cls)
394            if args and len(args) >= 2:
395                recursive(args[0])
396                recursive(args[1])
397
398    recursive(cls)
399    return lst
400
401
402def iter_literals(cls: type[Any]) -> list[TypeLike]:
403    """
404    Iterate over all literals that are used in the dataclass
405    """
406    lst: set[Union[type[Any], Any]] = set()
407
408    def recursive(cls: Union[type[Any], Any]) -> None:
409        if cls in lst:
410            return
411
412        if is_literal(cls):
413            lst.add(cls)
414        if is_union(cls):
415            for arg in type_args(cls):
416                recursive(arg)
417        if is_dataclass(cls):
418            lst.add(cls)
419            for f in dataclass_fields(cls):
420                recursive(f.type)
421        elif is_opt(cls):
422            args = type_args(cls)
423            if args:
424                recursive(args[0])
425        elif is_list(cls) or is_set(cls):
426            args = type_args(cls)
427            if args:
428                recursive(args[0])
429        elif is_tuple(cls):
430            for arg in type_args(cls):
431                recursive(arg)
432        elif is_dict(cls):
433            args = type_args(cls)
434            if args and len(args) >= 2:
435                recursive(args[0])
436                recursive(args[1])
437
438    recursive(cls)
439    return list(lst)
440
441
442@cache
443def is_union(typ: Any) -> bool:
444    """
445    Test if the type is `typing.Union`.
446
447    >>> is_union(Union[int, str])
448    True
449    """
450
451    try:
452        # When `_WithTagging` is received, it will check inner type.
453        if isinstance(typ, _WithTagging):
454            return is_union(typ.inner)
455    except Exception:
456        pass
457
458    # Python 3.10 Union operator e.g. str | int
459    if sys.version_info[:2] >= (3, 10):
460        try:
461            if isinstance(typ, types.UnionType):
462                return True
463        except Exception:
464            pass
465
466    # typing.Union
467    return typing_inspect.is_union_type(typ)  # type: ignore
468
469
470@cache
471def is_opt(typ: Any) -> bool:
472    """
473    Test if the type is `typing.Optional`.
474
475    >>> is_opt(Optional[int])
476    True
477    >>> is_opt(Optional)
478    True
479    >>> is_opt(None.__class__)
480    False
481    """
482
483    # Python 3.10 Union operator e.g. str | None
484    is_union_type = False
485    if sys.version_info[:2] >= (3, 10):
486        try:
487            if isinstance(typ, types.UnionType):
488                is_union_type = True
489        except Exception:
490            pass
491
492    # typing.Optional
493    is_typing_union = typing_inspect.is_optional_type(typ)
494
495    args = type_args(typ)
496    if args:
497        return (
498            (is_union_type or is_typing_union)
499            and len(args) == 2
500            and not is_none(args[0])
501            and is_none(args[1])
502        )
503    else:
504        return typ is Optional
505
506
507@cache
508def is_bare_opt(typ: Any) -> bool:
509    """
510    Test if the type is `typing.Optional` without type args.
511    >>> is_bare_opt(Optional[int])
512    False
513    >>> is_bare_opt(Optional)
514    True
515    >>> is_bare_opt(None.__class__)
516    False
517    """
518    return not type_args(typ) and typ is Optional
519
520
521@cache
522def is_opt_dataclass(typ: Any) -> bool:
523    """
524    Test if the type is optional dataclass.
525
526    >>> is_opt_dataclass(Optional[int])
527    False
528    >>> @dataclasses.dataclass
529    ... class Foo:
530    ...     pass
531    >>> is_opt_dataclass(Foo)
532    False
533    >>> is_opt_dataclass(Optional[Foo])
534    False
535    """
536    args = get_args(typ)
537    return is_opt(typ) and len(args) > 0 and is_dataclass(args[0])
538
539
540@cache
541def is_list(typ: type[Any]) -> bool:
542    """
543    Test if the type is `list`.
544
545    >>> is_list(list[int])
546    True
547    >>> is_list(list)
548    True
549    """
550    try:
551        return issubclass(get_origin(typ), list)  # type: ignore
552    except TypeError:
553        return typ is list
554
555
556@cache
557def is_bare_list(typ: type[Any]) -> bool:
558    """
559    Test if the type is `list` without type args.
560
561    >>> is_bare_list(list[int])
562    False
563    >>> is_bare_list(list)
564    True
565    """
566    return typ is list
567
568
569@cache
570def is_tuple(typ: Any) -> bool:
571    """
572    Test if the type is tuple.
573    """
574    try:
575        return issubclass(get_origin(typ), tuple)  # type: ignore
576    except TypeError:
577        return typ is tuple
578
579
580@cache
581def is_bare_tuple(typ: type[Any]) -> bool:
582    """
583    Test if the type is tuple without type args.
584
585    >>> is_bare_tuple(tuple[int, str])
586    False
587    >>> is_bare_tuple(tuple)
588    True
589    """
590    return typ is tuple
591
592
593@cache
594def is_variable_tuple(typ: type[Any]) -> bool:
595    """
596    Test if the type is a variable length of tuple tuple[T, ...]`.
597
598    >>> is_variable_tuple(tuple[int, ...])
599    True
600    >>> is_variable_tuple(tuple[int, bool])
601    False
602    >>> is_variable_tuple(tuple[()])
603    False
604    """
605    istuple = is_tuple(typ) and not is_bare_tuple(typ)
606    args = get_args(typ)
607    return istuple and len(args) == 2 and is_ellipsis(args[1])
608
609
610@cache
611def is_set(typ: type[Any]) -> bool:
612    """
613    Test if the type is `set` or `frozenset`.
614
615    >>> is_set(set[int])
616    True
617    >>> is_set(set)
618    True
619    >>> is_set(frozenset[int])
620    True
621    """
622    try:
623        return issubclass(get_origin(typ), (set, frozenset))  # type: ignore
624    except TypeError:
625        return typ in (set, frozenset)
626
627
628@cache
629def is_bare_set(typ: type[Any]) -> bool:
630    """
631    Test if the type is `set` without type args.
632
633    >>> is_bare_set(set[int])
634    False
635    >>> is_bare_set(set)
636    True
637    """
638    return typ in (set, frozenset)
639
640
641@cache
642def is_frozen_set(typ: type[Any]) -> bool:
643    """
644    Test if the type is `frozenset`.
645
646    >>> is_frozen_set(frozenset[int])
647    True
648    >>> is_frozen_set(set)
649    False
650    """
651    try:
652        return issubclass(get_origin(typ), frozenset)  # type: ignore
653    except TypeError:
654        return typ is frozenset
655
656
657@cache
658def is_dict(typ: type[Any]) -> bool:
659    """
660    Test if the type is dict.
661
662    >>> is_dict(dict[int, int])
663    True
664    >>> is_dict(dict)
665    True
666    >>> is_dict(defaultdict[int, int])
667    True
668    """
669    try:
670        return issubclass(get_origin(typ), (dict, defaultdict))  # type: ignore
671    except TypeError:
672        return typ in (dict, defaultdict)
673
674
675@cache
676def is_bare_dict(typ: type[Any]) -> bool:
677    """
678    Test if the type is `dict` without type args.
679
680    >>> is_bare_dict(dict[int, str])
681    False
682    >>> is_bare_dict(dict)
683    True
684    """
685    return typ is dict
686
687
688@cache
689def is_default_dict(typ: type[Any]) -> bool:
690    """
691    Test if the type is `defaultdict`.
692
693    >>> is_default_dict(defaultdict[int, int])
694    True
695    >>> is_default_dict(dict[int, int])
696    False
697    """
698    try:
699        return issubclass(get_origin(typ), defaultdict)  # type: ignore
700    except TypeError:
701        return typ is defaultdict
702
703
704@cache
705def is_none(typ: type[Any]) -> bool:
706    """
707    >>> is_none(int)
708    False
709    >>> is_none(type(None))
710    True
711    >>> is_none(None)
712    False
713    """
714    return typ is type(None)  # noqa
715
716
717PRIMITIVES = [int, float, bool, str]
718
719
720@cache
721def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
722    """
723    Test if the type is `enum.Enum`.
724    """
725    try:
726        return issubclass(typ, enum.Enum)
727    except TypeError:
728        return isinstance(typ, enum.Enum)
729
730
731@cache
732def is_primitive_subclass(typ: type[Any]) -> bool:
733    """
734    Test if the type is a subclass of primitive type.
735
736    >>> is_primitive_subclass(str)
737    False
738    >>> class Str(str):
739    ...     pass
740    >>> is_primitive_subclass(Str)
741    True
742    """
743    return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)
744
745
746@cache
747def is_primitive(typ: Union[type[Any], NewType]) -> bool:
748    """
749    Test if the type is primitive.
750
751    >>> is_primitive(int)
752    True
753    >>> class CustomInt(int):
754    ...     pass
755    >>> is_primitive(CustomInt)
756    True
757    """
758    try:
759        return any(issubclass(typ, ty) for ty in PRIMITIVES)  # type: ignore
760    except TypeError:
761        return is_new_type_primitive(typ)
762
763
764@cache
765def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
766    """
767    Test if the type is a NewType of primitives.
768    """
769    inner = getattr(typ, "__supertype__", None)
770    if inner:
771        return is_primitive(inner)
772    else:
773        return any(isinstance(typ, ty) for ty in PRIMITIVES)
774
775
776@cache
777def has_generic_base(typ: Any) -> bool:
778    return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ())
779
780
781@cache
782def is_generic(typ: Any) -> bool:
783    """
784    Test if the type is derived from `typing.Generic`.
785
786    >>> T = typing.TypeVar('T')
787    >>> class GenericFoo(typing.Generic[T]):
788    ...     pass
789    >>> is_generic(GenericFoo[int])
790    True
791    >>> is_generic(GenericFoo)
792    False
793    """
794    origin = get_origin(typ)
795    return origin is not None and has_generic_base(origin)
796
797
798@cache
799def is_class_var(typ: type[Any]) -> bool:
800    """
801    Test if the type is `typing.ClassVar`.
802
803    >>> is_class_var(ClassVar[int])
804    True
805    >>> is_class_var(ClassVar)
806    True
807    """
808    return get_origin(typ) is ClassVar or typ is ClassVar  # type: ignore
809
810
811@cache
812def is_literal(typ: type[Any]) -> bool:
813    """
814    Test if the type is derived from `typing.Literal`.
815
816    >>> T = typing.TypeVar('T')
817    >>> class GenericFoo(typing.Generic[T]):
818    ...     pass
819    >>> is_generic(GenericFoo[int])
820    True
821    >>> is_generic(GenericFoo)
822    False
823    """
824    origin = get_origin(typ)
825    return origin is not None and origin is typing.Literal
826
827
828@cache
829def is_any(typ: type[Any]) -> bool:
830    """
831    Test if the type is `typing.Any`.
832    """
833    return typ is Any  # type: ignore
834
835
836@cache
837def is_str_serializable(typ: type[Any]) -> bool:
838    """
839    Test if the type is serializable to `str`.
840    """
841    return typ in StrSerializableTypes or (
842        type(typ) is type and issubclass(typ, StrSerializableTypes)
843    )
844
845
846def is_datetime(
847    typ: type[Any],
848) -> TypeGuard[Union[datetime.date, datetime.time, datetime.datetime]]:
849    """
850    Test if the type is any of the datetime types..
851    """
852    return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes))
853
854
855def is_str_serializable_instance(obj: Any) -> bool:
856    return isinstance(obj, StrSerializableTypes)
857
858
859def is_datetime_instance(obj: Any) -> bool:
860    return isinstance(obj, DateTimeTypes)
861
862
863def is_ellipsis(typ: Any) -> bool:
864    return typ is Ellipsis
865
866
867@cache
868def get_type_var_names(cls: type[Any]) -> Optional[list[str]]:
869    """
870    Get type argument names of a generic class.
871
872    >>> T = typing.TypeVar('T')
873    >>> class GenericFoo(typing.Generic[T]):
874    ...     pass
875    >>> get_type_var_names(GenericFoo)
876    ['T']
877    >>> get_type_var_names(int)
878    """
879    bases = getattr(cls, "__orig_bases__", ())
880    if not bases:
881        return None
882
883    type_arg_names: list[str] = []
884    for base in bases:
885        type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__"))
886
887    return type_arg_names
888
889
890def find_generic_arg(cls: type[Any], field: TypeVar) -> int:
891    """
892    Find a type in generic parameters.
893
894    >>> T = typing.TypeVar('T')
895    >>> U = typing.TypeVar('U')
896    >>> V = typing.TypeVar('V')
897    >>> class GenericFoo(typing.Generic[T, U]):
898    ...     pass
899    >>> find_generic_arg(GenericFoo, T)
900    0
901    >>> find_generic_arg(GenericFoo, U)
902    1
903    >>> find_generic_arg(GenericFoo, V)
904    -1
905    """
906    bases = getattr(cls, "__orig_bases__", ())
907    if not bases:
908        raise Exception(f'"__orig_bases__" property was not found: {cls}')
909
910    for base in bases:
911        for n, arg in enumerate(get_args(base)):
912            if arg.__name__ == field.__name__:
913                return n
914
915    if not bases:
916        raise Exception(f"Generic field not found in class: {bases}")
917
918    return -1
919
920
921def get_generic_arg(
922    typ: Any,
923    maybe_generic_type_vars: Optional[list[str]],
924    variable_type_args: Optional[list[str]],
925    index: int,
926) -> Any:
927    """
928    Get generic type argument.
929
930    >>> T = typing.TypeVar('T')
931    >>> U = typing.TypeVar('U')
932    >>> class GenericFoo(typing.Generic[T, U]):
933    ...     pass
934    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__
935    'int'
936    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__
937    'str'
938    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__
939    'str'
940    """
941    if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None:
942        return typing.Any
943
944    args = get_args(typ)
945
946    if len(args) != len(maybe_generic_type_vars):
947        raise SerdeError(
948            f"Number of type args for {typ} does not match number of generic type vars: "
949            f"\n  type args: {args}\n  type_vars: {maybe_generic_type_vars}"
950        )
951
952    # Get the name of the type var used for this field in the parent class definition
953    type_var_name = variable_type_args[index]
954
955    try:
956        # Find the slot of that type var in the original generic class definition
957        orig_index = maybe_generic_type_vars.index(type_var_name)
958    except ValueError:
959        return typing.Any
960
961    return args[orig_index]