serde.compat
Compatibility layer which handles mostly differences of typing
module between python versions.
1""" 2Compatibility layer which handles mostly differences of `typing` module between python versions. 3""" 4 5import dataclasses 6import datetime 7import decimal 8import enum 9import functools 10import ipaddress 11import itertools 12import pathlib 13import sys 14import types 15import uuid 16import typing 17from collections import defaultdict 18from collections.abc import Iterator 19from dataclasses import is_dataclass 20from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable 21 22import typing_inspect 23from typing_extensions import TypeGuard, ParamSpec, TypeAliasType 24 25from .sqlalchemy import is_sqlalchemy_inspectable 26 27 28def get_np_origin(tp: type[Any]) -> Optional[Any]: 29 return None 30 31 32def get_np_args(tp: type[Any]) -> tuple[Any, ...]: 33 return () 34 35 36__all__: list[str] = [] 37 38T = TypeVar("T") 39 40 41StrSerializableTypes = ( 42 decimal.Decimal, 43 pathlib.Path, 44 pathlib.PosixPath, 45 pathlib.WindowsPath, 46 pathlib.PurePath, 47 pathlib.PurePosixPath, 48 pathlib.PureWindowsPath, 49 uuid.UUID, 50 ipaddress.IPv4Address, 51 ipaddress.IPv6Address, 52 ipaddress.IPv4Network, 53 ipaddress.IPv6Network, 54 ipaddress.IPv4Interface, 55 ipaddress.IPv6Interface, 56) 57""" List of standard types (de)serializable to str """ 58 59DateTimeTypes = (datetime.date, datetime.time, datetime.datetime) 60""" List of datetime types """ 61 62 63@dataclasses.dataclass(unsafe_hash=True) 64class _WithTagging(Generic[T]): 65 """ 66 Intermediate data structure for (de)serializaing Union without dataclass. 67 """ 68 69 inner: T 70 """ Union type .e.g Union[Foo,Bar] passed in from_obj. """ 71 tagging: Any 72 """ Union Tagging """ 73 74 75class SerdeError(Exception): 76 """ 77 Serde error class. 78 """ 79 80 81@dataclasses.dataclass 82class UserError(Exception): 83 """ 84 Error from user code e.g. __post_init__. 85 """ 86 87 inner: Exception 88 89 90class SerdeSkip(Exception): 91 """ 92 Skip a field in custom (de)serializer. 93 """ 94 95 96def is_hashable(typ: Any) -> TypeGuard[Hashable]: 97 """ 98 Test is an object hashable 99 """ 100 try: 101 hash(typ) 102 except TypeError: 103 return False 104 return True 105 106 107P = ParamSpec("P") 108 109 110def cache(f: Callable[P, T]) -> Callable[P, T]: 111 """ 112 Wrapper for `functools.cache` to avoid `Hashable` related type errors. 113 """ 114 115 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 116 return f(*args, **kwargs) 117 118 return functools.cache(wrapper) # type: ignore 119 120 121@cache 122def get_origin(typ: Any) -> Optional[Any]: 123 """ 124 Provide `get_origin` that works in all python versions. 125 """ 126 return typing.get_origin(typ) or get_np_origin(typ) 127 128 129@cache 130def get_args(typ: type[Any]) -> tuple[Any, ...]: 131 """ 132 Provide `get_args` that works in all python versions. 133 """ 134 return typing.get_args(typ) or get_np_args(typ) 135 136 137@cache 138def typename(typ: Any, with_typing_module: bool = False) -> str: 139 """ 140 >>> from typing import Any 141 >>> typename(int) 142 'int' 143 >>> class Foo: pass 144 >>> typename(Foo) 145 'Foo' 146 >>> typename(list[Foo]) 147 'list[Foo]' 148 >>> typename(dict[str, Foo]) 149 'dict[str, Foo]' 150 >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]]) 151 'tuple[int, str, Foo, list[int], dict[str, Foo]]' 152 >>> typename(Optional[list[Foo]]) 153 'Optional[list[Foo]]' 154 >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]]) 155 'Union[Optional[Foo], list[Foo], str, int]' 156 >>> typename(set[Foo]) 157 'set[Foo]' 158 >>> typename(Any) 159 'Any' 160 """ 161 mod = "typing." if with_typing_module else "" 162 thisfunc = functools.partial(typename, with_typing_module=with_typing_module) 163 if is_opt(typ): 164 args = type_args(typ) 165 if args: 166 return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]" 167 else: 168 return f"{mod}Optional" 169 elif is_union(typ): 170 args = union_args(typ) 171 if args: 172 return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]' 173 else: 174 return f"{mod}Union" 175 elif is_list(typ): 176 args = type_args(typ) 177 if args: 178 et = thisfunc(args[0]) 179 return f"{mod}list[{et}]" 180 else: 181 return f"{mod}list" 182 elif is_set(typ): 183 args = type_args(typ) 184 if args: 185 et = thisfunc(args[0]) 186 return f"{mod}set[{et}]" 187 else: 188 return f"{mod}set" 189 elif is_dict(typ): 190 args = type_args(typ) 191 if args and len(args) == 2: 192 kt = thisfunc(args[0]) 193 vt = thisfunc(args[1]) 194 return f"{mod}dict[{kt}, {vt}]" 195 else: 196 return f"{mod}dict" 197 elif is_tuple(typ): 198 args = type_args(typ) 199 if args: 200 return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]' 201 else: 202 return f"{mod}tuple" 203 elif is_generic(typ): 204 origin = get_origin(typ) 205 if origin is None: 206 raise SerdeError("Could not extract origin class from generic class") 207 208 if not isinstance(origin.__name__, str): 209 raise SerdeError("Name of generic class is not string") 210 211 return origin.__name__ 212 213 elif is_literal(typ): 214 args = type_args(typ) 215 if not args: 216 raise TypeError("Literal type requires at least one literal argument") 217 return f'Literal[{", ".join(stringify_literal(e) for e in args)}]' 218 elif typ is Any: 219 return f"{mod}Any" 220 elif is_ellipsis(typ): 221 return "..." 222 else: 223 # Get super type for NewType 224 inner = getattr(typ, "__supertype__", None) 225 if inner: 226 return typename(inner) 227 228 name: Optional[str] = getattr(typ, "_name", None) 229 if name: 230 return name 231 else: 232 name = getattr(typ, "__name__", None) 233 if isinstance(name, str): 234 return name 235 else: 236 raise SerdeError(f"Could not get a type name from: {typ}") 237 238 239def stringify_literal(v: Any) -> str: 240 if isinstance(v, str): 241 return f"'{v}'" 242 else: 243 return str(v) 244 245 246def type_args(typ: Any) -> tuple[type[Any], ...]: 247 """ 248 Wrapper to suppress type error for accessing private members. 249 """ 250 try: 251 args: tuple[type[Any, ...]] = typ.__args__ # type: ignore 252 if args is None: 253 return () 254 else: 255 return args 256 except AttributeError: 257 return get_args(typ) 258 259 260def union_args(typ: Any) -> tuple[type[Any], ...]: 261 if not is_union(typ): 262 raise TypeError(f"{typ} is not Union") 263 args = type_args(typ) 264 if len(args) == 1: 265 return (args[0],) 266 it = iter(args) 267 types = [] 268 for i1, i2 in itertools.zip_longest(it, it): 269 if not i2: 270 types.append(i1) 271 elif is_none(i2): 272 types.append(Optional[i1]) 273 else: 274 types.extend((i1, i2)) 275 return tuple(types) 276 277 278def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ignore 279 raw_fields = dataclasses.fields(cls) 280 281 try: 282 # this resolves types when string forward reference 283 # or PEP 563: "from __future__ import annotations" are used 284 resolved_hints = typing.get_type_hints(cls) 285 except Exception as e: 286 raise SerdeError( 287 f"Failed to resolve type hints for {typename(cls)}:\n" 288 f"{e.__class__.__name__}: {e}\n\n" 289 f"If you are using forward references make sure you are calling deserialize & " 290 "serialize after all classes are globally visible." 291 ) from e 292 293 for f in raw_fields: 294 real_type = resolved_hints.get(f.name) 295 if real_type is not None: 296 f.type = real_type 297 if is_generic(real_type) and is_sqlalchemy_inspectable(cls): 298 f.type = get_args(real_type)[0] 299 300 return iter(raw_fields) 301 302 303TypeLike = Union[type[Any], typing.Any] 304 305 306def iter_types(cls: type[Any]) -> list[type[Any]]: 307 """ 308 Iterate field types recursively. 309 310 The correct return type is `Iterator[Union[Type, typing._specialform]], 311 but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead. 312 """ 313 lst: set[Union[type[Any], Any]] = set() 314 315 def recursive(cls: Union[type[Any], Any]) -> None: 316 if cls in lst: 317 return 318 319 if is_dataclass(cls): 320 lst.add(cls) 321 for f in dataclass_fields(cls): 322 recursive(f.type) 323 elif isinstance(cls, str): 324 lst.add(cls) 325 elif is_opt(cls): 326 lst.add(Optional) 327 args = type_args(cls) 328 if args: 329 recursive(args[0]) 330 elif is_union(cls): 331 lst.add(Union) 332 for arg in type_args(cls): 333 recursive(arg) 334 elif is_list(cls): 335 lst.add(list) 336 args = type_args(cls) 337 if args: 338 recursive(args[0]) 339 elif is_set(cls): 340 lst.add(set) 341 args = type_args(cls) 342 if args: 343 recursive(args[0]) 344 elif is_tuple(cls): 345 lst.add(tuple) 346 for arg in type_args(cls): 347 recursive(arg) 348 elif is_dict(cls): 349 lst.add(dict) 350 args = type_args(cls) 351 if args and len(args) >= 2: 352 recursive(args[0]) 353 recursive(args[1]) 354 elif is_pep695_type_alias(cls): 355 recursive(cls.__value__) 356 else: 357 lst.add(cls) 358 359 recursive(cls) 360 return list(lst) 361 362 363def iter_unions(cls: TypeLike) -> list[TypeLike]: 364 """ 365 Iterate over all unions that are used in the dataclass 366 """ 367 lst: list[TypeLike] = [] 368 stack: list[TypeLike] = [] # To prevent infinite recursion 369 370 def recursive(cls: TypeLike) -> None: 371 if cls in stack: 372 return 373 374 if is_union(cls): 375 lst.append(cls) 376 for arg in type_args(cls): 377 recursive(arg) 378 elif is_pep695_type_alias(cls): 379 recursive(cls.__value__) 380 if is_dataclass(cls): 381 stack.append(cls) 382 for f in dataclass_fields(cls): 383 recursive(f.type) 384 stack.pop() 385 elif is_opt(cls): 386 args = type_args(cls) 387 if args: 388 recursive(args[0]) 389 elif is_list(cls) or is_set(cls): 390 args = type_args(cls) 391 if args: 392 recursive(args[0]) 393 elif is_tuple(cls): 394 for arg in type_args(cls): 395 recursive(arg) 396 elif is_dict(cls): 397 args = type_args(cls) 398 if args and len(args) >= 2: 399 recursive(args[0]) 400 recursive(args[1]) 401 402 recursive(cls) 403 return lst 404 405 406def iter_literals(cls: type[Any]) -> list[TypeLike]: 407 """ 408 Iterate over all literals that are used in the dataclass 409 """ 410 lst: set[Union[type[Any], Any]] = set() 411 412 def recursive(cls: Union[type[Any], Any]) -> None: 413 if cls in lst: 414 return 415 416 if is_literal(cls): 417 lst.add(cls) 418 if is_union(cls): 419 for arg in type_args(cls): 420 recursive(arg) 421 if is_dataclass(cls): 422 lst.add(cls) 423 for f in dataclass_fields(cls): 424 recursive(f.type) 425 elif is_opt(cls): 426 args = type_args(cls) 427 if args: 428 recursive(args[0]) 429 elif is_list(cls) or is_set(cls): 430 args = type_args(cls) 431 if args: 432 recursive(args[0]) 433 elif is_tuple(cls): 434 for arg in type_args(cls): 435 recursive(arg) 436 elif is_dict(cls): 437 args = type_args(cls) 438 if args and len(args) >= 2: 439 recursive(args[0]) 440 recursive(args[1]) 441 442 recursive(cls) 443 return list(lst) 444 445 446@cache 447def is_union(typ: Any) -> bool: 448 """ 449 Test if the type is `typing.Union`. 450 451 >>> is_union(Union[int, str]) 452 True 453 """ 454 455 try: 456 # When `_WithTagging` is received, it will check inner type. 457 if isinstance(typ, _WithTagging): 458 return is_union(typ.inner) 459 except Exception: 460 pass 461 462 # Python 3.10 Union operator e.g. str | int 463 if sys.version_info[:2] >= (3, 10): 464 try: 465 if isinstance(typ, types.UnionType): 466 return True 467 except Exception: 468 pass 469 470 # typing.Union 471 return typing_inspect.is_union_type(typ) # type: ignore 472 473 474@cache 475def is_opt(typ: Any) -> bool: 476 """ 477 Test if the type is `typing.Optional`. 478 479 >>> is_opt(Optional[int]) 480 True 481 >>> is_opt(Optional) 482 True 483 >>> is_opt(None.__class__) 484 False 485 """ 486 487 # Python 3.10 Union operator e.g. str | None 488 is_union_type = False 489 if sys.version_info[:2] >= (3, 10): 490 try: 491 if isinstance(typ, types.UnionType): 492 is_union_type = True 493 except Exception: 494 pass 495 496 # typing.Optional 497 is_typing_union = typing_inspect.is_optional_type(typ) 498 499 args = type_args(typ) 500 if args: 501 return ( 502 (is_union_type or is_typing_union) 503 and len(args) == 2 504 and not is_none(args[0]) 505 and is_none(args[1]) 506 ) 507 else: 508 return typ is Optional 509 510 511@cache 512def is_bare_opt(typ: Any) -> bool: 513 """ 514 Test if the type is `typing.Optional` without type args. 515 >>> is_bare_opt(Optional[int]) 516 False 517 >>> is_bare_opt(Optional) 518 True 519 >>> is_bare_opt(None.__class__) 520 False 521 """ 522 return not type_args(typ) and typ is Optional 523 524 525@cache 526def is_opt_dataclass(typ: Any) -> bool: 527 """ 528 Test if the type is optional dataclass. 529 530 >>> is_opt_dataclass(Optional[int]) 531 False 532 >>> @dataclasses.dataclass 533 ... class Foo: 534 ... pass 535 >>> is_opt_dataclass(Foo) 536 False 537 >>> is_opt_dataclass(Optional[Foo]) 538 False 539 """ 540 args = get_args(typ) 541 return is_opt(typ) and len(args) > 0 and is_dataclass(args[0]) 542 543 544@cache 545def is_list(typ: type[Any]) -> bool: 546 """ 547 Test if the type is `list`. 548 549 >>> is_list(list[int]) 550 True 551 >>> is_list(list) 552 True 553 """ 554 try: 555 return issubclass(get_origin(typ), list) # type: ignore 556 except TypeError: 557 return typ is list 558 559 560@cache 561def is_bare_list(typ: type[Any]) -> bool: 562 """ 563 Test if the type is `list` without type args. 564 565 >>> is_bare_list(list[int]) 566 False 567 >>> is_bare_list(list) 568 True 569 """ 570 return typ is list 571 572 573@cache 574def is_tuple(typ: Any) -> bool: 575 """ 576 Test if the type is tuple. 577 """ 578 try: 579 return issubclass(get_origin(typ), tuple) # type: ignore 580 except TypeError: 581 return typ is tuple 582 583 584@cache 585def is_bare_tuple(typ: type[Any]) -> bool: 586 """ 587 Test if the type is tuple without type args. 588 589 >>> is_bare_tuple(tuple[int, str]) 590 False 591 >>> is_bare_tuple(tuple) 592 True 593 """ 594 return typ is tuple 595 596 597@cache 598def is_variable_tuple(typ: type[Any]) -> bool: 599 """ 600 Test if the type is a variable length of tuple tuple[T, ...]`. 601 602 >>> is_variable_tuple(tuple[int, ...]) 603 True 604 >>> is_variable_tuple(tuple[int, bool]) 605 False 606 >>> is_variable_tuple(tuple[()]) 607 False 608 """ 609 istuple = is_tuple(typ) and not is_bare_tuple(typ) 610 args = get_args(typ) 611 return istuple and len(args) == 2 and is_ellipsis(args[1]) 612 613 614@cache 615def is_set(typ: type[Any]) -> bool: 616 """ 617 Test if the type is `set` or `frozenset`. 618 619 >>> is_set(set[int]) 620 True 621 >>> is_set(set) 622 True 623 >>> is_set(frozenset[int]) 624 True 625 """ 626 try: 627 return issubclass(get_origin(typ), (set, frozenset)) # type: ignore 628 except TypeError: 629 return typ in (set, frozenset) 630 631 632@cache 633def is_bare_set(typ: type[Any]) -> bool: 634 """ 635 Test if the type is `set` without type args. 636 637 >>> is_bare_set(set[int]) 638 False 639 >>> is_bare_set(set) 640 True 641 """ 642 return typ in (set, frozenset) 643 644 645@cache 646def is_frozen_set(typ: type[Any]) -> bool: 647 """ 648 Test if the type is `frozenset`. 649 650 >>> is_frozen_set(frozenset[int]) 651 True 652 >>> is_frozen_set(set) 653 False 654 """ 655 try: 656 return issubclass(get_origin(typ), frozenset) # type: ignore 657 except TypeError: 658 return typ is frozenset 659 660 661@cache 662def is_dict(typ: type[Any]) -> bool: 663 """ 664 Test if the type is dict. 665 666 >>> is_dict(dict[int, int]) 667 True 668 >>> is_dict(dict) 669 True 670 >>> is_dict(defaultdict[int, int]) 671 True 672 """ 673 try: 674 return issubclass(get_origin(typ), (dict, defaultdict)) # type: ignore 675 except TypeError: 676 return typ in (dict, defaultdict) 677 678 679@cache 680def is_bare_dict(typ: type[Any]) -> bool: 681 """ 682 Test if the type is `dict` without type args. 683 684 >>> is_bare_dict(dict[int, str]) 685 False 686 >>> is_bare_dict(dict) 687 True 688 """ 689 return typ is dict 690 691 692@cache 693def is_default_dict(typ: type[Any]) -> bool: 694 """ 695 Test if the type is `defaultdict`. 696 697 >>> is_default_dict(defaultdict[int, int]) 698 True 699 >>> is_default_dict(dict[int, int]) 700 False 701 """ 702 try: 703 return issubclass(get_origin(typ), defaultdict) # type: ignore 704 except TypeError: 705 return typ is defaultdict 706 707 708@cache 709def is_none(typ: type[Any]) -> bool: 710 """ 711 >>> is_none(int) 712 False 713 >>> is_none(type(None)) 714 True 715 >>> is_none(None) 716 False 717 """ 718 return typ is type(None) # noqa 719 720 721PRIMITIVES = [int, float, bool, str] 722 723 724@cache 725def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: 726 """ 727 Test if the type is `enum.Enum`. 728 """ 729 try: 730 return issubclass(typ, enum.Enum) 731 except TypeError: 732 return isinstance(typ, enum.Enum) 733 734 735@cache 736def is_primitive_subclass(typ: type[Any]) -> bool: 737 """ 738 Test if the type is a subclass of primitive type. 739 740 >>> is_primitive_subclass(str) 741 False 742 >>> class Str(str): 743 ... pass 744 >>> is_primitive_subclass(Str) 745 True 746 """ 747 return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ) 748 749 750@cache 751def is_primitive(typ: Union[type[Any], NewType]) -> bool: 752 """ 753 Test if the type is primitive. 754 755 >>> is_primitive(int) 756 True 757 >>> class CustomInt(int): 758 ... pass 759 >>> is_primitive(CustomInt) 760 True 761 """ 762 try: 763 return any(issubclass(typ, ty) for ty in PRIMITIVES) # type: ignore 764 except TypeError: 765 return is_new_type_primitive(typ) 766 767 768@cache 769def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool: 770 """ 771 Test if the type is a NewType of primitives. 772 """ 773 inner = getattr(typ, "__supertype__", None) 774 if inner: 775 return is_primitive(inner) 776 else: 777 return any(isinstance(typ, ty) for ty in PRIMITIVES) 778 779 780@cache 781def has_generic_base(typ: Any) -> bool: 782 return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ()) 783 784 785@cache 786def is_generic(typ: Any) -> bool: 787 """ 788 Test if the type is derived from `typing.Generic`. 789 790 >>> T = typing.TypeVar('T') 791 >>> class GenericFoo(typing.Generic[T]): 792 ... pass 793 >>> is_generic(GenericFoo[int]) 794 True 795 >>> is_generic(GenericFoo) 796 False 797 """ 798 origin = get_origin(typ) 799 return origin is not None and has_generic_base(origin) 800 801 802@cache 803def is_class_var(typ: type[Any]) -> bool: 804 """ 805 Test if the type is `typing.ClassVar`. 806 807 >>> is_class_var(ClassVar[int]) 808 True 809 >>> is_class_var(ClassVar) 810 True 811 """ 812 return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore 813 814 815@cache 816def is_literal(typ: type[Any]) -> bool: 817 """ 818 Test if the type is derived from `typing.Literal`. 819 820 >>> T = typing.TypeVar('T') 821 >>> class GenericFoo(typing.Generic[T]): 822 ... pass 823 >>> is_generic(GenericFoo[int]) 824 True 825 >>> is_generic(GenericFoo) 826 False 827 """ 828 origin = get_origin(typ) 829 return origin is not None and origin is typing.Literal 830 831 832@cache 833def is_any(typ: type[Any]) -> bool: 834 """ 835 Test if the type is `typing.Any`. 836 """ 837 return typ is Any # type: ignore 838 839 840@cache 841def is_str_serializable(typ: type[Any]) -> bool: 842 """ 843 Test if the type is serializable to `str`. 844 """ 845 return typ in StrSerializableTypes or ( 846 type(typ) is type and issubclass(typ, StrSerializableTypes) 847 ) 848 849 850def is_datetime( 851 typ: type[Any], 852) -> TypeGuard[Union[datetime.date, datetime.time, datetime.datetime]]: 853 """ 854 Test if the type is any of the datetime types.. 855 """ 856 return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes)) 857 858 859def is_str_serializable_instance(obj: Any) -> bool: 860 return isinstance(obj, StrSerializableTypes) 861 862 863def is_datetime_instance(obj: Any) -> bool: 864 return isinstance(obj, DateTimeTypes) 865 866 867def is_ellipsis(typ: Any) -> bool: 868 return typ is Ellipsis 869 870 871def is_pep695_type_alias(typ: Any) -> bool: 872 """ 873 Test if the type is of PEP695 type alias. 874 """ 875 return isinstance(typ, TypeAliasType) 876 877 878@cache 879def get_type_var_names(cls: type[Any]) -> Optional[list[str]]: 880 """ 881 Get type argument names of a generic class. 882 883 >>> T = typing.TypeVar('T') 884 >>> class GenericFoo(typing.Generic[T]): 885 ... pass 886 >>> get_type_var_names(GenericFoo) 887 ['T'] 888 >>> get_type_var_names(int) 889 """ 890 bases = getattr(cls, "__orig_bases__", ()) 891 if not bases: 892 return None 893 894 type_arg_names: list[str] = [] 895 for base in bases: 896 type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__")) 897 898 return type_arg_names 899 900 901def find_generic_arg(cls: type[Any], field: TypeVar) -> int: 902 """ 903 Find a type in generic parameters. 904 905 >>> T = typing.TypeVar('T') 906 >>> U = typing.TypeVar('U') 907 >>> V = typing.TypeVar('V') 908 >>> class GenericFoo(typing.Generic[T, U]): 909 ... pass 910 >>> find_generic_arg(GenericFoo, T) 911 0 912 >>> find_generic_arg(GenericFoo, U) 913 1 914 >>> find_generic_arg(GenericFoo, V) 915 -1 916 """ 917 bases = getattr(cls, "__orig_bases__", ()) 918 if not bases: 919 raise Exception(f'"__orig_bases__" property was not found: {cls}') 920 921 for base in bases: 922 for n, arg in enumerate(get_args(base)): 923 if arg.__name__ == field.__name__: 924 return n 925 926 if not bases: 927 raise Exception(f"Generic field not found in class: {bases}") 928 929 return -1 930 931 932def get_generic_arg( 933 typ: Any, 934 maybe_generic_type_vars: Optional[list[str]], 935 variable_type_args: Optional[list[str]], 936 index: int, 937) -> Any: 938 """ 939 Get generic type argument. 940 941 >>> T = typing.TypeVar('T') 942 >>> U = typing.TypeVar('U') 943 >>> class GenericFoo(typing.Generic[T, U]): 944 ... pass 945 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__ 946 'int' 947 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__ 948 'str' 949 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__ 950 'str' 951 """ 952 if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None: 953 return typing.Any 954 955 args = get_args(typ) 956 957 if len(args) != len(maybe_generic_type_vars): 958 raise SerdeError( 959 f"Number of type args for {typ} does not match number of generic type vars: " 960 f"\n type args: {args}\n type_vars: {maybe_generic_type_vars}" 961 ) 962 963 # Get the name of the type var used for this field in the parent class definition 964 type_var_name = variable_type_args[index] 965 966 try: 967 # Find the slot of that type var in the original generic class definition 968 orig_index = maybe_generic_type_vars.index(type_var_name) 969 except ValueError: 970 return typing.Any 971 972 return args[orig_index]