serde.compat
Compatibility layer which handles mostly differences of typing module between python versions.
1""" 2Compatibility layer which handles mostly differences of `typing` module between python versions. 3""" 4 5import dataclasses 6import datetime 7import decimal 8import enum 9import functools 10import ipaddress 11import itertools 12import pathlib 13import types 14import uuid 15import typing 16import typing_extensions 17from collections import defaultdict, deque, Counter 18from collections.abc import Iterator, Sequence, MutableSequence 19from collections.abc import Mapping, MutableMapping, Set, MutableSet 20from dataclasses import is_dataclass 21from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable 22 23import typing_inspect 24from typing_extensions import TypeGuard, ParamSpec 25 26# `typing_extensions.TypeAliasType` isn't always an alias to `typing.TypeAliasType` 27# depending on certain versions of `typing_extensions` and python. 28_PEP695_TYPES: tuple[type, ...] 29if hasattr(typing, "TypeAliasType"): 30 _PEP695_TYPES = (typing_extensions.TypeAliasType, typing.TypeAliasType) 31else: 32 _PEP695_TYPES = (typing_extensions.TypeAliasType,) 33 34 35# Lazy SQLAlchemy imports to improve startup time 36 37 38# Lazy SQLAlchemy import wrapper to improve startup time 39def _is_sqlalchemy_inspectable(subject: Any) -> bool: 40 from .sqlalchemy import is_sqlalchemy_inspectable 41 42 return is_sqlalchemy_inspectable(subject) 43 44 45def get_np_origin(tp: type[Any]) -> Any | None: 46 return None 47 48 49def get_np_args(tp: type[Any]) -> tuple[Any, ...]: 50 return () 51 52 53__all__: list[str] = [] 54 55T = TypeVar("T") 56 57 58StrSerializableTypes = ( 59 decimal.Decimal, 60 pathlib.Path, 61 pathlib.PosixPath, 62 pathlib.WindowsPath, 63 pathlib.PurePath, 64 pathlib.PurePosixPath, 65 pathlib.PureWindowsPath, 66 uuid.UUID, 67 ipaddress.IPv4Address, 68 ipaddress.IPv6Address, 69 ipaddress.IPv4Network, 70 ipaddress.IPv6Network, 71 ipaddress.IPv4Interface, 72 ipaddress.IPv6Interface, 73) 74""" List of standard types (de)serializable to str """ 75 76DateTimeTypes = (datetime.date, datetime.time, datetime.datetime) 77""" List of datetime types """ 78 79 80@dataclasses.dataclass(unsafe_hash=True) 81class _WithTagging(Generic[T]): 82 """ 83 Intermediate data structure for (de)serializaing Union without dataclass. 84 """ 85 86 inner: T 87 """ Union type .e.g Union[Foo,Bar] passed in from_obj. """ 88 tagging: Any 89 """ Union Tagging """ 90 91 92class SerdeError(Exception): 93 """ 94 Serde error class. 95 """ 96 97 98@dataclasses.dataclass 99class UserError(Exception): 100 """ 101 Error from user code e.g. __post_init__. 102 """ 103 104 inner: Exception 105 106 107class SerdeSkip(Exception): 108 """ 109 Skip a field in custom (de)serializer. 110 """ 111 112 113def is_hashable(typ: Any) -> TypeGuard[Hashable]: 114 """ 115 Test is an object hashable 116 """ 117 try: 118 hash(typ) 119 except TypeError: 120 return False 121 return True 122 123 124P = ParamSpec("P") 125 126 127def cache(f: Callable[P, T]) -> Callable[P, T]: 128 """ 129 Wrapper for `functools.cache` to avoid `Hashable` related type errors. 130 """ 131 132 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 133 return f(*args, **kwargs) 134 135 return functools.cache(wrapper) # type: ignore 136 137 138@cache 139def get_origin(typ: Any) -> Any | None: 140 """ 141 Provide `get_origin` that works in all python versions. 142 """ 143 return typing.get_origin(typ) or get_np_origin(typ) 144 145 146@cache 147def get_args(typ: type[Any]) -> tuple[Any, ...]: 148 """ 149 Provide `get_args` that works in all python versions. 150 """ 151 return typing.get_args(typ) or get_np_args(typ) 152 153 154@cache 155def typename(typ: Any, with_typing_module: bool = False) -> str: 156 """ 157 >>> from typing import Any 158 >>> typename(int) 159 'int' 160 >>> class Foo: pass 161 >>> typename(Foo) 162 'Foo' 163 >>> typename(list[Foo]) 164 'list[Foo]' 165 >>> typename(dict[str, Foo]) 166 'dict[str, Foo]' 167 >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]]) 168 'tuple[int, str, Foo, list[int], dict[str, Foo]]' 169 >>> typename(Optional[list[Foo]]) 170 'Optional[list[Foo]]' 171 >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]]) 172 'Union[Optional[Foo], list[Foo], str, int]' 173 >>> typename(set[Foo]) 174 'set[Foo]' 175 >>> typename(Any) 176 'Any' 177 """ 178 mod = "typing." if with_typing_module else "" 179 thisfunc = functools.partial(typename, with_typing_module=with_typing_module) 180 if is_opt(typ): 181 args = type_args(typ) 182 if args: 183 return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]" 184 else: 185 return f"{mod}Optional" 186 elif is_union(typ): 187 args = union_args(typ) 188 if args: 189 return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]' 190 else: 191 return f"{mod}Union" 192 elif is_list(typ): 193 args = type_args(typ) 194 if args: 195 et = thisfunc(args[0]) 196 return f"{mod}list[{et}]" 197 else: 198 return f"{mod}list" 199 elif is_set(typ): 200 args = type_args(typ) 201 if args: 202 et = thisfunc(args[0]) 203 return f"{mod}set[{et}]" 204 else: 205 return f"{mod}set" 206 elif is_dict(typ): 207 args = type_args(typ) 208 if args and len(args) == 2: 209 kt = thisfunc(args[0]) 210 vt = thisfunc(args[1]) 211 return f"{mod}dict[{kt}, {vt}]" 212 else: 213 return f"{mod}dict" 214 elif is_deque(typ): 215 args = type_args(typ) 216 if args: 217 et = thisfunc(args[0]) 218 return f"deque[{et}]" 219 else: 220 return "deque" 221 elif is_counter(typ): 222 args = type_args(typ) 223 if args: 224 et = thisfunc(args[0]) 225 return f"Counter[{et}]" 226 else: 227 return "Counter" 228 elif is_tuple(typ): 229 args = type_args(typ) 230 if args: 231 return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]' 232 else: 233 return f"{mod}tuple" 234 elif is_generic(typ): 235 origin = get_origin(typ) 236 if origin is None: 237 raise SerdeError("Could not extract origin class from generic class") 238 239 if not isinstance(origin.__name__, str): 240 raise SerdeError("Name of generic class is not string") 241 242 return origin.__name__ 243 244 elif is_literal(typ): 245 args = type_args(typ) 246 if not args: 247 raise TypeError("Literal type requires at least one literal argument") 248 return f'Literal[{", ".join(stringify_literal(e) for e in args)}]' 249 elif typ is Any: 250 return f"{mod}Any" 251 elif is_ellipsis(typ): 252 return "..." 253 else: 254 # Get super type for NewType 255 inner = getattr(typ, "__supertype__", None) 256 if inner: 257 return typename(inner) 258 259 name: str | None = getattr(typ, "_name", None) 260 if name: 261 return name 262 else: 263 name = getattr(typ, "__name__", None) 264 if isinstance(name, str): 265 return name 266 else: 267 raise SerdeError(f"Could not get a type name from: {typ}") 268 269 270def stringify_literal(v: Any) -> str: 271 if isinstance(v, str): 272 return f"'{v}'" 273 else: 274 return str(v) 275 276 277def type_args(typ: Any) -> tuple[type[Any], ...]: 278 """ 279 Wrapper to suppress type error for accessing private members. 280 """ 281 try: 282 args: tuple[type[Any, ...]] | None = typ.__args__ # type: ignore 283 except AttributeError: 284 return get_args(typ) 285 286 # Some typing objects expose __args__ as a member_descriptor (e.g. typing.Union), 287 # which isn't iterable. Fall back to typing.get_args in that case. 288 if isinstance(args, tuple): 289 return args 290 return get_args(typ) 291 292 293def union_args(typ: Any) -> tuple[type[Any], ...]: 294 if not is_union(typ): 295 raise TypeError(f"{typ} is not Union") 296 args = type_args(typ) 297 if len(args) == 1: 298 return (args[0],) 299 it = iter(args) 300 types = [] 301 for i1, i2 in itertools.zip_longest(it, it): 302 if not i2: 303 types.append(i1) 304 elif is_none(i2): 305 types.append(Optional[i1]) 306 else: 307 types.extend((i1, i2)) 308 return tuple(types) 309 310 311def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ignore 312 raw_fields = dataclasses.fields(cls) 313 314 try: 315 # this resolves types when string forward reference 316 # or PEP 563: "from __future__ import annotations" are used 317 resolved_hints = typing.get_type_hints(cls) 318 except Exception as e: 319 raise SerdeError( 320 f"Failed to resolve type hints for {typename(cls)}:\n" 321 f"{e.__class__.__name__}: {e}\n\n" 322 f"If you are using forward references make sure you are calling deserialize & " 323 "serialize after all classes are globally visible." 324 ) from e 325 326 for f in raw_fields: 327 real_type = resolved_hints.get(f.name) 328 if real_type is not None: 329 f.type = real_type 330 if is_generic(real_type) and _is_sqlalchemy_inspectable(cls): 331 f.type = get_args(real_type)[0] 332 333 return iter(raw_fields) 334 335 336TypeLike = Union[type[Any], typing.Any] 337 338 339def iter_types(cls: type[Any]) -> list[type[Any]]: 340 """ 341 Iterate field types recursively. 342 343 The correct return type is `Iterator[Union[Type, typing._specialform]], 344 but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead. 345 """ 346 lst: set[Union[type[Any], Any]] = set() 347 348 def recursive(cls: Union[type[Any], Any]) -> None: 349 if cls in lst: 350 return 351 352 if is_dataclass(cls): 353 lst.add(cls) 354 if isinstance(cls, type): 355 for f in dataclass_fields(cls): 356 recursive(f.type) 357 elif isinstance(cls, str): 358 lst.add(cls) 359 elif is_opt(cls): 360 lst.add(Optional) 361 args = type_args(cls) 362 if args: 363 recursive(args[0]) 364 elif is_union(cls): 365 lst.add(Union) 366 for arg in type_args(cls): 367 recursive(arg) 368 elif is_list(cls): 369 lst.add(list) 370 args = type_args(cls) 371 if args: 372 recursive(args[0]) 373 elif is_set(cls): 374 lst.add(set) 375 args = type_args(cls) 376 if args: 377 recursive(args[0]) 378 elif is_deque(cls): 379 lst.add(deque) 380 args = type_args(cls) 381 if args: 382 recursive(args[0]) 383 elif is_counter(cls): 384 lst.add(Counter) 385 args = type_args(cls) 386 if args: 387 recursive(args[0]) 388 elif is_tuple(cls): 389 lst.add(tuple) 390 for arg in type_args(cls): 391 recursive(arg) 392 elif is_dict(cls): 393 lst.add(dict) 394 args = type_args(cls) 395 if args and len(args) >= 2: 396 recursive(args[0]) 397 recursive(args[1]) 398 elif is_pep695_type_alias(cls): 399 recursive(cls.__value__) 400 else: 401 lst.add(cls) 402 403 recursive(cls) 404 return list(lst) 405 406 407def iter_unions(cls: TypeLike) -> list[TypeLike]: 408 """ 409 Iterate over all unions that are used in the dataclass 410 """ 411 lst: list[TypeLike] = [] 412 stack: list[TypeLike] = [] # To prevent infinite recursion 413 414 def recursive(cls: TypeLike) -> None: 415 if cls in stack: 416 return 417 418 if is_union(cls): 419 lst.append(cls) 420 for arg in type_args(cls): 421 recursive(arg) 422 elif is_pep695_type_alias(cls): 423 recursive(cls.__value__) 424 if is_dataclass(cls): 425 stack.append(cls) 426 if isinstance(cls, type): 427 for f in dataclass_fields(cls): 428 recursive(f.type) 429 stack.pop() 430 elif is_opt(cls): 431 args = type_args(cls) 432 if args: 433 recursive(args[0]) 434 elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls): 435 args = type_args(cls) 436 if args: 437 recursive(args[0]) 438 elif is_tuple(cls): 439 for arg in type_args(cls): 440 recursive(arg) 441 elif is_dict(cls): 442 args = type_args(cls) 443 if args and len(args) >= 2: 444 recursive(args[0]) 445 recursive(args[1]) 446 447 recursive(cls) 448 return lst 449 450 451def iter_literals(cls: type[Any]) -> list[TypeLike]: 452 """ 453 Iterate over all literals that are used in the dataclass 454 """ 455 lst: set[Union[type[Any], Any]] = set() 456 stack: list[TypeLike] = [] # To prevent infinite recursion 457 458 def recursive(cls: Union[type[Any], Any]) -> None: 459 if cls in stack: 460 return 461 462 if is_literal(cls): 463 lst.add(cls) 464 if is_union(cls): 465 for arg in type_args(cls): 466 recursive(arg) 467 if is_dataclass(cls): 468 stack.append(cls) 469 if isinstance(cls, type): 470 for f in dataclass_fields(cls): 471 recursive(f.type) 472 stack.pop() 473 elif is_opt(cls): 474 args = type_args(cls) 475 if args: 476 recursive(args[0]) 477 elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls): 478 args = type_args(cls) 479 if args: 480 recursive(args[0]) 481 elif is_tuple(cls): 482 for arg in type_args(cls): 483 recursive(arg) 484 elif is_dict(cls): 485 args = type_args(cls) 486 if args and len(args) >= 2: 487 recursive(args[0]) 488 recursive(args[1]) 489 490 recursive(cls) 491 return list(lst) 492 493 494@cache 495def is_union(typ: Any) -> bool: 496 """ 497 Test if the type is `typing.Union`. 498 499 >>> is_union(Union[int, str]) 500 True 501 """ 502 503 try: 504 # When `_WithTagging` is received, it will check inner type. 505 if isinstance(typ, _WithTagging): 506 return is_union(typ.inner) 507 except Exception: 508 pass 509 510 # Python 3.10+ Union operator e.g. str | int 511 try: 512 if isinstance(typ, types.UnionType): 513 return True 514 except Exception: 515 pass 516 517 # typing.Union 518 return typing_inspect.is_union_type(typ) # type: ignore 519 520 521@cache 522def is_opt(typ: Any) -> bool: 523 """ 524 Test if the type is `typing.Optional`. 525 526 >>> is_opt(Optional[int]) 527 True 528 >>> is_opt(Optional) 529 True 530 >>> is_opt(None.__class__) 531 False 532 """ 533 534 # Python 3.10+ Union operator e.g. str | None 535 is_union_type = False 536 try: 537 if isinstance(typ, types.UnionType): 538 is_union_type = True 539 except Exception: 540 pass 541 542 # typing.Optional 543 is_typing_union = typing_inspect.is_optional_type(typ) 544 545 args = type_args(typ) 546 if args: 547 return ( 548 (is_union_type or is_typing_union) 549 and len(args) == 2 550 and not is_none(args[0]) 551 and is_none(args[1]) 552 ) 553 else: 554 return typ is Optional 555 556 557@cache 558def is_bare_opt(typ: Any) -> bool: 559 """ 560 Test if the type is `typing.Optional` without type args. 561 >>> is_bare_opt(Optional[int]) 562 False 563 >>> is_bare_opt(Optional) 564 True 565 >>> is_bare_opt(None.__class__) 566 False 567 """ 568 return not type_args(typ) and typ is Optional 569 570 571@cache 572def is_opt_dataclass(typ: Any) -> bool: 573 """ 574 Test if the type is optional dataclass. 575 576 >>> is_opt_dataclass(Optional[int]) 577 False 578 >>> @dataclasses.dataclass 579 ... class Foo: 580 ... pass 581 >>> is_opt_dataclass(Foo) 582 False 583 >>> is_opt_dataclass(Optional[Foo]) 584 False 585 """ 586 args = get_args(typ) 587 return is_opt(typ) and len(args) > 0 and is_dataclass(args[0]) 588 589 590@cache 591def is_list(typ: type[Any]) -> bool: 592 """ 593 Test if the type is `list`, `collections.abc.Sequence`, or `collections.abc.MutableSequence`. 594 595 >>> is_list(list[int]) 596 True 597 >>> is_list(list) 598 True 599 >>> is_list(Sequence[int]) 600 True 601 >>> is_list(Sequence) 602 True 603 >>> is_list(MutableSequence[int]) 604 True 605 >>> is_list(MutableSequence) 606 True 607 """ 608 origin = get_origin(typ) 609 if origin is None: 610 return typ in (list, Sequence, MutableSequence) 611 return origin in (list, Sequence, MutableSequence) 612 613 614@cache 615def is_bare_list(typ: type[Any]) -> bool: 616 """ 617 Test if the type is `list`/`collections.abc.Sequence`/`collections.abc.MutableSequence` 618 without type args. 619 620 >>> is_bare_list(list[int]) 621 False 622 >>> is_bare_list(list) 623 True 624 >>> is_bare_list(Sequence[int]) 625 False 626 >>> is_bare_list(Sequence) 627 True 628 >>> is_bare_list(MutableSequence[int]) 629 False 630 >>> is_bare_list(MutableSequence) 631 True 632 """ 633 origin = get_origin(typ) 634 if origin in (list, Sequence, MutableSequence): 635 return not type_args(typ) 636 return typ in (list, Sequence, MutableSequence) 637 638 639@cache 640def is_tuple(typ: Any) -> bool: 641 """ 642 Test if the type is tuple. 643 """ 644 try: 645 return issubclass(get_origin(typ), tuple) # type: ignore 646 except TypeError: 647 return typ is tuple 648 649 650@cache 651def is_bare_tuple(typ: type[Any]) -> bool: 652 """ 653 Test if the type is tuple without type args. 654 655 >>> is_bare_tuple(tuple[int, str]) 656 False 657 >>> is_bare_tuple(tuple) 658 True 659 """ 660 return typ is tuple 661 662 663@cache 664def is_variable_tuple(typ: type[Any]) -> bool: 665 """ 666 Test if the type is a variable length of tuple tuple[T, ...]`. 667 668 >>> is_variable_tuple(tuple[int, ...]) 669 True 670 >>> is_variable_tuple(tuple[int, bool]) 671 False 672 >>> is_variable_tuple(tuple[()]) 673 False 674 """ 675 istuple = is_tuple(typ) and not is_bare_tuple(typ) 676 args = get_args(typ) 677 return istuple and len(args) == 2 and is_ellipsis(args[1]) 678 679 680@cache 681def is_set(typ: type[Any]) -> bool: 682 """ 683 Test if the type is set-like. 684 685 >>> is_set(set[int]) 686 True 687 >>> is_set(set) 688 True 689 >>> is_set(frozenset[int]) 690 True 691 >>> from collections.abc import Set, MutableSet 692 >>> is_set(Set[int]) 693 True 694 >>> is_set(Set) 695 True 696 >>> is_set(MutableSet[int]) 697 True 698 >>> is_set(MutableSet) 699 True 700 """ 701 try: 702 return issubclass(get_origin(typ), (set, frozenset, Set, MutableSet)) # type: ignore[arg-type] 703 except TypeError: 704 return typ in (set, frozenset, Set, MutableSet) 705 706 707@cache 708def is_bare_set(typ: type[Any]) -> bool: 709 """ 710 Test if the type is `set`/`frozenset`/`Set`/`MutableSet` without type args. 711 712 >>> is_bare_set(set[int]) 713 False 714 >>> is_bare_set(set) 715 True 716 >>> from collections.abc import Set, MutableSet 717 >>> is_bare_set(Set) 718 True 719 >>> is_bare_set(MutableSet) 720 True 721 """ 722 origin = get_origin(typ) 723 if origin in (set, frozenset, Set, MutableSet): 724 return not type_args(typ) 725 return typ in (set, frozenset, Set, MutableSet) 726 727 728@cache 729def is_frozen_set(typ: type[Any]) -> bool: 730 """ 731 Test if the type is `frozenset`. 732 733 >>> is_frozen_set(frozenset[int]) 734 True 735 >>> is_frozen_set(set) 736 False 737 """ 738 try: 739 return issubclass(get_origin(typ), frozenset) # type: ignore 740 except TypeError: 741 return typ is frozenset 742 743 744@cache 745def is_dict(typ: type[Any]) -> bool: 746 """ 747 Test if the type is dict-like. 748 749 >>> is_dict(dict[int, int]) 750 True 751 >>> is_dict(dict) 752 True 753 >>> is_dict(defaultdict[int, int]) 754 True 755 >>> from collections.abc import Mapping, MutableMapping 756 >>> is_dict(Mapping[str, int]) 757 True 758 >>> is_dict(Mapping) 759 True 760 >>> is_dict(MutableMapping[str, int]) 761 True 762 >>> is_dict(MutableMapping) 763 True 764 """ 765 try: 766 return issubclass( 767 get_origin(typ), (dict, defaultdict, Mapping, MutableMapping) # type: ignore[arg-type] 768 ) 769 except TypeError: 770 return typ in (dict, defaultdict, Mapping, MutableMapping) 771 772 773@cache 774@cache 775def is_bare_dict(typ: type[Any]) -> bool: 776 """ 777 Test if the type is `dict`/`Mapping`/`MutableMapping` without type args. 778 779 >>> is_bare_dict(dict[int, str]) 780 False 781 >>> is_bare_dict(dict) 782 True 783 >>> from collections.abc import Mapping, MutableMapping 784 >>> is_bare_dict(Mapping) 785 True 786 >>> is_bare_dict(MutableMapping) 787 True 788 """ 789 origin = get_origin(typ) 790 if origin in (dict, Mapping, MutableMapping): 791 return not type_args(typ) 792 return typ in (dict, Mapping, MutableMapping) 793 794 795@cache 796def is_default_dict(typ: type[Any]) -> bool: 797 """ 798 Test if the type is `defaultdict`. 799 800 >>> is_default_dict(defaultdict[int, int]) 801 True 802 >>> is_default_dict(dict[int, int]) 803 False 804 """ 805 try: 806 return issubclass(get_origin(typ), defaultdict) # type: ignore 807 except TypeError: 808 return typ is defaultdict 809 810 811@cache 812def is_deque(typ: type[Any]) -> bool: 813 """ 814 Test if the type is `collections.deque`. 815 816 >>> is_deque(deque[int]) 817 True 818 >>> is_deque(deque) 819 True 820 >>> is_deque(list[int]) 821 False 822 """ 823 try: 824 return issubclass(get_origin(typ), deque) # type: ignore 825 except TypeError: 826 return typ is deque 827 828 829@cache 830def is_bare_deque(typ: type[Any]) -> bool: 831 """ 832 Test if the type is `collections.deque` without type args. 833 834 >>> is_bare_deque(deque[int]) 835 False 836 >>> is_bare_deque(deque) 837 True 838 """ 839 origin = get_origin(typ) 840 if origin is deque: 841 return not type_args(typ) 842 return typ is deque 843 844 845@cache 846def is_counter(typ: type[Any]) -> bool: 847 """ 848 Test if the type is `collections.Counter`. 849 850 >>> is_counter(Counter[str]) 851 True 852 >>> is_counter(Counter) 853 True 854 >>> is_counter(dict[str, int]) 855 False 856 """ 857 try: 858 return issubclass(get_origin(typ), Counter) # type: ignore 859 except TypeError: 860 return typ is Counter 861 862 863@cache 864def is_bare_counter(typ: type[Any]) -> bool: 865 """ 866 Test if the type is `collections.Counter` without type args. 867 868 >>> is_bare_counter(Counter[str]) 869 False 870 >>> is_bare_counter(Counter) 871 True 872 """ 873 origin = get_origin(typ) 874 if origin is Counter: 875 return not type_args(typ) 876 return typ is Counter 877 878 879@cache 880def is_none(typ: type[Any]) -> bool: 881 """ 882 >>> is_none(int) 883 False 884 >>> is_none(type(None)) 885 True 886 >>> is_none(None) 887 False 888 """ 889 return typ is type(None) # noqa 890 891 892PRIMITIVES = [int, float, bool, str] 893 894 895@cache 896def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: 897 """ 898 Test if the type is `enum.Enum`. 899 """ 900 try: 901 return issubclass(typ, enum.Enum) 902 except TypeError: 903 return isinstance(typ, enum.Enum) 904 905 906@cache 907def is_primitive_subclass(typ: type[Any]) -> bool: 908 """ 909 Test if the type is a subclass of primitive type. 910 911 >>> is_primitive_subclass(str) 912 False 913 >>> class Str(str): 914 ... pass 915 >>> is_primitive_subclass(Str) 916 True 917 """ 918 return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ) 919 920 921@cache 922def is_primitive(typ: type[Any] | NewType) -> bool: 923 """ 924 Test if the type is primitive. 925 926 >>> is_primitive(int) 927 True 928 >>> class CustomInt(int): 929 ... pass 930 >>> is_primitive(CustomInt) 931 True 932 """ 933 try: 934 return any(issubclass(typ, ty) for ty in PRIMITIVES) # type: ignore 935 except TypeError: 936 return is_new_type_primitive(typ) 937 938 939@cache 940def is_new_type_primitive(typ: type[Any] | NewType) -> bool: 941 """ 942 Test if the type is a NewType of primitives. 943 """ 944 inner = getattr(typ, "__supertype__", None) 945 if inner: 946 return is_primitive(inner) 947 else: 948 return any(isinstance(typ, ty) for ty in PRIMITIVES) 949 950 951@cache 952def has_generic_base(typ: Any) -> bool: 953 return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ()) 954 955 956@cache 957def is_generic(typ: Any) -> bool: 958 """ 959 Test if the type is derived from `typing.Generic`. 960 961 >>> T = typing.TypeVar('T') 962 >>> class GenericFoo(typing.Generic[T]): 963 ... pass 964 >>> is_generic(GenericFoo[int]) 965 True 966 >>> is_generic(GenericFoo) 967 False 968 """ 969 origin = get_origin(typ) 970 return origin is not None and has_generic_base(origin) 971 972 973@cache 974def is_class_var(typ: type[Any]) -> bool: 975 """ 976 Test if the type is `typing.ClassVar`. 977 978 >>> is_class_var(ClassVar[int]) 979 True 980 >>> is_class_var(ClassVar) 981 True 982 """ 983 return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore 984 985 986@cache 987def is_literal(typ: type[Any]) -> bool: 988 """ 989 Test if the type is derived from `typing.Literal`. 990 991 >>> T = typing.TypeVar('T') 992 >>> class GenericFoo(typing.Generic[T]): 993 ... pass 994 >>> is_generic(GenericFoo[int]) 995 True 996 >>> is_generic(GenericFoo) 997 False 998 """ 999 origin = get_origin(typ) 1000 return origin is not None and origin is typing.Literal 1001 1002 1003@cache 1004def is_any(typ: Any) -> bool: 1005 """ 1006 Test if the type is `typing.Any`. 1007 """ 1008 return typ is Any 1009 1010 1011@cache 1012def is_str_serializable(typ: type[Any]) -> bool: 1013 """ 1014 Test if the type is serializable to `str`. 1015 """ 1016 return typ in StrSerializableTypes or ( 1017 type(typ) is type and issubclass(typ, StrSerializableTypes) 1018 ) 1019 1020 1021def is_datetime( 1022 typ: type[Any], 1023) -> TypeGuard[datetime.date | datetime.time | datetime.datetime]: 1024 """ 1025 Test if the type is any of the datetime types.. 1026 """ 1027 return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes)) 1028 1029 1030def is_str_serializable_instance(obj: Any) -> bool: 1031 return isinstance(obj, StrSerializableTypes) 1032 1033 1034def is_datetime_instance(obj: Any) -> bool: 1035 return isinstance(obj, DateTimeTypes) 1036 1037 1038def is_ellipsis(typ: Any) -> bool: 1039 return typ is Ellipsis 1040 1041 1042def is_pep695_type_alias(typ: Any) -> bool: 1043 """ 1044 Test if the type is of PEP695 type alias. 1045 """ 1046 return isinstance(typ, _PEP695_TYPES) 1047 1048 1049@cache 1050def get_type_var_names(cls: type[Any]) -> list[str] | None: 1051 """ 1052 Get type argument names of a generic class. 1053 1054 >>> T = typing.TypeVar('T') 1055 >>> class GenericFoo(typing.Generic[T]): 1056 ... pass 1057 >>> get_type_var_names(GenericFoo) 1058 ['T'] 1059 >>> get_type_var_names(int) 1060 """ 1061 bases = getattr(cls, "__orig_bases__", ()) 1062 if not bases: 1063 return None 1064 1065 type_arg_names: list[str] = [] 1066 for base in bases: 1067 type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__")) 1068 1069 return type_arg_names 1070 1071 1072def find_generic_arg(cls: type[Any], field: TypeVar) -> int: 1073 """ 1074 Find a type in generic parameters. 1075 1076 >>> T = typing.TypeVar('T') 1077 >>> U = typing.TypeVar('U') 1078 >>> V = typing.TypeVar('V') 1079 >>> class GenericFoo(typing.Generic[T, U]): 1080 ... pass 1081 >>> find_generic_arg(GenericFoo, T) 1082 0 1083 >>> find_generic_arg(GenericFoo, U) 1084 1 1085 >>> find_generic_arg(GenericFoo, V) 1086 -1 1087 """ 1088 bases = getattr(cls, "__orig_bases__", ()) 1089 if not bases: 1090 raise Exception(f'"__orig_bases__" property was not found: {cls}') 1091 1092 for base in bases: 1093 for n, arg in enumerate(get_args(base)): 1094 if arg.__name__ == field.__name__: 1095 return n 1096 1097 if not bases: 1098 raise Exception(f"Generic field not found in class: {bases}") 1099 1100 return -1 1101 1102 1103def get_generic_arg( 1104 typ: Any, 1105 maybe_generic_type_vars: list[str] | None, 1106 variable_type_args: list[str] | None, 1107 index: int, 1108) -> Any: 1109 """ 1110 Get generic type argument. 1111 1112 >>> T = typing.TypeVar('T') 1113 >>> U = typing.TypeVar('U') 1114 >>> class GenericFoo(typing.Generic[T, U]): 1115 ... pass 1116 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__ 1117 'int' 1118 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__ 1119 'str' 1120 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__ 1121 'str' 1122 """ 1123 if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None: 1124 return typing.Any 1125 1126 args = get_args(typ) 1127 1128 if len(args) != len(maybe_generic_type_vars): 1129 raise SerdeError( 1130 f"Number of type args for {typ} does not match number of generic type vars: " 1131 f"\n type args: {args}\n type_vars: {maybe_generic_type_vars}" 1132 ) 1133 1134 # Get the name of the type var used for this field in the parent class definition 1135 type_var_name = variable_type_args[index] 1136 1137 try: 1138 # Find the slot of that type var in the original generic class definition 1139 orig_index = maybe_generic_type_vars.index(type_var_name) 1140 except ValueError: 1141 return typing.Any 1142 1143 return args[orig_index]