serde.compat
Compatibility layer which handles mostly differences of typing module between python versions.
1""" 2Compatibility layer which handles mostly differences of `typing` module between python versions. 3""" 4 5import dataclasses 6import datetime 7import decimal 8import enum 9import functools 10import ipaddress 11import itertools 12import pathlib 13import types 14import uuid 15import typing 16import typing_extensions 17from collections import defaultdict 18from collections.abc import Iterator, Sequence, MutableSequence 19from collections.abc import Mapping, MutableMapping, Set, MutableSet 20from dataclasses import is_dataclass 21from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable 22 23import typing_inspect 24from typing_extensions import TypeGuard, ParamSpec 25 26# `typing_extensions.TypeAliasType` isn't always an alias to `typing.TypeAliasType` 27# depending on certain versions of `typing_extensions` and python. 28_PEP695_TYPES: tuple[type, ...] 29if hasattr(typing, "TypeAliasType"): 30 _PEP695_TYPES = (typing_extensions.TypeAliasType, typing.TypeAliasType) 31else: 32 _PEP695_TYPES = (typing_extensions.TypeAliasType,) 33 34 35# Lazy SQLAlchemy imports to improve startup time 36 37 38# Lazy SQLAlchemy import wrapper to improve startup time 39def _is_sqlalchemy_inspectable(subject: Any) -> bool: 40 from .sqlalchemy import is_sqlalchemy_inspectable 41 42 return is_sqlalchemy_inspectable(subject) 43 44 45def get_np_origin(tp: type[Any]) -> Any | None: 46 return None 47 48 49def get_np_args(tp: type[Any]) -> tuple[Any, ...]: 50 return () 51 52 53__all__: list[str] = [] 54 55T = TypeVar("T") 56 57 58StrSerializableTypes = ( 59 decimal.Decimal, 60 pathlib.Path, 61 pathlib.PosixPath, 62 pathlib.WindowsPath, 63 pathlib.PurePath, 64 pathlib.PurePosixPath, 65 pathlib.PureWindowsPath, 66 uuid.UUID, 67 ipaddress.IPv4Address, 68 ipaddress.IPv6Address, 69 ipaddress.IPv4Network, 70 ipaddress.IPv6Network, 71 ipaddress.IPv4Interface, 72 ipaddress.IPv6Interface, 73) 74""" List of standard types (de)serializable to str """ 75 76DateTimeTypes = (datetime.date, datetime.time, datetime.datetime) 77""" List of datetime types """ 78 79 80@dataclasses.dataclass(unsafe_hash=True) 81class _WithTagging(Generic[T]): 82 """ 83 Intermediate data structure for (de)serializaing Union without dataclass. 84 """ 85 86 inner: T 87 """ Union type .e.g Union[Foo,Bar] passed in from_obj. """ 88 tagging: Any 89 """ Union Tagging """ 90 91 92class SerdeError(Exception): 93 """ 94 Serde error class. 95 """ 96 97 98@dataclasses.dataclass 99class UserError(Exception): 100 """ 101 Error from user code e.g. __post_init__. 102 """ 103 104 inner: Exception 105 106 107class SerdeSkip(Exception): 108 """ 109 Skip a field in custom (de)serializer. 110 """ 111 112 113def is_hashable(typ: Any) -> TypeGuard[Hashable]: 114 """ 115 Test is an object hashable 116 """ 117 try: 118 hash(typ) 119 except TypeError: 120 return False 121 return True 122 123 124P = ParamSpec("P") 125 126 127def cache(f: Callable[P, T]) -> Callable[P, T]: 128 """ 129 Wrapper for `functools.cache` to avoid `Hashable` related type errors. 130 """ 131 132 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 133 return f(*args, **kwargs) 134 135 return functools.cache(wrapper) # type: ignore 136 137 138@cache 139def get_origin(typ: Any) -> Any | None: 140 """ 141 Provide `get_origin` that works in all python versions. 142 """ 143 return typing.get_origin(typ) or get_np_origin(typ) 144 145 146@cache 147def get_args(typ: type[Any]) -> tuple[Any, ...]: 148 """ 149 Provide `get_args` that works in all python versions. 150 """ 151 return typing.get_args(typ) or get_np_args(typ) 152 153 154@cache 155def typename(typ: Any, with_typing_module: bool = False) -> str: 156 """ 157 >>> from typing import Any 158 >>> typename(int) 159 'int' 160 >>> class Foo: pass 161 >>> typename(Foo) 162 'Foo' 163 >>> typename(list[Foo]) 164 'list[Foo]' 165 >>> typename(dict[str, Foo]) 166 'dict[str, Foo]' 167 >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]]) 168 'tuple[int, str, Foo, list[int], dict[str, Foo]]' 169 >>> typename(Optional[list[Foo]]) 170 'Optional[list[Foo]]' 171 >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]]) 172 'Union[Optional[Foo], list[Foo], str, int]' 173 >>> typename(set[Foo]) 174 'set[Foo]' 175 >>> typename(Any) 176 'Any' 177 """ 178 mod = "typing." if with_typing_module else "" 179 thisfunc = functools.partial(typename, with_typing_module=with_typing_module) 180 if is_opt(typ): 181 args = type_args(typ) 182 if args: 183 return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]" 184 else: 185 return f"{mod}Optional" 186 elif is_union(typ): 187 args = union_args(typ) 188 if args: 189 return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]' 190 else: 191 return f"{mod}Union" 192 elif is_list(typ): 193 args = type_args(typ) 194 if args: 195 et = thisfunc(args[0]) 196 return f"{mod}list[{et}]" 197 else: 198 return f"{mod}list" 199 elif is_set(typ): 200 args = type_args(typ) 201 if args: 202 et = thisfunc(args[0]) 203 return f"{mod}set[{et}]" 204 else: 205 return f"{mod}set" 206 elif is_dict(typ): 207 args = type_args(typ) 208 if args and len(args) == 2: 209 kt = thisfunc(args[0]) 210 vt = thisfunc(args[1]) 211 return f"{mod}dict[{kt}, {vt}]" 212 else: 213 return f"{mod}dict" 214 elif is_tuple(typ): 215 args = type_args(typ) 216 if args: 217 return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]' 218 else: 219 return f"{mod}tuple" 220 elif is_generic(typ): 221 origin = get_origin(typ) 222 if origin is None: 223 raise SerdeError("Could not extract origin class from generic class") 224 225 if not isinstance(origin.__name__, str): 226 raise SerdeError("Name of generic class is not string") 227 228 return origin.__name__ 229 230 elif is_literal(typ): 231 args = type_args(typ) 232 if not args: 233 raise TypeError("Literal type requires at least one literal argument") 234 return f'Literal[{", ".join(stringify_literal(e) for e in args)}]' 235 elif typ is Any: 236 return f"{mod}Any" 237 elif is_ellipsis(typ): 238 return "..." 239 else: 240 # Get super type for NewType 241 inner = getattr(typ, "__supertype__", None) 242 if inner: 243 return typename(inner) 244 245 name: str | None = getattr(typ, "_name", None) 246 if name: 247 return name 248 else: 249 name = getattr(typ, "__name__", None) 250 if isinstance(name, str): 251 return name 252 else: 253 raise SerdeError(f"Could not get a type name from: {typ}") 254 255 256def stringify_literal(v: Any) -> str: 257 if isinstance(v, str): 258 return f"'{v}'" 259 else: 260 return str(v) 261 262 263def type_args(typ: Any) -> tuple[type[Any], ...]: 264 """ 265 Wrapper to suppress type error for accessing private members. 266 """ 267 try: 268 args: tuple[type[Any, ...]] | None = typ.__args__ # type: ignore 269 except AttributeError: 270 return get_args(typ) 271 272 # Some typing objects expose __args__ as a member_descriptor (e.g. typing.Union), 273 # which isn't iterable. Fall back to typing.get_args in that case. 274 if isinstance(args, tuple): 275 return args 276 return get_args(typ) 277 278 279def union_args(typ: Any) -> tuple[type[Any], ...]: 280 if not is_union(typ): 281 raise TypeError(f"{typ} is not Union") 282 args = type_args(typ) 283 if len(args) == 1: 284 return (args[0],) 285 it = iter(args) 286 types = [] 287 for i1, i2 in itertools.zip_longest(it, it): 288 if not i2: 289 types.append(i1) 290 elif is_none(i2): 291 types.append(Optional[i1]) 292 else: 293 types.extend((i1, i2)) 294 return tuple(types) 295 296 297def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ignore 298 raw_fields = dataclasses.fields(cls) 299 300 try: 301 # this resolves types when string forward reference 302 # or PEP 563: "from __future__ import annotations" are used 303 resolved_hints = typing.get_type_hints(cls) 304 except Exception as e: 305 raise SerdeError( 306 f"Failed to resolve type hints for {typename(cls)}:\n" 307 f"{e.__class__.__name__}: {e}\n\n" 308 f"If you are using forward references make sure you are calling deserialize & " 309 "serialize after all classes are globally visible." 310 ) from e 311 312 for f in raw_fields: 313 real_type = resolved_hints.get(f.name) 314 if real_type is not None: 315 f.type = real_type 316 if is_generic(real_type) and _is_sqlalchemy_inspectable(cls): 317 f.type = get_args(real_type)[0] 318 319 return iter(raw_fields) 320 321 322TypeLike = Union[type[Any], typing.Any] 323 324 325def iter_types(cls: type[Any]) -> list[type[Any]]: 326 """ 327 Iterate field types recursively. 328 329 The correct return type is `Iterator[Union[Type, typing._specialform]], 330 but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead. 331 """ 332 lst: set[Union[type[Any], Any]] = set() 333 334 def recursive(cls: Union[type[Any], Any]) -> None: 335 if cls in lst: 336 return 337 338 if is_dataclass(cls): 339 lst.add(cls) 340 if isinstance(cls, type): 341 for f in dataclass_fields(cls): 342 recursive(f.type) 343 elif isinstance(cls, str): 344 lst.add(cls) 345 elif is_opt(cls): 346 lst.add(Optional) 347 args = type_args(cls) 348 if args: 349 recursive(args[0]) 350 elif is_union(cls): 351 lst.add(Union) 352 for arg in type_args(cls): 353 recursive(arg) 354 elif is_list(cls): 355 lst.add(list) 356 args = type_args(cls) 357 if args: 358 recursive(args[0]) 359 elif is_set(cls): 360 lst.add(set) 361 args = type_args(cls) 362 if args: 363 recursive(args[0]) 364 elif is_tuple(cls): 365 lst.add(tuple) 366 for arg in type_args(cls): 367 recursive(arg) 368 elif is_dict(cls): 369 lst.add(dict) 370 args = type_args(cls) 371 if args and len(args) >= 2: 372 recursive(args[0]) 373 recursive(args[1]) 374 elif is_pep695_type_alias(cls): 375 recursive(cls.__value__) 376 else: 377 lst.add(cls) 378 379 recursive(cls) 380 return list(lst) 381 382 383def iter_unions(cls: TypeLike) -> list[TypeLike]: 384 """ 385 Iterate over all unions that are used in the dataclass 386 """ 387 lst: list[TypeLike] = [] 388 stack: list[TypeLike] = [] # To prevent infinite recursion 389 390 def recursive(cls: TypeLike) -> None: 391 if cls in stack: 392 return 393 394 if is_union(cls): 395 lst.append(cls) 396 for arg in type_args(cls): 397 recursive(arg) 398 elif is_pep695_type_alias(cls): 399 recursive(cls.__value__) 400 if is_dataclass(cls): 401 stack.append(cls) 402 if isinstance(cls, type): 403 for f in dataclass_fields(cls): 404 recursive(f.type) 405 stack.pop() 406 elif is_opt(cls): 407 args = type_args(cls) 408 if args: 409 recursive(args[0]) 410 elif is_list(cls) or is_set(cls): 411 args = type_args(cls) 412 if args: 413 recursive(args[0]) 414 elif is_tuple(cls): 415 for arg in type_args(cls): 416 recursive(arg) 417 elif is_dict(cls): 418 args = type_args(cls) 419 if args and len(args) >= 2: 420 recursive(args[0]) 421 recursive(args[1]) 422 423 recursive(cls) 424 return lst 425 426 427def iter_literals(cls: type[Any]) -> list[TypeLike]: 428 """ 429 Iterate over all literals that are used in the dataclass 430 """ 431 lst: set[Union[type[Any], Any]] = set() 432 433 def recursive(cls: Union[type[Any], Any]) -> None: 434 if cls in lst: 435 return 436 437 if is_literal(cls): 438 lst.add(cls) 439 if is_union(cls): 440 for arg in type_args(cls): 441 recursive(arg) 442 if is_dataclass(cls): 443 lst.add(cls) 444 if isinstance(cls, type): 445 for f in dataclass_fields(cls): 446 recursive(f.type) 447 elif is_opt(cls): 448 args = type_args(cls) 449 if args: 450 recursive(args[0]) 451 elif is_list(cls) or is_set(cls): 452 args = type_args(cls) 453 if args: 454 recursive(args[0]) 455 elif is_tuple(cls): 456 for arg in type_args(cls): 457 recursive(arg) 458 elif is_dict(cls): 459 args = type_args(cls) 460 if args and len(args) >= 2: 461 recursive(args[0]) 462 recursive(args[1]) 463 464 recursive(cls) 465 return list(lst) 466 467 468@cache 469def is_union(typ: Any) -> bool: 470 """ 471 Test if the type is `typing.Union`. 472 473 >>> is_union(Union[int, str]) 474 True 475 """ 476 477 try: 478 # When `_WithTagging` is received, it will check inner type. 479 if isinstance(typ, _WithTagging): 480 return is_union(typ.inner) 481 except Exception: 482 pass 483 484 # Python 3.10+ Union operator e.g. str | int 485 try: 486 if isinstance(typ, types.UnionType): 487 return True 488 except Exception: 489 pass 490 491 # typing.Union 492 return typing_inspect.is_union_type(typ) # type: ignore 493 494 495@cache 496def is_opt(typ: Any) -> bool: 497 """ 498 Test if the type is `typing.Optional`. 499 500 >>> is_opt(Optional[int]) 501 True 502 >>> is_opt(Optional) 503 True 504 >>> is_opt(None.__class__) 505 False 506 """ 507 508 # Python 3.10+ Union operator e.g. str | None 509 is_union_type = False 510 try: 511 if isinstance(typ, types.UnionType): 512 is_union_type = True 513 except Exception: 514 pass 515 516 # typing.Optional 517 is_typing_union = typing_inspect.is_optional_type(typ) 518 519 args = type_args(typ) 520 if args: 521 return ( 522 (is_union_type or is_typing_union) 523 and len(args) == 2 524 and not is_none(args[0]) 525 and is_none(args[1]) 526 ) 527 else: 528 return typ is Optional 529 530 531@cache 532def is_bare_opt(typ: Any) -> bool: 533 """ 534 Test if the type is `typing.Optional` without type args. 535 >>> is_bare_opt(Optional[int]) 536 False 537 >>> is_bare_opt(Optional) 538 True 539 >>> is_bare_opt(None.__class__) 540 False 541 """ 542 return not type_args(typ) and typ is Optional 543 544 545@cache 546def is_opt_dataclass(typ: Any) -> bool: 547 """ 548 Test if the type is optional dataclass. 549 550 >>> is_opt_dataclass(Optional[int]) 551 False 552 >>> @dataclasses.dataclass 553 ... class Foo: 554 ... pass 555 >>> is_opt_dataclass(Foo) 556 False 557 >>> is_opt_dataclass(Optional[Foo]) 558 False 559 """ 560 args = get_args(typ) 561 return is_opt(typ) and len(args) > 0 and is_dataclass(args[0]) 562 563 564@cache 565def is_list(typ: type[Any]) -> bool: 566 """ 567 Test if the type is `list`, `collections.abc.Sequence`, or `collections.abc.MutableSequence`. 568 569 >>> is_list(list[int]) 570 True 571 >>> is_list(list) 572 True 573 >>> is_list(Sequence[int]) 574 True 575 >>> is_list(Sequence) 576 True 577 >>> is_list(MutableSequence[int]) 578 True 579 >>> is_list(MutableSequence) 580 True 581 """ 582 origin = get_origin(typ) 583 if origin is None: 584 return typ in (list, Sequence, MutableSequence) 585 return origin in (list, Sequence, MutableSequence) 586 587 588@cache 589def is_bare_list(typ: type[Any]) -> bool: 590 """ 591 Test if the type is `list`/`collections.abc.Sequence`/`collections.abc.MutableSequence` 592 without type args. 593 594 >>> is_bare_list(list[int]) 595 False 596 >>> is_bare_list(list) 597 True 598 >>> is_bare_list(Sequence[int]) 599 False 600 >>> is_bare_list(Sequence) 601 True 602 >>> is_bare_list(MutableSequence[int]) 603 False 604 >>> is_bare_list(MutableSequence) 605 True 606 """ 607 origin = get_origin(typ) 608 if origin in (list, Sequence, MutableSequence): 609 return not type_args(typ) 610 return typ in (list, Sequence, MutableSequence) 611 612 613@cache 614def is_tuple(typ: Any) -> bool: 615 """ 616 Test if the type is tuple. 617 """ 618 try: 619 return issubclass(get_origin(typ), tuple) # type: ignore 620 except TypeError: 621 return typ is tuple 622 623 624@cache 625def is_bare_tuple(typ: type[Any]) -> bool: 626 """ 627 Test if the type is tuple without type args. 628 629 >>> is_bare_tuple(tuple[int, str]) 630 False 631 >>> is_bare_tuple(tuple) 632 True 633 """ 634 return typ is tuple 635 636 637@cache 638def is_variable_tuple(typ: type[Any]) -> bool: 639 """ 640 Test if the type is a variable length of tuple tuple[T, ...]`. 641 642 >>> is_variable_tuple(tuple[int, ...]) 643 True 644 >>> is_variable_tuple(tuple[int, bool]) 645 False 646 >>> is_variable_tuple(tuple[()]) 647 False 648 """ 649 istuple = is_tuple(typ) and not is_bare_tuple(typ) 650 args = get_args(typ) 651 return istuple and len(args) == 2 and is_ellipsis(args[1]) 652 653 654@cache 655def is_set(typ: type[Any]) -> bool: 656 """ 657 Test if the type is set-like. 658 659 >>> is_set(set[int]) 660 True 661 >>> is_set(set) 662 True 663 >>> is_set(frozenset[int]) 664 True 665 >>> from collections.abc import Set, MutableSet 666 >>> is_set(Set[int]) 667 True 668 >>> is_set(Set) 669 True 670 >>> is_set(MutableSet[int]) 671 True 672 >>> is_set(MutableSet) 673 True 674 """ 675 try: 676 return issubclass(get_origin(typ), (set, frozenset, Set, MutableSet)) # type: ignore[arg-type] 677 except TypeError: 678 return typ in (set, frozenset, Set, MutableSet) 679 680 681@cache 682def is_bare_set(typ: type[Any]) -> bool: 683 """ 684 Test if the type is `set`/`frozenset`/`Set`/`MutableSet` without type args. 685 686 >>> is_bare_set(set[int]) 687 False 688 >>> is_bare_set(set) 689 True 690 >>> from collections.abc import Set, MutableSet 691 >>> is_bare_set(Set) 692 True 693 >>> is_bare_set(MutableSet) 694 True 695 """ 696 origin = get_origin(typ) 697 if origin in (set, frozenset, Set, MutableSet): 698 return not type_args(typ) 699 return typ in (set, frozenset, Set, MutableSet) 700 701 702@cache 703def is_frozen_set(typ: type[Any]) -> bool: 704 """ 705 Test if the type is `frozenset`. 706 707 >>> is_frozen_set(frozenset[int]) 708 True 709 >>> is_frozen_set(set) 710 False 711 """ 712 try: 713 return issubclass(get_origin(typ), frozenset) # type: ignore 714 except TypeError: 715 return typ is frozenset 716 717 718@cache 719def is_dict(typ: type[Any]) -> bool: 720 """ 721 Test if the type is dict-like. 722 723 >>> is_dict(dict[int, int]) 724 True 725 >>> is_dict(dict) 726 True 727 >>> is_dict(defaultdict[int, int]) 728 True 729 >>> from collections.abc import Mapping, MutableMapping 730 >>> is_dict(Mapping[str, int]) 731 True 732 >>> is_dict(Mapping) 733 True 734 >>> is_dict(MutableMapping[str, int]) 735 True 736 >>> is_dict(MutableMapping) 737 True 738 """ 739 try: 740 return issubclass( 741 get_origin(typ), (dict, defaultdict, Mapping, MutableMapping) # type: ignore[arg-type] 742 ) 743 except TypeError: 744 return typ in (dict, defaultdict, Mapping, MutableMapping) 745 746 747@cache 748@cache 749def is_bare_dict(typ: type[Any]) -> bool: 750 """ 751 Test if the type is `dict`/`Mapping`/`MutableMapping` without type args. 752 753 >>> is_bare_dict(dict[int, str]) 754 False 755 >>> is_bare_dict(dict) 756 True 757 >>> from collections.abc import Mapping, MutableMapping 758 >>> is_bare_dict(Mapping) 759 True 760 >>> is_bare_dict(MutableMapping) 761 True 762 """ 763 origin = get_origin(typ) 764 if origin in (dict, Mapping, MutableMapping): 765 return not type_args(typ) 766 return typ in (dict, Mapping, MutableMapping) 767 768 769@cache 770def is_default_dict(typ: type[Any]) -> bool: 771 """ 772 Test if the type is `defaultdict`. 773 774 >>> is_default_dict(defaultdict[int, int]) 775 True 776 >>> is_default_dict(dict[int, int]) 777 False 778 """ 779 try: 780 return issubclass(get_origin(typ), defaultdict) # type: ignore 781 except TypeError: 782 return typ is defaultdict 783 784 785@cache 786def is_none(typ: type[Any]) -> bool: 787 """ 788 >>> is_none(int) 789 False 790 >>> is_none(type(None)) 791 True 792 >>> is_none(None) 793 False 794 """ 795 return typ is type(None) # noqa 796 797 798PRIMITIVES = [int, float, bool, str] 799 800 801@cache 802def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: 803 """ 804 Test if the type is `enum.Enum`. 805 """ 806 try: 807 return issubclass(typ, enum.Enum) 808 except TypeError: 809 return isinstance(typ, enum.Enum) 810 811 812@cache 813def is_primitive_subclass(typ: type[Any]) -> bool: 814 """ 815 Test if the type is a subclass of primitive type. 816 817 >>> is_primitive_subclass(str) 818 False 819 >>> class Str(str): 820 ... pass 821 >>> is_primitive_subclass(Str) 822 True 823 """ 824 return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ) 825 826 827@cache 828def is_primitive(typ: type[Any] | NewType) -> bool: 829 """ 830 Test if the type is primitive. 831 832 >>> is_primitive(int) 833 True 834 >>> class CustomInt(int): 835 ... pass 836 >>> is_primitive(CustomInt) 837 True 838 """ 839 try: 840 return any(issubclass(typ, ty) for ty in PRIMITIVES) # type: ignore 841 except TypeError: 842 return is_new_type_primitive(typ) 843 844 845@cache 846def is_new_type_primitive(typ: type[Any] | NewType) -> bool: 847 """ 848 Test if the type is a NewType of primitives. 849 """ 850 inner = getattr(typ, "__supertype__", None) 851 if inner: 852 return is_primitive(inner) 853 else: 854 return any(isinstance(typ, ty) for ty in PRIMITIVES) 855 856 857@cache 858def has_generic_base(typ: Any) -> bool: 859 return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ()) 860 861 862@cache 863def is_generic(typ: Any) -> bool: 864 """ 865 Test if the type is derived from `typing.Generic`. 866 867 >>> T = typing.TypeVar('T') 868 >>> class GenericFoo(typing.Generic[T]): 869 ... pass 870 >>> is_generic(GenericFoo[int]) 871 True 872 >>> is_generic(GenericFoo) 873 False 874 """ 875 origin = get_origin(typ) 876 return origin is not None and has_generic_base(origin) 877 878 879@cache 880def is_class_var(typ: type[Any]) -> bool: 881 """ 882 Test if the type is `typing.ClassVar`. 883 884 >>> is_class_var(ClassVar[int]) 885 True 886 >>> is_class_var(ClassVar) 887 True 888 """ 889 return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore 890 891 892@cache 893def is_literal(typ: type[Any]) -> bool: 894 """ 895 Test if the type is derived from `typing.Literal`. 896 897 >>> T = typing.TypeVar('T') 898 >>> class GenericFoo(typing.Generic[T]): 899 ... pass 900 >>> is_generic(GenericFoo[int]) 901 True 902 >>> is_generic(GenericFoo) 903 False 904 """ 905 origin = get_origin(typ) 906 return origin is not None and origin is typing.Literal 907 908 909@cache 910def is_any(typ: Any) -> bool: 911 """ 912 Test if the type is `typing.Any`. 913 """ 914 return typ is Any 915 916 917@cache 918def is_str_serializable(typ: type[Any]) -> bool: 919 """ 920 Test if the type is serializable to `str`. 921 """ 922 return typ in StrSerializableTypes or ( 923 type(typ) is type and issubclass(typ, StrSerializableTypes) 924 ) 925 926 927def is_datetime( 928 typ: type[Any], 929) -> TypeGuard[datetime.date | datetime.time | datetime.datetime]: 930 """ 931 Test if the type is any of the datetime types.. 932 """ 933 return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes)) 934 935 936def is_str_serializable_instance(obj: Any) -> bool: 937 return isinstance(obj, StrSerializableTypes) 938 939 940def is_datetime_instance(obj: Any) -> bool: 941 return isinstance(obj, DateTimeTypes) 942 943 944def is_ellipsis(typ: Any) -> bool: 945 return typ is Ellipsis 946 947 948def is_pep695_type_alias(typ: Any) -> bool: 949 """ 950 Test if the type is of PEP695 type alias. 951 """ 952 return isinstance(typ, _PEP695_TYPES) 953 954 955@cache 956def get_type_var_names(cls: type[Any]) -> list[str] | None: 957 """ 958 Get type argument names of a generic class. 959 960 >>> T = typing.TypeVar('T') 961 >>> class GenericFoo(typing.Generic[T]): 962 ... pass 963 >>> get_type_var_names(GenericFoo) 964 ['T'] 965 >>> get_type_var_names(int) 966 """ 967 bases = getattr(cls, "__orig_bases__", ()) 968 if not bases: 969 return None 970 971 type_arg_names: list[str] = [] 972 for base in bases: 973 type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__")) 974 975 return type_arg_names 976 977 978def find_generic_arg(cls: type[Any], field: TypeVar) -> int: 979 """ 980 Find a type in generic parameters. 981 982 >>> T = typing.TypeVar('T') 983 >>> U = typing.TypeVar('U') 984 >>> V = typing.TypeVar('V') 985 >>> class GenericFoo(typing.Generic[T, U]): 986 ... pass 987 >>> find_generic_arg(GenericFoo, T) 988 0 989 >>> find_generic_arg(GenericFoo, U) 990 1 991 >>> find_generic_arg(GenericFoo, V) 992 -1 993 """ 994 bases = getattr(cls, "__orig_bases__", ()) 995 if not bases: 996 raise Exception(f'"__orig_bases__" property was not found: {cls}') 997 998 for base in bases: 999 for n, arg in enumerate(get_args(base)): 1000 if arg.__name__ == field.__name__: 1001 return n 1002 1003 if not bases: 1004 raise Exception(f"Generic field not found in class: {bases}") 1005 1006 return -1 1007 1008 1009def get_generic_arg( 1010 typ: Any, 1011 maybe_generic_type_vars: list[str] | None, 1012 variable_type_args: list[str] | None, 1013 index: int, 1014) -> Any: 1015 """ 1016 Get generic type argument. 1017 1018 >>> T = typing.TypeVar('T') 1019 >>> U = typing.TypeVar('U') 1020 >>> class GenericFoo(typing.Generic[T, U]): 1021 ... pass 1022 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__ 1023 'int' 1024 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__ 1025 'str' 1026 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__ 1027 'str' 1028 """ 1029 if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None: 1030 return typing.Any 1031 1032 args = get_args(typ) 1033 1034 if len(args) != len(maybe_generic_type_vars): 1035 raise SerdeError( 1036 f"Number of type args for {typ} does not match number of generic type vars: " 1037 f"\n type args: {args}\n type_vars: {maybe_generic_type_vars}" 1038 ) 1039 1040 # Get the name of the type var used for this field in the parent class definition 1041 type_var_name = variable_type_args[index] 1042 1043 try: 1044 # Find the slot of that type var in the original generic class definition 1045 orig_index = maybe_generic_type_vars.index(type_var_name) 1046 except ValueError: 1047 return typing.Any 1048 1049 return args[orig_index]