Edit on GitHub

serde.compat

Compatibility layer which handles mostly differences of typing module between python versions.

  1"""
  2Compatibility layer which handles mostly differences of `typing` module between python versions.
  3"""
  4
  5import dataclasses
  6import datetime
  7import decimal
  8import enum
  9import functools
 10import ipaddress
 11import itertools
 12import pathlib
 13import sys
 14import types
 15import uuid
 16import typing
 17from collections import defaultdict
 18from collections.abc import Iterator
 19from dataclasses import is_dataclass
 20from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union
 21
 22import typing_inspect
 23from typing_extensions import TypeGuard
 24
 25
 26def get_np_origin(tp: type[Any]) -> Optional[Any]:
 27    return None
 28
 29
 30def get_np_args(tp: Any) -> tuple[Any, ...]:
 31    return ()
 32
 33
 34__all__: list[str] = []
 35
 36T = TypeVar("T")
 37
 38
 39StrSerializableTypes = (
 40    decimal.Decimal,
 41    pathlib.Path,
 42    pathlib.PosixPath,
 43    pathlib.WindowsPath,
 44    pathlib.PurePath,
 45    pathlib.PurePosixPath,
 46    pathlib.PureWindowsPath,
 47    uuid.UUID,
 48    ipaddress.IPv4Address,
 49    ipaddress.IPv6Address,
 50    ipaddress.IPv4Network,
 51    ipaddress.IPv6Network,
 52    ipaddress.IPv4Interface,
 53    ipaddress.IPv6Interface,
 54)
 55""" List of standard types (de)serializable to str """
 56
 57DateTimeTypes = (datetime.date, datetime.time, datetime.datetime)
 58""" List of datetime types """
 59
 60
 61@dataclasses.dataclass
 62class _WithTagging(Generic[T]):
 63    """
 64    Intermediate data structure for (de)serializaing Union without dataclass.
 65    """
 66
 67    inner: T
 68    """ Union type .e.g Union[Foo,Bar] passed in from_obj. """
 69    tagging: Any
 70    """ Union Tagging """
 71
 72
 73class SerdeError(Exception):
 74    """
 75    Serde error class.
 76    """
 77
 78
 79@dataclasses.dataclass
 80class UserError(Exception):
 81    """
 82    Error from user code e.g. __post_init__.
 83    """
 84
 85    inner: Exception
 86
 87
 88class SerdeSkip(Exception):
 89    """
 90    Skip a field in custom (de)serializer.
 91    """
 92
 93
 94def get_origin(typ: Any) -> Optional[Any]:
 95    """
 96    Provide `get_origin` that works in all python versions.
 97    """
 98    return typing.get_origin(typ) or get_np_origin(typ)
 99
100
101def get_args(typ: Any) -> tuple[Any, ...]:
102    """
103    Provide `get_args` that works in all python versions.
104    """
105    return typing.get_args(typ) or get_np_args(typ)
106
107
108def typename(typ: Any, with_typing_module: bool = False) -> str:
109    """
110    >>> from typing import Any
111    >>> typename(int)
112    'int'
113    >>> class Foo: pass
114    >>> typename(Foo)
115    'Foo'
116    >>> typename(list[Foo])
117    'list[Foo]'
118    >>> typename(dict[str, Foo])
119    'dict[str, Foo]'
120    >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]])
121    'tuple[int, str, Foo, list[int], dict[str, Foo]]'
122    >>> typename(Optional[list[Foo]])
123    'Optional[list[Foo]]'
124    >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]])
125    'Union[Optional[Foo], list[Foo], str, int]'
126    >>> typename(set[Foo])
127    'set[Foo]'
128    >>> typename(Any)
129    'Any'
130    """
131    mod = "typing." if with_typing_module else ""
132    thisfunc = functools.partial(typename, with_typing_module=with_typing_module)
133    if is_opt(typ):
134        args = type_args(typ)
135        if args:
136            return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]"
137        else:
138            return f"{mod}Optional"
139    elif is_union(typ):
140        args = union_args(typ)
141        if args:
142            return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]'
143        else:
144            return f"{mod}Union"
145    elif is_list(typ):
146        args = type_args(typ)
147        if args:
148            et = thisfunc(args[0])
149            return f"{mod}list[{et}]"
150        else:
151            return f"{mod}list"
152    elif is_set(typ):
153        args = type_args(typ)
154        if args:
155            et = thisfunc(args[0])
156            return f"{mod}set[{et}]"
157        else:
158            return f"{mod}set"
159    elif is_dict(typ):
160        args = type_args(typ)
161        if args and len(args) == 2:
162            kt = thisfunc(args[0])
163            vt = thisfunc(args[1])
164            return f"{mod}dict[{kt}, {vt}]"
165        else:
166            return f"{mod}dict"
167    elif is_tuple(typ):
168        args = type_args(typ)
169        if args:
170            return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]'
171        else:
172            return f"{mod}tuple"
173    elif is_generic(typ):
174        origin = get_origin(typ)
175        if origin is None:
176            raise SerdeError("Could not extract origin class from generic class")
177
178        if not isinstance(origin.__name__, str):
179            raise SerdeError("Name of generic class is not string")
180
181        return origin.__name__
182
183    elif is_literal(typ):
184        args = type_args(typ)
185        if not args:
186            raise TypeError("Literal type requires at least one literal argument")
187        return f'Literal[{", ".join(str(e) for e in args)}]'
188    elif typ is Any:
189        return f"{mod}Any"
190    elif is_ellipsis(typ):
191        return "..."
192    else:
193        # Get super type for NewType
194        inner = getattr(typ, "__supertype__", None)
195        if inner:
196            return typename(inner)
197
198        name: Optional[str] = getattr(typ, "_name", None)
199        if name:
200            return name
201        else:
202            name = getattr(typ, "__name__", None)
203            if isinstance(name, str):
204                return name
205            else:
206                raise SerdeError(f"Could not get a type name from: {typ}")
207
208
209def type_args(typ: Any) -> tuple[type[Any], ...]:
210    """
211    Wrapper to suppress type error for accessing private members.
212    """
213    try:
214        args: tuple[type[Any, ...]] = typ.__args__  # type: ignore
215        if args is None:
216            return ()
217        else:
218            return args
219    except AttributeError:
220        return get_args(typ)
221
222
223def union_args(typ: Any) -> tuple[type[Any], ...]:
224    if not is_union(typ):
225        raise TypeError(f"{typ} is not Union")
226    args = type_args(typ)
227    if len(args) == 1:
228        return (args[0],)
229    it = iter(args)
230    types = []
231    for i1, i2 in itertools.zip_longest(it, it):
232        if not i2:
233            types.append(i1)
234        elif is_none(i2):
235            types.append(Optional[i1])
236        else:
237            types.extend((i1, i2))
238    return tuple(types)
239
240
241def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]:  # type: ignore
242    raw_fields = dataclasses.fields(cls)
243
244    try:
245        # this resolves types when string forward reference
246        # or PEP 563: "from __future__ import annotations" are used
247        resolved_hints = typing.get_type_hints(cls)
248    except Exception as e:
249        raise SerdeError(
250            f"Failed to resolve type hints for {typename(cls)}:\n"
251            f"{e.__class__.__name__}: {e}\n\n"
252            f"If you are using forward references make sure you are calling deserialize & "
253            "serialize after all classes are globally visible."
254        ) from e
255
256    for f in raw_fields:
257        real_type = resolved_hints.get(f.name)
258        if real_type is not None:
259            f.type = real_type
260
261    return iter(raw_fields)
262
263
264TypeLike = Union[type[Any], typing.Any]
265
266
267def iter_types(cls: TypeLike) -> list[TypeLike]:
268    """
269    Iterate field types recursively.
270
271    The correct return type is `Iterator[Union[Type, typing._specialform]],
272    but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
273    """
274    lst: set[TypeLike] = set()
275
276    def recursive(cls: TypeLike) -> None:
277        if cls in lst:
278            return
279
280        if is_dataclass(cls):
281            lst.add(cls)
282            for f in dataclass_fields(cls):
283                recursive(f.type)
284        elif isinstance(cls, str):
285            lst.add(cls)
286        elif is_opt(cls):
287            lst.add(Optional)
288            args = type_args(cls)
289            if args:
290                recursive(args[0])
291        elif is_union(cls):
292            lst.add(Union)
293            for arg in type_args(cls):
294                recursive(arg)
295        elif is_list(cls):
296            lst.add(list)
297            args = type_args(cls)
298            if args:
299                recursive(args[0])
300        elif is_set(cls):
301            lst.add(set)
302            args = type_args(cls)
303            if args:
304                recursive(args[0])
305        elif is_tuple(cls):
306            lst.add(tuple)
307            for arg in type_args(cls):
308                recursive(arg)
309        elif is_dict(cls):
310            lst.add(dict)
311            args = type_args(cls)
312            if args and len(args) >= 2:
313                recursive(args[0])
314                recursive(args[1])
315        else:
316            lst.add(cls)
317
318    recursive(cls)
319    return list(lst)
320
321
322def iter_unions(cls: TypeLike) -> list[TypeLike]:
323    """
324    Iterate over all unions that are used in the dataclass
325    """
326    lst: set[TypeLike] = set()
327    stack: list[TypeLike] = []  # To prevent infinite recursion
328
329    def recursive(cls: TypeLike) -> None:
330        if cls in lst:
331            return
332        if cls in stack:
333            return
334
335        if is_union(cls):
336            lst.add(cls)
337            for arg in type_args(cls):
338                recursive(arg)
339        if is_dataclass(cls):
340            stack.append(cls)
341            for f in dataclass_fields(cls):
342                recursive(f.type)
343            stack.pop()
344        elif is_opt(cls):
345            args = type_args(cls)
346            if args:
347                recursive(args[0])
348        elif is_list(cls) or is_set(cls):
349            args = type_args(cls)
350            if args:
351                recursive(args[0])
352        elif is_tuple(cls):
353            for arg in type_args(cls):
354                recursive(arg)
355        elif is_dict(cls):
356            args = type_args(cls)
357            if args and len(args) >= 2:
358                recursive(args[0])
359                recursive(args[1])
360
361    recursive(cls)
362    return list(lst)
363
364
365def iter_literals(cls: TypeLike) -> list[TypeLike]:
366    """
367    Iterate over all literals that are used in the dataclass
368    """
369    lst: set[TypeLike] = set()
370
371    def recursive(cls: TypeLike) -> None:
372        if cls in lst:
373            return
374
375        if is_literal(cls):
376            lst.add(cls)
377        if is_union(cls):
378            for arg in type_args(cls):
379                recursive(arg)
380        if is_dataclass(cls):
381            lst.add(cls)
382            for f in dataclass_fields(cls):
383                recursive(f.type)
384        elif is_opt(cls):
385            args = type_args(cls)
386            if args:
387                recursive(args[0])
388        elif is_list(cls) or is_set(cls):
389            args = type_args(cls)
390            if args:
391                recursive(args[0])
392        elif is_tuple(cls):
393            for arg in type_args(cls):
394                recursive(arg)
395        elif is_dict(cls):
396            args = type_args(cls)
397            if args and len(args) >= 2:
398                recursive(args[0])
399                recursive(args[1])
400
401    recursive(cls)
402    return list(lst)
403
404
405def is_union(typ: Any) -> bool:
406    """
407    Test if the type is `typing.Union`.
408
409    >>> is_union(Union[int, str])
410    True
411    """
412
413    try:
414        # When `_WithTagging` is received, it will check inner type.
415        if isinstance(typ, _WithTagging):
416            return is_union(typ.inner)
417    except Exception:
418        pass
419
420    # Python 3.10 Union operator e.g. str | int
421    if sys.version_info[:2] >= (3, 10):
422        try:
423            if isinstance(typ, types.UnionType):
424                return True
425        except Exception:
426            pass
427
428    # typing.Union
429    return typing_inspect.is_union_type(typ)  # type: ignore
430
431
432def is_opt(typ: Any) -> bool:
433    """
434    Test if the type is `typing.Optional`.
435
436    >>> is_opt(Optional[int])
437    True
438    >>> is_opt(Optional)
439    True
440    >>> is_opt(None.__class__)
441    False
442    """
443
444    # Python 3.10 Union operator e.g. str | None
445    is_union_type = False
446    if sys.version_info[:2] >= (3, 10):
447        try:
448            if isinstance(typ, types.UnionType):
449                is_union_type = True
450        except Exception:
451            pass
452
453    # typing.Optional
454    is_typing_union = typing_inspect.is_optional_type(typ)
455
456    args = type_args(typ)
457    if args:
458        return (
459            (is_union_type or is_typing_union)
460            and len(args) == 2
461            and not is_none(args[0])
462            and is_none(args[1])
463        )
464    else:
465        return typ is Optional
466
467
468def is_bare_opt(typ: Any) -> bool:
469    """
470    Test if the type is `typing.Optional` without type args.
471    >>> is_bare_opt(Optional[int])
472    False
473    >>> is_bare_opt(Optional)
474    True
475    >>> is_bare_opt(None.__class__)
476    False
477    """
478    return not type_args(typ) and typ is Optional
479
480
481def is_list(typ: type[Any]) -> bool:
482    """
483    Test if the type is `list`.
484
485    >>> is_list(list[int])
486    True
487    >>> is_list(list)
488    True
489    """
490    try:
491        return issubclass(get_origin(typ), list)  # type: ignore
492    except TypeError:
493        return typ is list
494
495
496def is_bare_list(typ: type[Any]) -> bool:
497    """
498    Test if the type is `list` without type args.
499
500    >>> is_bare_list(list[int])
501    False
502    >>> is_bare_list(list)
503    True
504    """
505    return typ is list
506
507
508def is_tuple(typ: Any) -> bool:
509    """
510    Test if the type is tuple.
511    """
512    try:
513        return issubclass(get_origin(typ), tuple)  # type: ignore
514    except TypeError:
515        return typ is tuple
516
517
518def is_bare_tuple(typ: type[Any]) -> bool:
519    """
520    Test if the type is tuple without type args.
521
522    >>> is_bare_tuple(tuple[int, str])
523    False
524    >>> is_bare_tuple(tuple)
525    True
526    """
527    return typ is tuple
528
529
530def is_variable_tuple(typ: type[Any]) -> bool:
531    """
532    Test if the type is a variable length of tuple tuple[T, ...]`.
533
534    >>> is_variable_tuple(tuple[int, ...])
535    True
536    >>> is_variable_tuple(tuple[int, bool])
537    False
538    >>> is_variable_tuple(tuple[()])
539    False
540    """
541    istuple = is_tuple(typ) and not is_bare_tuple(typ)
542    args = get_args(typ)
543    return istuple and len(args) == 2 and is_ellipsis(args[1])
544
545
546def is_set(typ: type[Any]) -> bool:
547    """
548    Test if the type is `set` or `frozenset`.
549
550    >>> is_set(set[int])
551    True
552    >>> is_set(set)
553    True
554    >>> is_set(frozenset[int])
555    True
556    """
557    try:
558        return issubclass(get_origin(typ), (set, frozenset))  # type: ignore
559    except TypeError:
560        return typ in (set, frozenset)
561
562
563def is_bare_set(typ: type[Any]) -> bool:
564    """
565    Test if the type is `set` without type args.
566
567    >>> is_bare_set(set[int])
568    False
569    >>> is_bare_set(set)
570    True
571    """
572    return typ in (set, frozenset)
573
574
575def is_frozen_set(typ: type[Any]) -> bool:
576    """
577    Test if the type is `frozenset`.
578
579    >>> is_frozen_set(frozenset[int])
580    True
581    >>> is_frozen_set(set)
582    False
583    """
584    try:
585        return issubclass(get_origin(typ), frozenset)  # type: ignore
586    except TypeError:
587        return typ is frozenset
588
589
590def is_dict(typ: type[Any]) -> bool:
591    """
592    Test if the type is dict.
593
594    >>> is_dict(dict[int, int])
595    True
596    >>> is_dict(dict)
597    True
598    >>> is_dict(defaultdict[int, int])
599    True
600    """
601    try:
602        return issubclass(get_origin(typ), (dict, defaultdict))  # type: ignore
603    except TypeError:
604        return typ in (dict, defaultdict)
605
606
607def is_bare_dict(typ: type[Any]) -> bool:
608    """
609    Test if the type is `dict` without type args.
610
611    >>> is_bare_dict(dict[int, str])
612    False
613    >>> is_bare_dict(dict)
614    True
615    """
616    return typ is dict
617
618
619def is_default_dict(typ: type[Any]) -> bool:
620    """
621    Test if the type is `defaultdict`.
622
623    >>> is_default_dict(defaultdict[int, int])
624    True
625    >>> is_default_dict(dict[int, int])
626    False
627    """
628    try:
629        return issubclass(get_origin(typ), defaultdict)  # type: ignore
630    except TypeError:
631        return typ is defaultdict
632
633
634def is_none(typ: type[Any]) -> bool:
635    """
636    >>> is_none(int)
637    False
638    >>> is_none(type(None))
639    True
640    >>> is_none(None)
641    False
642    """
643    return typ is type(None)  # noqa
644
645
646PRIMITIVES = [int, float, bool, str]
647
648
649def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
650    """
651    Test if the type is `enum.Enum`.
652    """
653    try:
654        return issubclass(typ, enum.Enum)
655    except TypeError:
656        return isinstance(typ, enum.Enum)
657
658
659def is_primitive_subclass(typ: type[Any]) -> bool:
660    """
661    Test if the type is a subclass of primitive type.
662
663    >>> is_primitive_subclass(str)
664    False
665    >>> class Str(str):
666    ...     pass
667    >>> is_primitive_subclass(Str)
668    True
669    """
670    return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)
671
672
673def is_primitive(typ: Union[type[Any], NewType]) -> bool:
674    """
675    Test if the type is primitive.
676
677    >>> is_primitive(int)
678    True
679    >>> class CustomInt(int):
680    ...     pass
681    >>> is_primitive(CustomInt)
682    True
683    """
684    try:
685        return any(issubclass(typ, ty) for ty in PRIMITIVES)  # type: ignore
686    except TypeError:
687        return is_new_type_primitive(typ)
688
689
690def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
691    """
692    Test if the type is a NewType of primitives.
693    """
694    inner = getattr(typ, "__supertype__", None)
695    if inner:
696        return is_primitive(inner)
697    else:
698        return any(isinstance(typ, ty) for ty in PRIMITIVES)
699
700
701def is_generic(typ: Any) -> bool:
702    """
703    Test if the type is derived from `typing.Generic`.
704
705    >>> T = typing.TypeVar('T')
706    >>> class GenericFoo(typing.Generic[T]):
707    ...     pass
708    >>> is_generic(GenericFoo[int])
709    True
710    >>> is_generic(GenericFoo)
711    False
712    """
713    origin = get_origin(typ)
714    return origin is not None and Generic in getattr(origin, "__bases__", ())
715
716
717def is_class_var(typ: type[Any]) -> bool:
718    """
719    Test if the type is `typing.ClassVar`.
720
721    >>> is_class_var(ClassVar[int])
722    True
723    >>> is_class_var(ClassVar)
724    True
725    """
726    return get_origin(typ) is ClassVar or typ is ClassVar  # type: ignore
727
728
729def is_literal(typ: type[Any]) -> bool:
730    """
731    Test if the type is derived from `typing.Literal`.
732
733    >>> T = typing.TypeVar('T')
734    >>> class GenericFoo(typing.Generic[T]):
735    ...     pass
736    >>> is_generic(GenericFoo[int])
737    True
738    >>> is_generic(GenericFoo)
739    False
740    """
741    origin = get_origin(typ)
742    return origin is not None and origin is typing.Literal
743
744
745def is_any(typ: type[Any]) -> bool:
746    """
747    Test if the type is `typing.Any`.
748    """
749    return typ is Any
750
751
752def is_str_serializable(typ: type[Any]) -> bool:
753    """
754    Test if the type is serializable to `str`.
755    """
756    return typ in StrSerializableTypes or (
757        type(typ) is type and issubclass(typ, StrSerializableTypes)
758    )
759
760
761def is_datetime(
762    typ: type[Any],
763) -> TypeGuard[Union[datetime.date, datetime.time, datetime.datetime]]:
764    """
765    Test if the type is any of the datetime types..
766    """
767    return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes))
768
769
770def is_str_serializable_instance(obj: Any) -> bool:
771    return isinstance(obj, StrSerializableTypes)
772
773
774def is_datetime_instance(obj: Any) -> bool:
775    return isinstance(obj, DateTimeTypes)
776
777
778def is_ellipsis(typ: Any) -> bool:
779    return typ is Ellipsis
780
781
782def get_type_var_names(cls: type[Any]) -> Optional[list[str]]:
783    """
784    Get type argument names of a generic class.
785
786    >>> T = typing.TypeVar('T')
787    >>> class GenericFoo(typing.Generic[T]):
788    ...     pass
789    >>> get_type_var_names(GenericFoo)
790    ['T']
791    >>> get_type_var_names(int)
792    """
793    bases = getattr(cls, "__orig_bases__", ())
794    if not bases:
795        return None
796
797    type_arg_names: list[str] = []
798    for base in bases:
799        type_arg_names.extend(arg.__name__ for arg in get_args(base))
800
801    return type_arg_names
802
803
804def find_generic_arg(cls: type[Any], field: TypeVar) -> int:
805    """
806    Find a type in generic parameters.
807
808    >>> T = typing.TypeVar('T')
809    >>> U = typing.TypeVar('U')
810    >>> V = typing.TypeVar('V')
811    >>> class GenericFoo(typing.Generic[T, U]):
812    ...     pass
813    >>> find_generic_arg(GenericFoo, T)
814    0
815    >>> find_generic_arg(GenericFoo, U)
816    1
817    >>> find_generic_arg(GenericFoo, V)
818    -1
819    """
820    bases = getattr(cls, "__orig_bases__", ())
821    if not bases:
822        raise Exception(f'"__orig_bases__" property was not found: {cls}')
823
824    for base in bases:
825        for n, arg in enumerate(get_args(base)):
826            if arg.__name__ == field.__name__:
827                return n
828
829    if not bases:
830        raise Exception(f"Generic field not found in class: {bases}")
831
832    return -1
833
834
835def get_generic_arg(
836    typ: Any,
837    maybe_generic_type_vars: Optional[list[str]],
838    variable_type_args: Optional[list[str]],
839    index: int,
840) -> Any:
841    """
842    Get generic type argument.
843
844    >>> T = typing.TypeVar('T')
845    >>> U = typing.TypeVar('U')
846    >>> class GenericFoo(typing.Generic[T, U]):
847    ...     pass
848    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__
849    'int'
850    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__
851    'str'
852    >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__
853    'str'
854    """
855    if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None:
856        return typing.Any
857
858    args = get_args(typ)
859
860    if len(args) != len(maybe_generic_type_vars):
861        raise SerdeError(
862            f"Number of type args for {typ} does not match number of generic type vars: "
863            f"\n  type args: {args}\n  type_vars: {maybe_generic_type_vars}"
864        )
865
866    # Get the name of the type var used for this field in the parent class definition
867    type_var_name = variable_type_args[index]
868
869    try:
870        # Find the slot of that type var in the original generic class definition
871        orig_index = maybe_generic_type_vars.index(type_var_name)
872    except ValueError:
873        return typing.Any
874
875    return args[orig_index]