serde.compat
Compatibility layer which handles mostly differences of typing
module between python versions.
1""" 2Compatibility layer which handles mostly differences of `typing` module between python versions. 3""" 4 5import dataclasses 6import datetime 7import decimal 8import enum 9import functools 10import ipaddress 11import itertools 12import pathlib 13import sys 14import types 15import uuid 16import typing 17from collections import defaultdict 18from collections.abc import Iterator 19from dataclasses import is_dataclass 20from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable 21 22import typing_inspect 23from typing_extensions import TypeGuard, ParamSpec, TypeAliasType 24 25# Lazy SQLAlchemy imports to improve startup time 26 27 28# Lazy SQLAlchemy import wrapper to improve startup time 29def _is_sqlalchemy_inspectable(subject: Any) -> bool: 30 from .sqlalchemy import is_sqlalchemy_inspectable 31 32 return is_sqlalchemy_inspectable(subject) 33 34 35def get_np_origin(tp: type[Any]) -> Optional[Any]: 36 return None 37 38 39def get_np_args(tp: type[Any]) -> tuple[Any, ...]: 40 return () 41 42 43__all__: list[str] = [] 44 45T = TypeVar("T") 46 47 48StrSerializableTypes = ( 49 decimal.Decimal, 50 pathlib.Path, 51 pathlib.PosixPath, 52 pathlib.WindowsPath, 53 pathlib.PurePath, 54 pathlib.PurePosixPath, 55 pathlib.PureWindowsPath, 56 uuid.UUID, 57 ipaddress.IPv4Address, 58 ipaddress.IPv6Address, 59 ipaddress.IPv4Network, 60 ipaddress.IPv6Network, 61 ipaddress.IPv4Interface, 62 ipaddress.IPv6Interface, 63) 64""" List of standard types (de)serializable to str """ 65 66DateTimeTypes = (datetime.date, datetime.time, datetime.datetime) 67""" List of datetime types """ 68 69 70@dataclasses.dataclass(unsafe_hash=True) 71class _WithTagging(Generic[T]): 72 """ 73 Intermediate data structure for (de)serializaing Union without dataclass. 74 """ 75 76 inner: T 77 """ Union type .e.g Union[Foo,Bar] passed in from_obj. """ 78 tagging: Any 79 """ Union Tagging """ 80 81 82class SerdeError(Exception): 83 """ 84 Serde error class. 85 """ 86 87 88@dataclasses.dataclass 89class UserError(Exception): 90 """ 91 Error from user code e.g. __post_init__. 92 """ 93 94 inner: Exception 95 96 97class SerdeSkip(Exception): 98 """ 99 Skip a field in custom (de)serializer. 100 """ 101 102 103def is_hashable(typ: Any) -> TypeGuard[Hashable]: 104 """ 105 Test is an object hashable 106 """ 107 try: 108 hash(typ) 109 except TypeError: 110 return False 111 return True 112 113 114P = ParamSpec("P") 115 116 117def cache(f: Callable[P, T]) -> Callable[P, T]: 118 """ 119 Wrapper for `functools.cache` to avoid `Hashable` related type errors. 120 """ 121 122 def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 123 return f(*args, **kwargs) 124 125 return functools.cache(wrapper) # type: ignore 126 127 128@cache 129def get_origin(typ: Any) -> Optional[Any]: 130 """ 131 Provide `get_origin` that works in all python versions. 132 """ 133 return typing.get_origin(typ) or get_np_origin(typ) 134 135 136@cache 137def get_args(typ: type[Any]) -> tuple[Any, ...]: 138 """ 139 Provide `get_args` that works in all python versions. 140 """ 141 return typing.get_args(typ) or get_np_args(typ) 142 143 144@cache 145def typename(typ: Any, with_typing_module: bool = False) -> str: 146 """ 147 >>> from typing import Any 148 >>> typename(int) 149 'int' 150 >>> class Foo: pass 151 >>> typename(Foo) 152 'Foo' 153 >>> typename(list[Foo]) 154 'list[Foo]' 155 >>> typename(dict[str, Foo]) 156 'dict[str, Foo]' 157 >>> typename(tuple[int, str, Foo, list[int], dict[str, Foo]]) 158 'tuple[int, str, Foo, list[int], dict[str, Foo]]' 159 >>> typename(Optional[list[Foo]]) 160 'Optional[list[Foo]]' 161 >>> typename(Union[Optional[Foo], list[Foo], Union[str, int]]) 162 'Union[Optional[Foo], list[Foo], str, int]' 163 >>> typename(set[Foo]) 164 'set[Foo]' 165 >>> typename(Any) 166 'Any' 167 """ 168 mod = "typing." if with_typing_module else "" 169 thisfunc = functools.partial(typename, with_typing_module=with_typing_module) 170 if is_opt(typ): 171 args = type_args(typ) 172 if args: 173 return f"{mod}Optional[{thisfunc(type_args(typ)[0])}]" 174 else: 175 return f"{mod}Optional" 176 elif is_union(typ): 177 args = union_args(typ) 178 if args: 179 return f'{mod}Union[{", ".join([thisfunc(e) for e in args])}]' 180 else: 181 return f"{mod}Union" 182 elif is_list(typ): 183 args = type_args(typ) 184 if args: 185 et = thisfunc(args[0]) 186 return f"{mod}list[{et}]" 187 else: 188 return f"{mod}list" 189 elif is_set(typ): 190 args = type_args(typ) 191 if args: 192 et = thisfunc(args[0]) 193 return f"{mod}set[{et}]" 194 else: 195 return f"{mod}set" 196 elif is_dict(typ): 197 args = type_args(typ) 198 if args and len(args) == 2: 199 kt = thisfunc(args[0]) 200 vt = thisfunc(args[1]) 201 return f"{mod}dict[{kt}, {vt}]" 202 else: 203 return f"{mod}dict" 204 elif is_tuple(typ): 205 args = type_args(typ) 206 if args: 207 return f'{mod}tuple[{", ".join([thisfunc(e) for e in args])}]' 208 else: 209 return f"{mod}tuple" 210 elif is_generic(typ): 211 origin = get_origin(typ) 212 if origin is None: 213 raise SerdeError("Could not extract origin class from generic class") 214 215 if not isinstance(origin.__name__, str): 216 raise SerdeError("Name of generic class is not string") 217 218 return origin.__name__ 219 220 elif is_literal(typ): 221 args = type_args(typ) 222 if not args: 223 raise TypeError("Literal type requires at least one literal argument") 224 return f'Literal[{", ".join(stringify_literal(e) for e in args)}]' 225 elif typ is Any: 226 return f"{mod}Any" 227 elif is_ellipsis(typ): 228 return "..." 229 else: 230 # Get super type for NewType 231 inner = getattr(typ, "__supertype__", None) 232 if inner: 233 return typename(inner) 234 235 name: Optional[str] = getattr(typ, "_name", None) 236 if name: 237 return name 238 else: 239 name = getattr(typ, "__name__", None) 240 if isinstance(name, str): 241 return name 242 else: 243 raise SerdeError(f"Could not get a type name from: {typ}") 244 245 246def stringify_literal(v: Any) -> str: 247 if isinstance(v, str): 248 return f"'{v}'" 249 else: 250 return str(v) 251 252 253def type_args(typ: Any) -> tuple[type[Any], ...]: 254 """ 255 Wrapper to suppress type error for accessing private members. 256 """ 257 try: 258 args: tuple[type[Any, ...]] = typ.__args__ # type: ignore 259 if args is None: 260 return () 261 else: 262 return args 263 except AttributeError: 264 return get_args(typ) 265 266 267def union_args(typ: Any) -> tuple[type[Any], ...]: 268 if not is_union(typ): 269 raise TypeError(f"{typ} is not Union") 270 args = type_args(typ) 271 if len(args) == 1: 272 return (args[0],) 273 it = iter(args) 274 types = [] 275 for i1, i2 in itertools.zip_longest(it, it): 276 if not i2: 277 types.append(i1) 278 elif is_none(i2): 279 types.append(Optional[i1]) 280 else: 281 types.extend((i1, i2)) 282 return tuple(types) 283 284 285def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ignore 286 raw_fields = dataclasses.fields(cls) 287 288 try: 289 # this resolves types when string forward reference 290 # or PEP 563: "from __future__ import annotations" are used 291 resolved_hints = typing.get_type_hints(cls) 292 except Exception as e: 293 raise SerdeError( 294 f"Failed to resolve type hints for {typename(cls)}:\n" 295 f"{e.__class__.__name__}: {e}\n\n" 296 f"If you are using forward references make sure you are calling deserialize & " 297 "serialize after all classes are globally visible." 298 ) from e 299 300 for f in raw_fields: 301 real_type = resolved_hints.get(f.name) 302 if real_type is not None: 303 f.type = real_type 304 if is_generic(real_type) and _is_sqlalchemy_inspectable(cls): 305 f.type = get_args(real_type)[0] 306 307 return iter(raw_fields) 308 309 310TypeLike = Union[type[Any], typing.Any] 311 312 313def iter_types(cls: type[Any]) -> list[type[Any]]: 314 """ 315 Iterate field types recursively. 316 317 The correct return type is `Iterator[Union[Type, typing._specialform]], 318 but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead. 319 """ 320 lst: set[Union[type[Any], Any]] = set() 321 322 def recursive(cls: Union[type[Any], Any]) -> None: 323 if cls in lst: 324 return 325 326 if is_dataclass(cls): 327 lst.add(cls) 328 if isinstance(cls, type): 329 for f in dataclass_fields(cls): 330 recursive(f.type) 331 elif isinstance(cls, str): 332 lst.add(cls) 333 elif is_opt(cls): 334 lst.add(Optional) 335 args = type_args(cls) 336 if args: 337 recursive(args[0]) 338 elif is_union(cls): 339 lst.add(Union) 340 for arg in type_args(cls): 341 recursive(arg) 342 elif is_list(cls): 343 lst.add(list) 344 args = type_args(cls) 345 if args: 346 recursive(args[0]) 347 elif is_set(cls): 348 lst.add(set) 349 args = type_args(cls) 350 if args: 351 recursive(args[0]) 352 elif is_tuple(cls): 353 lst.add(tuple) 354 for arg in type_args(cls): 355 recursive(arg) 356 elif is_dict(cls): 357 lst.add(dict) 358 args = type_args(cls) 359 if args and len(args) >= 2: 360 recursive(args[0]) 361 recursive(args[1]) 362 elif is_pep695_type_alias(cls): 363 recursive(cls.__value__) 364 else: 365 lst.add(cls) 366 367 recursive(cls) 368 return list(lst) 369 370 371def iter_unions(cls: TypeLike) -> list[TypeLike]: 372 """ 373 Iterate over all unions that are used in the dataclass 374 """ 375 lst: list[TypeLike] = [] 376 stack: list[TypeLike] = [] # To prevent infinite recursion 377 378 def recursive(cls: TypeLike) -> None: 379 if cls in stack: 380 return 381 382 if is_union(cls): 383 lst.append(cls) 384 for arg in type_args(cls): 385 recursive(arg) 386 elif is_pep695_type_alias(cls): 387 recursive(cls.__value__) 388 if is_dataclass(cls): 389 stack.append(cls) 390 if isinstance(cls, type): 391 for f in dataclass_fields(cls): 392 recursive(f.type) 393 stack.pop() 394 elif is_opt(cls): 395 args = type_args(cls) 396 if args: 397 recursive(args[0]) 398 elif is_list(cls) or is_set(cls): 399 args = type_args(cls) 400 if args: 401 recursive(args[0]) 402 elif is_tuple(cls): 403 for arg in type_args(cls): 404 recursive(arg) 405 elif is_dict(cls): 406 args = type_args(cls) 407 if args and len(args) >= 2: 408 recursive(args[0]) 409 recursive(args[1]) 410 411 recursive(cls) 412 return lst 413 414 415def iter_literals(cls: type[Any]) -> list[TypeLike]: 416 """ 417 Iterate over all literals that are used in the dataclass 418 """ 419 lst: set[Union[type[Any], Any]] = set() 420 421 def recursive(cls: Union[type[Any], Any]) -> None: 422 if cls in lst: 423 return 424 425 if is_literal(cls): 426 lst.add(cls) 427 if is_union(cls): 428 for arg in type_args(cls): 429 recursive(arg) 430 if is_dataclass(cls): 431 lst.add(cls) 432 if isinstance(cls, type): 433 for f in dataclass_fields(cls): 434 recursive(f.type) 435 elif is_opt(cls): 436 args = type_args(cls) 437 if args: 438 recursive(args[0]) 439 elif is_list(cls) or is_set(cls): 440 args = type_args(cls) 441 if args: 442 recursive(args[0]) 443 elif is_tuple(cls): 444 for arg in type_args(cls): 445 recursive(arg) 446 elif is_dict(cls): 447 args = type_args(cls) 448 if args and len(args) >= 2: 449 recursive(args[0]) 450 recursive(args[1]) 451 452 recursive(cls) 453 return list(lst) 454 455 456@cache 457def is_union(typ: Any) -> bool: 458 """ 459 Test if the type is `typing.Union`. 460 461 >>> is_union(Union[int, str]) 462 True 463 """ 464 465 try: 466 # When `_WithTagging` is received, it will check inner type. 467 if isinstance(typ, _WithTagging): 468 return is_union(typ.inner) 469 except Exception: 470 pass 471 472 # Python 3.10 Union operator e.g. str | int 473 if sys.version_info[:2] >= (3, 10): 474 try: 475 if isinstance(typ, types.UnionType): 476 return True 477 except Exception: 478 pass 479 480 # typing.Union 481 return typing_inspect.is_union_type(typ) # type: ignore 482 483 484@cache 485def is_opt(typ: Any) -> bool: 486 """ 487 Test if the type is `typing.Optional`. 488 489 >>> is_opt(Optional[int]) 490 True 491 >>> is_opt(Optional) 492 True 493 >>> is_opt(None.__class__) 494 False 495 """ 496 497 # Python 3.10 Union operator e.g. str | None 498 is_union_type = False 499 if sys.version_info[:2] >= (3, 10): 500 try: 501 if isinstance(typ, types.UnionType): 502 is_union_type = True 503 except Exception: 504 pass 505 506 # typing.Optional 507 is_typing_union = typing_inspect.is_optional_type(typ) 508 509 args = type_args(typ) 510 if args: 511 return ( 512 (is_union_type or is_typing_union) 513 and len(args) == 2 514 and not is_none(args[0]) 515 and is_none(args[1]) 516 ) 517 else: 518 return typ is Optional 519 520 521@cache 522def is_bare_opt(typ: Any) -> bool: 523 """ 524 Test if the type is `typing.Optional` without type args. 525 >>> is_bare_opt(Optional[int]) 526 False 527 >>> is_bare_opt(Optional) 528 True 529 >>> is_bare_opt(None.__class__) 530 False 531 """ 532 return not type_args(typ) and typ is Optional 533 534 535@cache 536def is_opt_dataclass(typ: Any) -> bool: 537 """ 538 Test if the type is optional dataclass. 539 540 >>> is_opt_dataclass(Optional[int]) 541 False 542 >>> @dataclasses.dataclass 543 ... class Foo: 544 ... pass 545 >>> is_opt_dataclass(Foo) 546 False 547 >>> is_opt_dataclass(Optional[Foo]) 548 False 549 """ 550 args = get_args(typ) 551 return is_opt(typ) and len(args) > 0 and is_dataclass(args[0]) 552 553 554@cache 555def is_list(typ: type[Any]) -> bool: 556 """ 557 Test if the type is `list`. 558 559 >>> is_list(list[int]) 560 True 561 >>> is_list(list) 562 True 563 """ 564 try: 565 return issubclass(get_origin(typ), list) # type: ignore 566 except TypeError: 567 return typ is list 568 569 570@cache 571def is_bare_list(typ: type[Any]) -> bool: 572 """ 573 Test if the type is `list` without type args. 574 575 >>> is_bare_list(list[int]) 576 False 577 >>> is_bare_list(list) 578 True 579 """ 580 return typ is list 581 582 583@cache 584def is_tuple(typ: Any) -> bool: 585 """ 586 Test if the type is tuple. 587 """ 588 try: 589 return issubclass(get_origin(typ), tuple) # type: ignore 590 except TypeError: 591 return typ is tuple 592 593 594@cache 595def is_bare_tuple(typ: type[Any]) -> bool: 596 """ 597 Test if the type is tuple without type args. 598 599 >>> is_bare_tuple(tuple[int, str]) 600 False 601 >>> is_bare_tuple(tuple) 602 True 603 """ 604 return typ is tuple 605 606 607@cache 608def is_variable_tuple(typ: type[Any]) -> bool: 609 """ 610 Test if the type is a variable length of tuple tuple[T, ...]`. 611 612 >>> is_variable_tuple(tuple[int, ...]) 613 True 614 >>> is_variable_tuple(tuple[int, bool]) 615 False 616 >>> is_variable_tuple(tuple[()]) 617 False 618 """ 619 istuple = is_tuple(typ) and not is_bare_tuple(typ) 620 args = get_args(typ) 621 return istuple and len(args) == 2 and is_ellipsis(args[1]) 622 623 624@cache 625def is_set(typ: type[Any]) -> bool: 626 """ 627 Test if the type is `set` or `frozenset`. 628 629 >>> is_set(set[int]) 630 True 631 >>> is_set(set) 632 True 633 >>> is_set(frozenset[int]) 634 True 635 """ 636 try: 637 return issubclass(get_origin(typ), (set, frozenset)) # type: ignore 638 except TypeError: 639 return typ in (set, frozenset) 640 641 642@cache 643def is_bare_set(typ: type[Any]) -> bool: 644 """ 645 Test if the type is `set` without type args. 646 647 >>> is_bare_set(set[int]) 648 False 649 >>> is_bare_set(set) 650 True 651 """ 652 return typ in (set, frozenset) 653 654 655@cache 656def is_frozen_set(typ: type[Any]) -> bool: 657 """ 658 Test if the type is `frozenset`. 659 660 >>> is_frozen_set(frozenset[int]) 661 True 662 >>> is_frozen_set(set) 663 False 664 """ 665 try: 666 return issubclass(get_origin(typ), frozenset) # type: ignore 667 except TypeError: 668 return typ is frozenset 669 670 671@cache 672def is_dict(typ: type[Any]) -> bool: 673 """ 674 Test if the type is dict. 675 676 >>> is_dict(dict[int, int]) 677 True 678 >>> is_dict(dict) 679 True 680 >>> is_dict(defaultdict[int, int]) 681 True 682 """ 683 try: 684 return issubclass(get_origin(typ), (dict, defaultdict)) # type: ignore 685 except TypeError: 686 return typ in (dict, defaultdict) 687 688 689@cache 690def is_bare_dict(typ: type[Any]) -> bool: 691 """ 692 Test if the type is `dict` without type args. 693 694 >>> is_bare_dict(dict[int, str]) 695 False 696 >>> is_bare_dict(dict) 697 True 698 """ 699 return typ is dict 700 701 702@cache 703def is_default_dict(typ: type[Any]) -> bool: 704 """ 705 Test if the type is `defaultdict`. 706 707 >>> is_default_dict(defaultdict[int, int]) 708 True 709 >>> is_default_dict(dict[int, int]) 710 False 711 """ 712 try: 713 return issubclass(get_origin(typ), defaultdict) # type: ignore 714 except TypeError: 715 return typ is defaultdict 716 717 718@cache 719def is_none(typ: type[Any]) -> bool: 720 """ 721 >>> is_none(int) 722 False 723 >>> is_none(type(None)) 724 True 725 >>> is_none(None) 726 False 727 """ 728 return typ is type(None) # noqa 729 730 731PRIMITIVES = [int, float, bool, str] 732 733 734@cache 735def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: 736 """ 737 Test if the type is `enum.Enum`. 738 """ 739 try: 740 return issubclass(typ, enum.Enum) 741 except TypeError: 742 return isinstance(typ, enum.Enum) 743 744 745@cache 746def is_primitive_subclass(typ: type[Any]) -> bool: 747 """ 748 Test if the type is a subclass of primitive type. 749 750 >>> is_primitive_subclass(str) 751 False 752 >>> class Str(str): 753 ... pass 754 >>> is_primitive_subclass(Str) 755 True 756 """ 757 return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ) 758 759 760@cache 761def is_primitive(typ: Union[type[Any], NewType]) -> bool: 762 """ 763 Test if the type is primitive. 764 765 >>> is_primitive(int) 766 True 767 >>> class CustomInt(int): 768 ... pass 769 >>> is_primitive(CustomInt) 770 True 771 """ 772 try: 773 return any(issubclass(typ, ty) for ty in PRIMITIVES) # type: ignore 774 except TypeError: 775 return is_new_type_primitive(typ) 776 777 778@cache 779def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool: 780 """ 781 Test if the type is a NewType of primitives. 782 """ 783 inner = getattr(typ, "__supertype__", None) 784 if inner: 785 return is_primitive(inner) 786 else: 787 return any(isinstance(typ, ty) for ty in PRIMITIVES) 788 789 790@cache 791def has_generic_base(typ: Any) -> bool: 792 return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ()) 793 794 795@cache 796def is_generic(typ: Any) -> bool: 797 """ 798 Test if the type is derived from `typing.Generic`. 799 800 >>> T = typing.TypeVar('T') 801 >>> class GenericFoo(typing.Generic[T]): 802 ... pass 803 >>> is_generic(GenericFoo[int]) 804 True 805 >>> is_generic(GenericFoo) 806 False 807 """ 808 origin = get_origin(typ) 809 return origin is not None and has_generic_base(origin) 810 811 812@cache 813def is_class_var(typ: type[Any]) -> bool: 814 """ 815 Test if the type is `typing.ClassVar`. 816 817 >>> is_class_var(ClassVar[int]) 818 True 819 >>> is_class_var(ClassVar) 820 True 821 """ 822 return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore 823 824 825@cache 826def is_literal(typ: type[Any]) -> bool: 827 """ 828 Test if the type is derived from `typing.Literal`. 829 830 >>> T = typing.TypeVar('T') 831 >>> class GenericFoo(typing.Generic[T]): 832 ... pass 833 >>> is_generic(GenericFoo[int]) 834 True 835 >>> is_generic(GenericFoo) 836 False 837 """ 838 origin = get_origin(typ) 839 return origin is not None and origin is typing.Literal 840 841 842@cache 843def is_any(typ: type[Any]) -> bool: 844 """ 845 Test if the type is `typing.Any`. 846 """ 847 return typ is Any # type: ignore 848 849 850@cache 851def is_str_serializable(typ: type[Any]) -> bool: 852 """ 853 Test if the type is serializable to `str`. 854 """ 855 return typ in StrSerializableTypes or ( 856 type(typ) is type and issubclass(typ, StrSerializableTypes) 857 ) 858 859 860def is_datetime( 861 typ: type[Any], 862) -> TypeGuard[Union[datetime.date, datetime.time, datetime.datetime]]: 863 """ 864 Test if the type is any of the datetime types.. 865 """ 866 return typ in DateTimeTypes or (type(typ) is type and issubclass(typ, DateTimeTypes)) 867 868 869def is_str_serializable_instance(obj: Any) -> bool: 870 return isinstance(obj, StrSerializableTypes) 871 872 873def is_datetime_instance(obj: Any) -> bool: 874 return isinstance(obj, DateTimeTypes) 875 876 877def is_ellipsis(typ: Any) -> bool: 878 return typ is Ellipsis 879 880 881def is_pep695_type_alias(typ: Any) -> bool: 882 """ 883 Test if the type is of PEP695 type alias. 884 """ 885 return isinstance(typ, TypeAliasType) 886 887 888@cache 889def get_type_var_names(cls: type[Any]) -> Optional[list[str]]: 890 """ 891 Get type argument names of a generic class. 892 893 >>> T = typing.TypeVar('T') 894 >>> class GenericFoo(typing.Generic[T]): 895 ... pass 896 >>> get_type_var_names(GenericFoo) 897 ['T'] 898 >>> get_type_var_names(int) 899 """ 900 bases = getattr(cls, "__orig_bases__", ()) 901 if not bases: 902 return None 903 904 type_arg_names: list[str] = [] 905 for base in bases: 906 type_arg_names.extend(arg.__name__ for arg in get_args(base) if hasattr(arg, "__name__")) 907 908 return type_arg_names 909 910 911def find_generic_arg(cls: type[Any], field: TypeVar) -> int: 912 """ 913 Find a type in generic parameters. 914 915 >>> T = typing.TypeVar('T') 916 >>> U = typing.TypeVar('U') 917 >>> V = typing.TypeVar('V') 918 >>> class GenericFoo(typing.Generic[T, U]): 919 ... pass 920 >>> find_generic_arg(GenericFoo, T) 921 0 922 >>> find_generic_arg(GenericFoo, U) 923 1 924 >>> find_generic_arg(GenericFoo, V) 925 -1 926 """ 927 bases = getattr(cls, "__orig_bases__", ()) 928 if not bases: 929 raise Exception(f'"__orig_bases__" property was not found: {cls}') 930 931 for base in bases: 932 for n, arg in enumerate(get_args(base)): 933 if arg.__name__ == field.__name__: 934 return n 935 936 if not bases: 937 raise Exception(f"Generic field not found in class: {bases}") 938 939 return -1 940 941 942def get_generic_arg( 943 typ: Any, 944 maybe_generic_type_vars: Optional[list[str]], 945 variable_type_args: Optional[list[str]], 946 index: int, 947) -> Any: 948 """ 949 Get generic type argument. 950 951 >>> T = typing.TypeVar('T') 952 >>> U = typing.TypeVar('U') 953 >>> class GenericFoo(typing.Generic[T, U]): 954 ... pass 955 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__ 956 'int' 957 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__ 958 'str' 959 >>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__ 960 'str' 961 """ 962 if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None: 963 return typing.Any 964 965 args = get_args(typ) 966 967 if len(args) != len(maybe_generic_type_vars): 968 raise SerdeError( 969 f"Number of type args for {typ} does not match number of generic type vars: " 970 f"\n type args: {args}\n type_vars: {maybe_generic_type_vars}" 971 ) 972 973 # Get the name of the type var used for this field in the parent class definition 974 type_var_name = variable_type_args[index] 975 976 try: 977 # Find the slot of that type var in the original generic class definition 978 orig_index = maybe_generic_type_vars.index(type_var_name) 979 except ValueError: 980 return typing.Any 981 982 return args[orig_index]