serde.compat
Compatibility layer which handles mostly differences of typing module between python versions.
1""" 2Compatibility layer which handles mostly differences of `typing` module between python versions. 3""" 4 5import dataclasses 6import datetime 7import decimal 8import enum 9import functools 10import ipaddress 11import itertools 12import pathlib 13import types 14import uuid 15import typing 16import typing_extensions 17from collections import defaultdict, deque, Counter 18from collections.abc import Iterator, Sequence, MutableSequence 19from collections.abc import Mapping, MutableMapping, Set, MutableSet 20from dataclasses import is_dataclass 21from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable 22 23import typing_inspect 24from typing_extensions import TypeGuard, ParamSpec 25 26# `typing_extensions.TypeAliasType` isn't always an alias to `typing.TypeAliasType` 27# depending on certain versions of `typing_extensions` and python. 28_PEP695_TYPES: tuple[type, ...] 29if hasattr(typing, "TypeAliasType"): 30 _PEP695_TYPES = (typing_extensions.TypeAliasType, typing.TypeAliasType) 31else: 32 _PEP695_TYPES = (typing_extensions.TypeAliasType,) 33 34 35# Lazy SQLAlchemy imports to improve startup time 36 37 38# Lazy SQLAlchemy import wrapper to improve startup time 39def _is_sqlalchemy_inspectable(subject: Any) -> bool: 40 from .sqlalchemy import is_sqlalchemy_inspectable 41 42 return is_sqlalchemy_inspectable(subject) 43 44 45def get_np_origin(tp: type[Any]) -> Any | None: 46 return None 47 48 49def get_np_args(tp: type[Any]) -> tuple[Any, ...]: 50 return () 51 52 53__all__: list[str] = [] 54 55T = TypeVar("T") 56 57 58StrSerializableTypes = ( 59 decimal.Decimal, 60 pathlib.Path, 61 pathlib.PosixPath, 62 pathlib.WindowsPath, 63 pathlib.PurePath, 64 pathlib.PurePosixPath, 65 pathlib.PureWindowsPath, 66 uuid.UUID, 67 ipaddress.IPv4Address, 68 ipaddress.IPv6Address, 69 ipaddress.IPv4Network, 70 ipaddress.IPv6Network, 71 ipaddress.IPv4Interface, 72 ipaddress.IPv6Interface, 73) 74""" List of standard types (de)serializable to str """ 75 76DateTimeTypes = (datetime.date, datetime.time, datetime.datetime) 77""" List of datetime types """ 78 79 80@dataclasses.dataclass(unsafe_hash=True) 81class _WithTagging(Generic[T]): 82 """ 83 Intermediate data structure for (de)serializaing Union without dataclass. 84 """ 85 86 inner: T 87 """ Union type .e.g Union[Foo,Bar] passed in from_obj. """ 88 tagging: Any 89 """ Union Tagging """ 90 91 92class SerdeError(Exception): 93 """ 94 Serde error class. 95 """ 96 97 98@dataclasses.dataclass 99class UserError(Exception): 100 """ 101 Error from user code e.g. __post_init__. 102 """ 103 104 inner: Exception 105 106 107class SerdeSkip(Exception): 108 """ 109 Skip a field in custom (de)serializer. 110 """ 111 112 113def is_hashable(typ: Any) -> TypeGuard[Hashable]: 114 """ 115 Test is an object hashable 116 """ 117 try: 118 hash(typ) 119 except TypeError: 120 return False 121 return True 122 123 124P = ParamSpec("P") 125 126 127def cache(f: Callable[P, T]) -> Callable[P, T]: 128 """ 129 Wrapper for `functools.cache` to avoid `Hashable` related type errors. 130 """ 131 132 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 133 return f(*args, **kwargs) 134 135 return functools.cache(wrapper) # type: ignore 136 137 138@cache 139def get_origin(typ: Any) -> Any | None: 140 """ 141 Provide `get_origin` that works in all python versions. 142 """ 143 return typing.get_origin(typ) or get_np_origin(typ) 144 145 146@cache 147def get_args(typ: type[Any]) -> tuple[Any, ...]: 148 """ 149 Provide `get_args` that works in all python versions. 150 """ 151 return typing.get_args(typ) or get_np_args(typ) 152 153 154@cache 155def typename(typ: Any, with_typing_module: bool = False) -> str: 156 """ 157 >>> from typing import Any 158 >>> typename(int) 159 'int' 160 >>> class Foo: pass 161 >>> typename(Foo) 162 'Foo' 163 >>> typename(list[Foo]) 164 'list[Foo]' 165 >>> typename(dict[str, Foo]) 166 'dict[str, Foo]' 167 >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]]) 168 'tuple[int, str, Foo, list[int], dict[str, Foo]]' 169 >>> typename(Optional[list[Foo]]) 170 'Optional[list[Foo]]' 171 >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]]) 172 'Union[Optional[Foo], list[Foo], str, int]' 173 >>> typename(set[Foo]) 174 'set[Foo]' 175 >>> typename(Any) 176 'Any' 177 """ 178 mod = "typing." if with_typing_module else "" 179 thisfunc = functools.partial(typename, with_typing_module=with_typing_module) 180 if is_opt(typ): 181 args = type_args(typ) 182 if args: 183 return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]" 184 else: 185 return f"{mod}Optional" 186 elif is_union(typ): 187 args = union_args(typ) 188 if args: 189 return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]' 190 else: 191 return f"{mod}Union" 192 elif is_list(typ): 193 args = type_args(typ) 194 if args: 195 et = thisfunc(args[0]) 196 return f"{mod}list[{et}]" 197 else: 198 return f"{mod}list" 199 elif is_set(typ): 200 args = type_args(typ) 201 if args: 202 et = thisfunc(args[0]) 203 return f"{mod}set[{et}]" 204 else: 205 return f"{mod}set" 206 elif is_dict(typ): 207 args = type_args(typ) 208 if args and len(args) == 2: 209 kt = thisfunc(args[0]) 210 vt = thisfunc(args[1]) 211 return f"{mod}dict[{kt}, {vt}]" 212 else: 213 return f"{mod}dict" 214 elif is_deque(typ): 215 args = type_args(typ) 216 if args: 217 et = thisfunc(args[0]) 218 return f"deque[{et}]" 219 else: 220 return "deque" 221 elif is_counter(typ): 222 args = type_args(typ) 223 if args: 224 et = thisfunc(args[0]) 225 return f"Counter[{et}]" 226 else: 227 return "Counter" 228 elif is_tuple(typ): 229 args = type_args(typ) 230 if args: 231 return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]' 232 else: 233 return f"{mod}tuple" 234 elif is_generic(typ): 235 origin = get_origin(typ) 236 if origin is None: 237 raise SerdeError("Could not extract origin class from generic class") 238 239 if not isinstance(origin.__name__, str): 240 raise SerdeError("Name of generic class is not string") 241 242 return origin.__name__ 243 244 elif is_literal(typ): 245 args = type_args(typ) 246 if not args: 247 raise TypeError("Literal type requires at least one literal argument") 248 return f'Literal[{", ".join(stringify_literal(e) for e in args)}]' 249 elif typ is Any: 250 return f"{mod}Any" 251 elif is_ellipsis(typ): 252 return "..." 253 else: 254 # Get super type for NewType 255 inner = getattr(typ, "__supertype__", None) 256 if inner: 257 return typename(inner) 258 259 name: str | None = getattr(typ, "_name", None) 260 if name: 261 return name 262 else: 263 name = getattr(typ, "__name__", None) 264 if isinstance(name, str): 265 return name 266 else: 267 raise SerdeError(f"Could not get a type name from: {typ}") 268 269 270def stringify_literal(v: Any) -> str: 271 if isinstance(v, str): 272 return f"'{v}'" 273 else: 274 return str(v) 275 276 277def type_args(typ: Any) -> tuple[type[Any], ...]: 278 """ 279 Wrapper to suppress type error for accessing private members. 280 """ 281 try: 282 args: tuple[type[Any, ...]] | None = typ.__args__ # type: ignore 283 except AttributeError: 284 return get_args(typ) 285 286 # Some typing objects expose __args__ as a member_descriptor (e.g. typing.Union), 287 # which isn't iterable. Fall back to typing.get_args in that case. 288 if isinstance(args, tuple): 289 return args 290 return get_args(typ) 291 292 293def union_args(typ: Any) -> tuple[type[Any], ...]: 294 if not is_union(typ): 295 raise TypeError(f"{typ} is not Union") 296 args = type_args(typ) 297 if len(args) == 1: 298 return (args[0],) 299 it = iter(args) 300 types = [] 301 for i1, i2 in itertools.zip_longest(it, it): 302 if not i2: 303 types.append(i1) 304 elif is_none(i2): 305 types.append(Optional[i1]) 306 else: 307 types.extend((i1, i2)) 308 return tuple(types) 309 310 311def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ignore 312 raw_fields = dataclasses.fields(cls) 313 314 try: 315 # this resolves types when string forward reference 316 # or PEP 563: "from __future__ import annotations" are used 317 resolved_hints = typing.get_type_hints(cls) 318 except Exception as e: 319 raise SerdeError( 320 f"Failed to resolve type hints for {typename(cls)}:\n" 321 f"{e.__class__.__name__}: {e}\n\n" 322 f"If you are using forward references make sure you are calling deserialize & " 323 "serialize after all classes are globally visible." 324 ) from e 325 326 for f in raw_fields: 327 real_type = resolved_hints.get(f.name) 328 if real_type is not None: 329 f.type = real_type 330 if is_generic(real_type) and _is_sqlalchemy_inspectable(cls): 331 f.type = get_args(real_type)[0] 332 333 return iter(raw_fields) 334 335 336TypeLike = Union[type[Any], typing.Any] 337 338 339def iter_types(cls: type[Any]) -> list[type[Any]]: 340 """ 341 Iterate field types recursively. 342 343 The correct return type is `Iterator[Union[Type, typing._specialform]], 344 but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead. 345 """ 346 lst: set[Union[type[Any], Any]] = set() 347 348 def recursive(cls: Union[type[Any], Any]) -> None: 349 if cls in lst: 350 return 351 352 if is_dataclass(cls): 353 lst.add(cls) 354 if isinstance(cls, type): 355 for f in dataclass_fields(cls): 356 recursive(f.type) 357 elif isinstance(cls, str): 358 lst.add(cls) 359 elif is_opt(cls): 360 lst.add(Optional) 361 args = type_args(cls) 362 if args: 363 recursive(args[0]) 364 elif is_union(cls): 365 lst.add(Union) 366 for arg in type_args(cls): 367 recursive(arg) 368 elif is_list(cls): 369 lst.add(list) 370 args = type_args(cls) 371 if args: 372 recursive(args[0]) 373 elif is_set(cls): 374 lst.add(set) 375 args = type_args(cls) 376 if args: 377 recursive(args[0]) 378 elif is_deque(cls): 379 lst.add(deque) 380 args = type_args(cls) 381 if args: 382 recursive(args[0]) 383 elif is_counter(cls): 384 lst.add(Counter) 385 args = type_args(cls) 386 if args: 387 recursive(args[0]) 388 elif is_tuple(cls): 389 lst.add(tuple) 390 for arg in type_args(cls): 391 recursive(arg) 392 elif is_dict(cls): 393 lst.add(dict) 394 args = type_args(cls) 395 if args and len(args) >= 2: 396 recursive(args[0]) 397 recursive(args[1]) 398 elif is_pep695_type_alias(cls): 399 recursive(cls.__value__) 400 else: 401 lst.add(cls) 402 403 recursive(cls) 404 return list(lst) 405 406 407def iter_unions(cls: TypeLike) -> list[TypeLike]: 408 """ 409 Iterate over all unions that are used in the dataclass 410 """ 411 lst: list[TypeLike] = [] 412 stack: list[TypeLike] = [] # To prevent infinite recursion 413 414 def recursive(cls: TypeLike) -> None: 415 if cls in stack: 416 return 417 418 if is_union(cls): 419 lst.append(cls) 420 for arg in type_args(cls): 421 recursive(arg) 422 elif is_pep695_type_alias(cls): 423 recursive(cls.__value__) 424 if is_dataclass(cls): 425 stack.append(cls) 426 if isinstance(cls, type): 427 for f in dataclass_fields(cls): 428 recursive(f.type) 429 stack.pop() 430 elif is_opt(cls): 431 args = type_args(cls) 432 if args: 433 recursive(args[0]) 434 elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls): 435 args = type_args(cls) 436 if args: 437 recursive(args[0]) 438 elif is_tuple(cls): 439 for arg in type_args(cls): 440 recursive(arg) 441 elif is_dict(cls): 442 args = type_args(cls) 443 if args and len(args) >= 2: 444 recursive(args[0]) 445 recursive(args[1]) 446 447 recursive(cls) 448 return lst 449 450 451def iter_literals(cls: type[Any]) -> list[TypeLike]: 452 """ 453 Iterate over all literals that are used in the dataclass 454 """ 455 lst: set[Union[type[Any], Any]] = set() 456 stack: list[TypeLike] = [] # To prevent infinite recursion 457 458 def recursive(cls: Union[type[Any], Any]) -> None: 459 if cls in stack: 460 return 461 462 if is_literal(cls): 463 lst.add(cls) 464 if is_union(cls): 465 for arg in type_args(cls): 466 recursive(arg) 467 if is_dataclass(cls): 468 stack.append(cls) 469 if isinstance(cls, type): 470 for f in dataclass_fields(cls): 471 recursive(f.type) 472 stack.pop() 473 elif is_opt(cls): 474 args = type_args(cls) 475 if args: 476 recursive(args[0]) 477 elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls): 478 args = type_args(cls) 479 if args: 480 recursive(args[0]) 481 elif is_tuple(cls): 482 for arg in type_args(cls): 483 recursive(arg) 484 elif is_dict(cls): 485 args = type_args(cls) 486 if args and len(args) >= 2: 487 recursive(args[0]) 488 recursive(args[1]) 489 490 recursive(cls) 491 return list(lst) 492 493 494@cache 495def is_union(typ: Any) -> bool: 496 """ 497 Test if the type is `typing.Union`. 498 499 >>> is_union(Union[int, str]) 500 True 501 """ 502 503 try: 504 # When `_WithTagging` is received, it will check inner type. 505 if isinstance(typ, _WithTagging): 506 return is_union(typ.inner) 507 except Exception: 508 pass 509 510 # Python 3.10+ Union operator e.g. str | int 511 try: 512 if isinstance(typ, types.UnionType): 513 return True 514 except Exception: 515 pass 516 517 # typing.Union 518 return typing_inspect.is_union_type(typ) # type: ignore 519 520 521@cache 522def is_opt(typ: Any) -> bool: 523 """ 524 Test if the type is `typing.Optional`. 525 526 >>> is_opt(Optional[int]) 527 True 528 >>> is_opt(Optional) 529 True 530 >>> is_opt(None.__class__) 531 False 532 """ 533 534 # Python 3.10+ Union operator e.g. str | None 535 is_union_type = False 536 try: 537 if isinstance(typ, types.UnionType): 538 is_union_type = True 539 except Exception: 540 pass 541 542 # typing.Optional 543 is_typing_union = typing_inspect.is_optional_type(typ) 544 545 args = type_args(typ) 546 if args: 547 return ( 548 (is_union_type or is_typing_union) 549 and len(args) == 2 550 and not is_none(args[0]) 551 and is_none(args[1]) 552 ) 553 else: 554 return typ is Optional 555 556 557@cache 558def is_bare_opt(typ: Any) -> bool: 559 """ 560 Test if the type is `typing.Optional` without type args. 561 >>> is_bare_opt(Optional[int]) 562 False 563 >>> is_bare_opt(Optional) 564 True 565 >>> is_bare_opt(None.__class__) 566 False 567 """ 568 return not type_args(typ) and typ is Optional 569 570 571@cache 572def is_opt_dataclass(typ: Any) -> bool: 573 """ 574 Test if the type is optional dataclass. 575 576 >>> is_opt_dataclass(Optional[int]) 577 False 578 >>> @dataclasses.dataclass 579 ... class Foo: 580 ... pass 581 >>> is_opt_dataclass(Foo) 582 False 583 >>> is_opt_dataclass(Optional[Foo]) 584 False 585 """ 586 args = get_args(typ) 587 return is_opt(typ) and len(args) > 0 and is_dataclass(args[0]) 588 589 590@cache 591def is_list(typ: type[Any]) -> bool: 592 """ 593 Test if the type is `list`, `collections.abc.Sequence`, or `collections.abc.MutableSequence`. 594 595 >>> is_list(list[int]) 596 True 597 >>> is_list(list) 598 True 599 >>> is_list(Sequence[int]) 600 True 601 >>> is_list(Sequence) 602 True 603 >>> is_list(MutableSequence[int]) 604 True 605 >>> is_list(MutableSequence) 606 True 607 """ 608 origin = get_origin(typ) 609 if origin is None: 610 return typ in (list, Sequence, MutableSequence) 611 return origin in (list, Sequence, MutableSequence) 612 613 614@cache 615def is_bare_list(typ: type[Any]) -> bool: 616 """ 617 Test if the type is `list`/`collections.abc.Sequence`/`collections.abc.MutableSequence` 618 without type args. 619 620 >>> is_bare_list(list[int]) 621 False 622 >>> is_bare_list(list) 623 True 624 >>> is_bare_list(Sequence[int]) 625 False 626 >>> is_bare_list(Sequence) 627 True 628 >>> is_bare_list(MutableSequence[int]) 629 False 630 >>> is_bare_list(MutableSequence) 631 True 632 """ 633 origin = get_origin(typ) 634 if origin in (list, Sequence, MutableSequence): 635 return not type_args(typ) 636 return typ in (list, Sequence, MutableSequence) 637 638 639@cache 640def is_tuple(typ: Any) -> bool: 641 """ 642 Test if the type is tuple. 643 """ 644 try: 645 return issubclass(get_origin(typ), tuple) # type: ignore 646 except TypeError: 647 return typ is tuple 648 649 650@cache 651def is_bare_tuple(typ: type[Any]) -> bool: 652 """ 653 Test if the type is tuple without type args. 654 655 >>> is_bare_tuple(tuple[int, str]) 656 False 657 >>> is_bare_tuple(tuple) 658 True 659 """ 660 return typ is tuple 661 662 663@cache 664def is_variable_tuple(typ: type[Any]) -> bool: 665 """ 666 Test if the type is a variable length of tuple tuple[T, ...]`. 667 668 >>> is_variable_tuple(tuple[int, ...]) 669 True 670 >>> is_variable_tuple(tuple[int, bool]) 671 False 672 >>> is_variable_tuple(tuple[()]) 673 False 674 """ 675 istuple = is_tuple(typ) and not is_bare_tuple(typ) 676 args = get_args(typ) 677 return istuple and len(args) == 2 and is_ellipsis(args[1]) 678 679 680@cache 681def is_set(typ: type[Any]) -> bool: 682 """ 683 Test if the type is set-like. 684 685 >>> is_set(set[int]) 686 True 687 >>> is_set(set) 688 True 689 >>> is_set(frozenset[int]) 690 True 691 >>> from collections.abc import Set, MutableSet 692 >>> is_set(Set[int]) 693 True 694 >>> is_set(Set) 695 True 696 >>> is_set(MutableSet[int]) 697 True 698 >>> is_set(MutableSet) 699 True 700 """ 701 try: 702 return issubclass(get_origin(typ), (set, frozenset, Set, MutableSet)) # type: ignore[arg-type] 703 except TypeError: 704 return typ in (set, frozenset, Set, MutableSet) 705 706 707@cache 708def is_bare_set(typ: type[Any]) -> bool: 709 """ 710 Test if the type is `set`/`frozenset`/`Set`/`MutableSet` without type args. 711 712 >>> is_bare_set(set[int]) 713 False 714 >>> is_bare_set(set) 715 True 716 >>> from collections.abc import Set, MutableSet 717 >>> is_bare_set(Set) 718 True 719 >>> is_bare_set(MutableSet) 720 True 721 """ 722 origin = get_origin(typ) 723 if origin in (set, frozenset, Set, MutableSet): 724 return not type_args(typ) 725 return typ in (set, frozenset, Set, MutableSet) 726 727 728@cache 729def is_frozen_set(typ: type[Any]) -> bool: 730 """ 731 Test if the type is `frozenset`. 732 733 >>> is_frozen_set(frozenset[int]) 734 True 735 >>> is_frozen_set(set) 736 False 737 """ 738 try: 739 return issubclass(get_origin(typ), frozenset) # type: ignore 740 except TypeError: 741 return typ is frozenset 742 743 744@cache 745def is_dict(typ: type[Any]) -> bool: 746 """ 747 Test if the type is dict-like. 748 749 >>> is_dict(dict[int, int]) 750 True 751 >>> is_dict(dict) 752 True 753 >>> is_dict(defaultdict[int, int]) 754 True 755 >>> from collections.abc import Mapping, MutableMapping 756 >>> is_dict(Mapping[str, int]) 757 True 758 >>> is_dict(Mapping) 759 True 760 >>> is_dict(MutableMapping[str, int]) 761 True 762 >>> is_dict(MutableMapping) 763 True 764 """ 765 try: 766 return issubclass( 767 get_origin(typ), (dict, defaultdict, Mapping, MutableMapping) # type: ignore[arg-type] 768 ) 769 except TypeError: 770 return typ in (dict, defaultdict, Mapping, MutableMapping) 771 772 773def is_flatten_dict(typ: Any) -> bool: 774 """ 775 Test if the type is dict[str, Any] or bare dict suitable for flatten. 776 777 >>> is_flatten_dict(dict[str, Any]) 778 True 779 >>> is_flatten_dict(dict) 780 True 781 >>> is_flatten_dict(dict[str, int]) 782 False 783 >>> is_flatten_dict(dict[int, str]) 784 False 785 >>> is_flatten_dict(list[str]) 786 False 787 """ 788 if not is_dict(typ): 789 return False 790 # Allow bare dict 791 if is_bare_dict(typ): 792 return True 793 args = type_args(typ) 794 if not args or len(args) != 2: 795 return False 796 # Key must be str, value must be Any 797 return args[0] is str and is_any(args[1]) 798 799 800@cache 801@cache 802def is_bare_dict(typ: type[Any]) -> bool: 803 """ 804 Test if the type is `dict`/`Mapping`/`MutableMapping` without type args. 805 806 >>> is_bare_dict(dict[int, str]) 807 False 808 >>> is_bare_dict(dict) 809 True 810 >>> from collections.abc import Mapping, MutableMapping 811 >>> is_bare_dict(Mapping) 812 True 813 >>> is_bare_dict(MutableMapping) 814 True 815 """ 816 origin = get_origin(typ) 817 if origin in (dict, Mapping, MutableMapping): 818 return not type_args(typ) 819 return typ in (dict, Mapping, MutableMapping) 820 821 822@cache 823def is_default_dict(typ: type[Any]) -> bool: 824 """ 825 Test if the type is `defaultdict`. 826 827 >>> is_default_dict(defaultdict[int, int]) 828 True 829 >>> is_default_dict(dict[int, int]) 830 False 831 """ 832 try: 833 return issubclass(get_origin(typ), defaultdict) # type: ignore 834 except TypeError: 835 return typ is defaultdict 836 837 838@cache 839def is_deque(typ: type[Any]) -> bool: 840 """ 841 Test if the type is `collections.deque`. 842 843 >>> is_deque(deque[int]) 844 True 845 >>> is_deque(deque) 846 True 847 >>> is_deque(list[int]) 848 False 849 """ 850 try: 851 return issubclass(get_origin(typ), deque) # type: ignore 852 except TypeError: 853 return typ is deque 854 855 856@cache 857def is_bare_deque(typ: type[Any]) -> bool: 858 """ 859 Test if the type is `collections.deque` without type args. 860 861 >>> is_bare_deque(deque[int]) 862 False 863 >>> is_bare_deque(deque) 864 True 865 """ 866 origin = get_origin(typ) 867 if origin is deque: 868 return not type_args(typ) 869 return typ is deque 870 871 872@cache 873def is_counter(typ: type[Any]) -> bool: 874 """ 875 Test if the type is `collections.Counter`. 876 877 >>> is_counter(Counter[str]) 878 True 879 >>> is_counter(Counter) 880 True 881 >>> is_counter(dict[str, int]) 882 False 883 """ 884 try: 885 return issubclass(get_origin(typ), Counter) # type: ignore 886 except TypeError: 887 return typ is Counter 888 889 890@cache 891def is_bare_counter(typ: type[Any]) -> bool: 892 """ 893 Test if the type is `collections.Counter` without type args. 894 895 >>> is_bare_counter(Counter[str]) 896 False 897 >>> is_bare_counter(Counter) 898 True 899 """ 900 origin = get_origin(typ) 901 if origin is Counter: 902 return not type_args(typ) 903 return typ is Counter 904 905 906@cache 907def is_none(typ: type[Any]) -> bool: 908 """ 909 >>> is_none(int) 910 False 911 >>> is_none(type(None)) 912 True 913 >>> is_none(None) 914 False 915 """ 916 return typ is type(None) # noqa 917 918 919PRIMITIVES = [int, float, bool, str] 920 921 922@cache 923def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: 924 """ 925 Test if the type is `enum.Enum`. 926 """ 927 try: 928 return issubclass(typ, enum.Enum) 929 except TypeError: 930 return isinstance(typ, enum.Enum) 931 932 933@cache 934def is_primitive_subclass(typ: type[Any]) -> bool: 935 """ 936 Test if the type is a subclass of primitive type. 937 938 >>> is_primitive_subclass(str) 939 False 940 >>> class Str(str): 941 ... pass 942 >>> is_primitive_subclass(Str) 943 True 944 """ 945 return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ) 946 947 948@cache 949def is_primitive(typ: type[Any] | NewType) -> bool: 950 """ 951 Test if the type is primitive. 952 953 >>> is_primitive(int) 954 True 955 >>> class CustomInt(int): 956 ... pass 957 >>> is_primitive(CustomInt) 958 True 959 """ 960 try: 961 return any(issubclass(typ, ty) for ty in PRIMITIVES) # type: ignore 962 except TypeError: 963 return is_new_type_primitive(typ) 964 965 966@cache 967def is_new_type_primitive(typ: type[Any] | NewType) -> bool: 968 """ 969 Test if the type is a NewType of primitives. 970 """ 971 inner = getattr(typ, "__supertype__", None) 972 if inner: 973 return is_primitive(inner) 974 else: 975 return any(isinstance(typ, ty) for ty in PRIMITIVES) 976 977 978@cache 979def has_generic_base(typ: Any) -> bool: 980 return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ()) 981 982 983@cache 984def is_generic(typ: Any) -> bool: 985 """ 986 Test if the type is derived from `typing.Generic`. 987 988 >>> T = typing.TypeVar('T') 989 >>> class GenericFoo(typing.Generic[T]): 990 ... pass 991 >>> is_generic(GenericFoo[int]) 992 True 993 >>> is_generic(GenericFoo) 994 False 995 """ 996 origin = get_origin(typ) 997 return origin is not None and has_generic_base(origin) 998 999 1000@cache 1001def is_class_var(typ: type[Any]) -> bool: 1002 """ 1003 Test if the type is `typing.ClassVar`. 1004 1005 >>> is_class_var(ClassVar[int]) 1006 True 1007 >>> is_class_var(ClassVar) 1008 True 1009 """ 1010 return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore 1011 1012 1013@cache 1014def is_literal(typ: type[Any]) -> bool: 1015 """ 1016 Test if the type is derived from `typing.Literal`. 1017 1018 >>> T = typing.TypeVar('T') 1019 >>> class GenericFoo(typing.Generic[T]): 1020 ... pass 1021 >>> is_generic(GenericFoo[int]) 1022 True 1023 >>> is_generic(GenericFoo) 1024 False 1025 """ 1026 origin = get_origin(typ) 1027 return origin is not None and origin is typing.Literal 1028 1029 1030@cache 1031def is_any(typ: Any) -> bool: 1032 """ 1033 Test if the type is `typing.Any`. 1034 """ 1035 return typ is Any 1036 1037 1038@cache 1039def is_str_serializable(typ: type[Any]) -> bool: 1040 """ 1041 Test if the type is serializable to `str`. 1042 """ 1043 return typ in StrSerializableTypes or ( 1044 type(typ) is type and issubclass(typ, StrSerializableTypes) 1045 ) 1046 1047 1048def is_datetime( 1049 typ: type[Any], 1050) -> TypeGuard[datetime.date | datetime.time | datetime.datetime]: 1051 """ 1052 Test if the type is any of the datetime types.. 1053 """ 1054 return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes)) 1055 1056 1057def is_str_serializable_instance(obj: Any) -> bool: 1058 return isinstance(obj, StrSerializableTypes) 1059 1060 1061def is_datetime_instance(obj: Any) -> bool: 1062 return isinstance(obj, DateTimeTypes) 1063 1064 1065def is_ellipsis(typ: Any) -> bool: 1066 return typ is Ellipsis 1067 1068 1069def is_pep695_type_alias(typ: Any) -> bool: 1070 """ 1071 Test if the type is of PEP695 type alias. 1072 """ 1073 return isinstance(typ, _PEP695_TYPES) 1074 1075 1076@cache 1077def get_type_var_names(cls: type[Any]) -> list[str] | None: 1078 """ 1079 Get type argument names of a generic class. 1080 1081 >>> T = typing.TypeVar('T') 1082 >>> class GenericFoo(typing.Generic[T]): 1083 ... pass 1084 >>> get_type_var_names(GenericFoo) 1085 ['T'] 1086 >>> get_type_var_names(int) 1087 """ 1088 bases = getattr(cls, "__orig_bases__", ()) 1089 if not bases: 1090 return None 1091 1092 type_arg_names: list[str] = [] 1093 for base in bases: 1094 type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__")) 1095 1096 return type_arg_names 1097 1098 1099def find_generic_arg(cls: type[Any], field: TypeVar) -> int: 1100 """ 1101 Find a type in generic parameters. 1102 1103 >>> T = typing.TypeVar('T') 1104 >>> U = typing.TypeVar('U') 1105 >>> V = typing.TypeVar('V') 1106 >>> class GenericFoo(typing.Generic[T, U]): 1107 ... pass 1108 >>> find_generic_arg(GenericFoo, T) 1109 0 1110 >>> find_generic_arg(GenericFoo, U) 1111 1 1112 >>> find_generic_arg(GenericFoo, V) 1113 -1 1114 """ 1115 bases = getattr(cls, "__orig_bases__", ()) 1116 if not bases: 1117 raise Exception(f'"__orig_bases__" property was not found: {cls}') 1118 1119 for base in bases: 1120 for n, arg in enumerate(get_args(base)): 1121 if arg.__name__ == field.__name__: 1122 return n 1123 1124 if not bases: 1125 raise Exception(f"Generic field not found in class: {bases}") 1126 1127 return -1 1128 1129 1130def get_generic_arg( 1131 typ: Any, 1132 maybe_generic_type_vars: list[str] | None, 1133 variable_type_args: list[str] | None, 1134 index: int, 1135) -> Any: 1136 """ 1137 Get generic type argument. 1138 1139 >>> T = typing.TypeVar('T') 1140 >>> U = typing.TypeVar('U') 1141 >>> class GenericFoo(typing.Generic[T, U]): 1142 ... pass 1143 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__ 1144 'int' 1145 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__ 1146 'str' 1147 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__ 1148 'str' 1149 """ 1150 if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None: 1151 return typing.Any 1152 1153 args = get_args(typ) 1154 1155 if len(args) != len(maybe_generic_type_vars): 1156 raise SerdeError( 1157 f"Number of type args for {typ} does not match number of generic type vars: " 1158 f"\n type args: {args}\n type_vars: {maybe_generic_type_vars}" 1159 ) 1160 1161 # Get the name of the type var used for this field in the parent class definition 1162 type_var_name = variable_type_args[index] 1163 1164 try: 1165 # Find the slot of that type var in the original generic class definition 1166 orig_index = maybe_generic_type_vars.index(type_var_name) 1167 except ValueError: 1168 return typing.Any 1169 1170 return args[orig_index]