serde.core
pyserde core module.
1""" 2pyserde core module. 3""" 4 5from __future__ import annotations 6import dataclasses 7import enum 8import functools 9import logging 10import re 11import casefy 12from dataclasses import dataclass 13 14from beartype.door import is_bearable 15from collections.abc import Mapping, Sequence, MutableSequence, Set, Callable, Hashable 16from typing import ( 17 overload, 18 TypeVar, 19 Generic, 20 Any, 21 Protocol, 22 get_type_hints, 23 cast, 24) 25 26from .compat import ( 27 T, 28 SerdeError, 29 dataclass_fields, 30 get_origin, 31 is_bare_dict, 32 is_bare_list, 33 is_bare_set, 34 is_bare_tuple, 35 is_class_var, 36 is_dict, 37 is_generic, 38 is_list, 39 is_literal, 40 is_new_type_primitive, 41 is_any, 42 is_opt, 43 is_opt_dataclass, 44 is_set, 45 is_tuple, 46 is_union, 47 is_variable_tuple, 48 type_args, 49 typename, 50 _WithTagging, 51) 52 53__all__ = [ 54 "Scope", 55 "gen", 56 "add_func", 57 "Func", 58 "Field", 59 "fields", 60 "FlattenOpts", 61 "conv", 62 "union_func_name", 63] 64 65logger = logging.getLogger("serde") 66 67 68# name of the serde context key 69SERDE_SCOPE = "__serde__" 70 71# main function keys 72FROM_ITER = "from_iter" 73FROM_DICT = "from_dict" 74TO_ITER = "to_iter" 75TO_DICT = "to_dict" 76TYPE_CHECK = "typecheck" 77 78# prefixes used to distinguish the direction of a union function 79UNION_SE_PREFIX = "union_se" 80UNION_DE_PREFIX = "union_de" 81 82LITERAL_DE_PREFIX = "literal_de" 83 84SETTINGS = {"debug": False} 85 86 87@dataclass(frozen=True) 88class UnionCacheKey: 89 union: Hashable 90 tagging: Tagging 91 92 93def init(debug: bool = False) -> None: 94 SETTINGS["debug"] = debug 95 96 97@dataclass 98class Cache: 99 """ 100 Cache the generated code for non-dataclass classes. 101 102 for example, a type not bound in a dataclass is passed in from_json 103 104 ``` 105 from_json(Union[Foo, Bar], ...) 106 ``` 107 108 It creates the following wrapper dataclass on the fly, 109 110 ``` 111 @serde 112 @dataclass 113 class Union_Foo_bar: 114 v: Union[Foo, Bar] 115 ``` 116 117 Then store this class in this cache. Whenever the same type is passed, 118 the class is retrieved from this cache. So the overhead of the codegen 119 should be only once. 120 """ 121 122 classes: dict[Hashable, type[Any]] = dataclasses.field(default_factory=dict) 123 124 def _get_class(self, cls: type[Any]) -> type[Any]: 125 """ 126 Get a wrapper class from the the cache. If not found, it will generate 127 the class and store it in the cache. 128 """ 129 wrapper = self.classes.get(cls) # type: ignore[call-overload] # mypy doesn't recognize type[Any] as Hashable 130 return wrapper or self._generate_class(cls) 131 132 def _generate_class(self, cls: type[Any]) -> type[Any]: 133 """ 134 Generate a wrapper dataclass then make the it (de)serializable using 135 @serde decorator. 136 """ 137 from . import serde 138 139 class_name = f"Wrapper{typename(cls)}" 140 logger.debug(f"Generating a wrapper class code for {class_name}") 141 142 wrapper = dataclasses.make_dataclass(class_name, [("v", cls)]) 143 144 serde(wrapper) 145 self.classes[cls] = wrapper # type: ignore[index] # mypy doesn't recognize type[Any] as Hashable 146 147 logger.debug(f"(de)serializing code for {class_name} was generated") 148 return wrapper 149 150 def serialize(self, cls: type[Any], obj: Any, **kwargs: Any) -> Any: 151 """ 152 Serialize the specified type of object into dict or tuple. 153 """ 154 wrapper = self._get_class(cls) 155 scope: Scope = getattr(wrapper, SERDE_SCOPE) 156 data = scope.funcs[TO_DICT](wrapper(obj), **kwargs) 157 158 logging.debug(f"Intermediate value: {data}") 159 160 return data["v"] 161 162 def deserialize(self, cls: type[T], obj: Any) -> T: 163 """ 164 Deserialize from dict or tuple into the specified type. 165 """ 166 wrapper = self._get_class(cls) 167 scope: Scope = getattr(wrapper, SERDE_SCOPE) 168 return scope.funcs[FROM_DICT](data={"v": obj}).v # type: ignore 169 170 def _get_union_class(self, cls: type[Any]) -> type[Any] | None: 171 """ 172 Get a wrapper class from the the cache. If not found, it will generate 173 the class and store it in the cache. 174 """ 175 union_cls, tagging = _extract_from_with_tagging(cls) 176 cache_key = UnionCacheKey(union=union_cls, tagging=tagging) 177 wrapper = self.classes.get(cache_key) 178 return wrapper or self._generate_union_class(cls) 179 180 def _generate_union_class(self, cls: type[Any]) -> type[Any]: 181 """ 182 Generate a wrapper dataclass then make the it (de)serializable using 183 @serde decorator. 184 """ 185 import serde 186 187 union_cls, tagging = _extract_from_with_tagging(cls) 188 cache_key = UnionCacheKey(union=union_cls, tagging=tagging) 189 class_name = union_func_name( 190 f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls)) 191 ) 192 wrapper = dataclasses.make_dataclass(class_name, [("v", union_cls)]) 193 serde.serde(wrapper, tagging=tagging) 194 self.classes[cache_key] = wrapper 195 return wrapper 196 197 def serialize_union(self, cls: type[Any], obj: Any) -> Any: 198 """ 199 Serialize the specified Union into dict or tuple. 200 """ 201 union_cls, _ = _extract_from_with_tagging(cls) 202 wrapper = self._get_union_class(cls) 203 scope: Scope = getattr(wrapper, SERDE_SCOPE) 204 func_name = union_func_name(UNION_SE_PREFIX, list(type_args(union_cls))) 205 return scope.funcs[func_name](obj, False, False) 206 207 def deserialize_union( 208 self, 209 cls: type[T], 210 data: Any, 211 deserialize_numbers: Callable[[str | int], float] | None = None, 212 ) -> T: 213 """ 214 Deserialize from dict or tuple into the specified Union. 215 """ 216 union_cls, _ = _extract_from_with_tagging(cls) 217 wrapper = self._get_union_class(cls) 218 scope: Scope = getattr(wrapper, SERDE_SCOPE) 219 func_name = union_func_name(UNION_DE_PREFIX, list(type_args(union_cls))) 220 return cast( 221 T, 222 scope.funcs[func_name]( 223 cls=union_cls, data=data, deserialize_numbers=deserialize_numbers 224 ), 225 ) 226 227 228def _extract_from_with_tagging(maybe_with_tagging: Any) -> tuple[Any, Tagging]: 229 if isinstance(maybe_with_tagging, _WithTagging): 230 return maybe_with_tagging.inner, maybe_with_tagging.tagging 231 else: 232 return maybe_with_tagging, ExternalTagging 233 234 235CACHE = Cache() 236""" Global cache variable for non-dataclass classes """ 237 238 239@dataclass 240class Scope: 241 """ 242 Container to store types and functions used in code generation context. 243 """ 244 245 cls: type[Any] 246 """ The exact class this scope is for 247 (needed to distinguish scopes between inherited classes) """ 248 249 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 250 """ Generated serialize and deserialize functions """ 251 252 defaults: dict[str, Callable[..., Any] | Any] = dataclasses.field(default_factory=dict) 253 """ Default values of the dataclass fields (factories & normal values) """ 254 255 code: dict[str, str] = dataclasses.field(default_factory=dict) 256 """ Generated source code (only filled when debug is True) """ 257 258 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 259 """ The union serializing functions need references to their types """ 260 261 reuse_instances_default: bool = True 262 """ Default values for to_dict & from_dict arguments """ 263 264 convert_sets_default: bool = False 265 266 def __repr__(self) -> str: 267 res: list[str] = [] 268 269 res.append("==================================================") 270 res.append(self._justify(self.cls.__name__)) 271 res.append("==================================================") 272 res.append("") 273 274 if self.code: 275 res.append("--------------------------------------------------") 276 res.append(self._justify("Functions generated by pyserde")) 277 res.append("--------------------------------------------------") 278 res.extend(list(self.code.values())) 279 res.append("") 280 281 if self.funcs: 282 res.append("--------------------------------------------------") 283 res.append(self._justify("Function references in scope")) 284 res.append("--------------------------------------------------") 285 for k, v in self.funcs.items(): 286 res.append(f"{k}: {v}") 287 res.append("") 288 289 if self.defaults: 290 res.append("--------------------------------------------------") 291 res.append(self._justify("Default values for the dataclass fields")) 292 res.append("--------------------------------------------------") 293 for k, v in self.defaults.items(): 294 res.append(f"{k}: {v}") 295 res.append("") 296 297 if self.union_se_args: 298 res.append("--------------------------------------------------") 299 res.append(self._justify("Type list by used for union serialize functions")) 300 res.append("--------------------------------------------------") 301 for k, lst in self.union_se_args.items(): 302 res.append(f"{k}: {list(lst)}") 303 res.append("") 304 305 return "\n".join(res) 306 307 def _justify(self, s: str, length: int = 50) -> str: 308 white_spaces = int((50 - len(s)) / 2) 309 return " " * (white_spaces if white_spaces > 0 else 0) + s 310 311 312def raise_unsupported_type(obj: Any) -> None: 313 # needed because we can not render a raise statement everywhere, e.g. as argument 314 raise SerdeError(f"Unsupported type: {typename(type(obj))}") 315 316 317def gen( 318 code: str, globals: dict[str, Any] | None = None, locals: dict[str, Any] | None = None 319) -> str: 320 """ 321 A wrapper of builtin `exec` function. 322 """ 323 if SETTINGS["debug"]: 324 # black formatting is only important when debugging 325 try: 326 from black import FileMode, format_str 327 328 code = format_str(code, mode=FileMode(line_length=100)) 329 except Exception: 330 pass 331 exec(code, globals, locals) 332 return code 333 334 335def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 336 """ 337 Generate a function and add it to a Scope's `funcs` dictionary. 338 339 * `serde_scope`: the Scope instance to modify 340 * `func_name`: the name of the function 341 * `func_code`: the source code of the function 342 * `globals`: global variables that should be accessible to the generated function 343 """ 344 345 code = gen(func_code, globals) 346 serde_scope.funcs[func_name] = globals[func_name] 347 348 if SETTINGS["debug"]: 349 serde_scope.code[func_name] = code 350 351 352def is_instance(obj: Any, typ: Any) -> bool: 353 """ 354 pyserde's own `isinstance` helper. It accepts subscripted generics e.g. `list[int]` and 355 deeply check object against declared type. 356 """ 357 if dataclasses.is_dataclass(typ): 358 if not isinstance(typ, type): 359 raise SerdeError("expect dataclass class but dataclass instance received") 360 return isinstance(obj, typ) 361 elif is_opt(typ): 362 return is_opt_instance(obj, typ) 363 elif is_union(typ): 364 return is_union_instance(obj, typ) 365 elif is_list(typ): 366 return is_list_instance(obj, typ) 367 elif is_set(typ): 368 return is_set_instance(obj, typ) 369 elif is_tuple(typ): 370 return is_tuple_instance(obj, typ) 371 elif is_dict(typ): 372 return is_dict_instance(obj, typ) 373 elif is_generic(typ): 374 return is_generic_instance(obj, typ) 375 elif is_literal(typ): 376 return True 377 elif is_new_type_primitive(typ): 378 inner = getattr(typ, "__supertype__", None) 379 if type(inner) is type: 380 return isinstance(obj, inner) 381 else: 382 return False 383 elif is_any(typ): 384 return True 385 elif typ is Ellipsis: 386 return True 387 else: 388 return is_bearable(obj, typ) 389 390 391def is_opt_instance(obj: Any, typ: type[Any]) -> bool: 392 if obj is None: 393 return True 394 opt_arg = type_args(typ)[0] 395 return is_instance(obj, opt_arg) 396 397 398def is_union_instance(obj: Any, typ: type[Any]) -> bool: 399 for arg in type_args(typ): 400 if is_instance(obj, arg): 401 return True 402 return False 403 404 405def is_list_instance(obj: Any, typ: type[Any]) -> bool: 406 origin = get_origin(typ) or typ 407 if origin is list: 408 if not isinstance(obj, list): 409 return False 410 elif origin is MutableSequence: 411 if isinstance(obj, (str, bytes, bytearray)): 412 return False 413 if not isinstance(obj, MutableSequence): 414 return False 415 else: 416 if isinstance(obj, (str, bytes, bytearray)): 417 return False 418 if not isinstance(obj, Sequence): 419 return False 420 if len(obj) == 0 or is_bare_list(typ): 421 return True 422 list_arg = type_args(typ)[0] 423 # for speed reasons we just check the type of the 1st element 424 return is_instance(obj[0], list_arg) 425 426 427def is_set_instance(obj: Any, typ: type[Any]) -> bool: 428 if not isinstance(obj, Set): 429 return False 430 if len(obj) == 0 or is_bare_set(typ): 431 return True 432 set_arg = type_args(typ)[0] 433 # for speed reasons we just check the type of the 1st element 434 return is_instance(next(iter(obj)), set_arg) 435 436 437def is_tuple_instance(obj: Any, typ: type[Any]) -> bool: 438 args = type_args(typ) 439 440 if not isinstance(obj, tuple): 441 return False 442 443 # empty tuple 444 if len(args) == 0 and len(obj) == 0: 445 return True 446 447 # In the form of tuple[T, ...] 448 elif is_variable_tuple(typ): 449 # Get the first type arg. Since tuple[T, ...] is homogeneous tuple, 450 # all the elements should be of this type. 451 arg = type_args(typ)[0] 452 for v in obj: 453 if not is_instance(v, arg): 454 return False 455 return True 456 457 # bare tuple "tuple" is equivalent to tuple[Any, ...] 458 if is_bare_tuple(typ) and isinstance(obj, tuple): 459 return True 460 461 # All the other tuples e.g. tuple[int, str] 462 if len(obj) == len(args): 463 for element, arg in zip(obj, args, strict=False): 464 if not is_instance(element, arg): 465 return False 466 else: 467 return False 468 469 return True 470 471 472def is_dict_instance(obj: Any, typ: type[Any]) -> bool: 473 if not isinstance(obj, Mapping): 474 return False 475 if len(obj) == 0 or is_bare_dict(typ): 476 return True 477 ktyp = type_args(typ)[0] 478 vtyp = type_args(typ)[1] 479 for k, v in obj.items(): 480 # for speed reasons we just check the type of the 1st element 481 return is_instance(k, ktyp) and is_instance(v, vtyp) 482 return False 483 484 485def is_generic_instance(obj: Any, typ: type[Any]) -> bool: 486 return is_instance(obj, get_origin(typ)) 487 488 489@dataclass 490class Func: 491 """ 492 Function wrapper that provides `mangled` optional field. 493 494 pyserde copies every function reference into global scope 495 for code generation. Mangling function names is needed in 496 order to avoid name conflict in the global scope when 497 multiple fields receives `skip_if` attribute. 498 """ 499 500 inner: Callable[..., Any] 501 """ Function to wrap in """ 502 503 mangeld: str = "" 504 """ Mangled function name """ 505 506 def __call__(self, v: Any) -> None: 507 return self.inner(v) # type: ignore 508 509 @property 510 def name(self) -> str: 511 """ 512 Mangled function name 513 """ 514 return self.mangeld 515 516 517def skip_if_false(v: Any) -> Any: 518 return not bool(v) 519 520 521def skip_if_default(v: Any, default: Any | None = None) -> Any: 522 return v == default # Why return type is deduced to be Any? 523 524 525@dataclass 526class FlattenOpts: 527 """ 528 Flatten options. Currently not used. 529 """ 530 531 532def field( 533 *args: Any, 534 rename: str | None = None, 535 alias: list[str] | None = None, 536 skip: bool | None = None, 537 skip_if: Callable[[Any], Any] | None = None, 538 skip_if_false: bool | None = None, 539 skip_if_default: bool | None = None, 540 serializer: Callable[..., Any] | None = None, 541 deserializer: Callable[..., Any] | None = None, 542 flatten: FlattenOpts | bool | None = None, 543 metadata: dict[str, Any] | None = None, 544 **kwargs: Any, 545) -> Any: 546 """ 547 Declare a field with parameters. 548 """ 549 if not metadata: 550 metadata = {} 551 552 if rename is not None: 553 metadata["serde_rename"] = rename 554 if alias is not None: 555 metadata["serde_alias"] = alias 556 if skip is not None: 557 metadata["serde_skip"] = skip 558 if skip_if is not None: 559 metadata["serde_skip_if"] = skip_if 560 if skip_if_false is not None: 561 metadata["serde_skip_if_false"] = skip_if_false 562 if skip_if_default is not None: 563 metadata["serde_skip_if_default"] = skip_if_default 564 if serializer: 565 metadata["serde_serializer"] = serializer 566 if deserializer: 567 metadata["serde_deserializer"] = deserializer 568 if flatten is True: 569 metadata["serde_flatten"] = FlattenOpts() 570 elif flatten: 571 metadata["serde_flatten"] = flatten 572 573 return dataclasses.field(*args, metadata=metadata, **kwargs) 574 575 576@dataclass 577class Field(Generic[T]): 578 """ 579 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 580 581 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 582 """ 583 584 type: type[T] 585 """ Type of Field """ 586 name: str | None 587 """ Name of Field """ 588 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 589 """ Default value of Field """ 590 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 591 """ Default factory method of Field """ 592 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 593 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 594 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 595 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 596 metadata: Mapping[str, Any] = field(default_factory=dict) 597 kw_only: bool = False 598 case: str | None = None 599 alias: list[str] = field(default_factory=list) 600 rename: str | None = None 601 skip: bool | None = None 602 skip_if: Func | None = None 603 skip_if_false: bool | None = None 604 skip_if_default: bool | None = None 605 serializer: Func | None = None # Custom field serializer. 606 deserializer: Func | None = None # Custom field deserializer. 607 flatten: FlattenOpts | None = None 608 parent: Any | None = None 609 type_args: list[str] | None = None 610 611 @classmethod 612 def from_dataclass(cls, f: dataclasses.Field[T], parent: Any | None = None) -> Field[T]: 613 """ 614 Create `Field` object from `dataclasses.Field`. 615 """ 616 skip_if_false_func: Func | None = None 617 if f.metadata.get("serde_skip_if_false"): 618 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 619 620 skip_if_default_func: Func | None = None 621 if f.metadata.get("serde_skip_if_default"): 622 skip_if_def = functools.partial(skip_if_default, default=f.default) 623 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 624 625 skip_if: Func | None = None 626 if f.metadata.get("serde_skip_if"): 627 func = f.metadata.get("serde_skip_if") 628 if callable(func): 629 skip_if = Func(func, cls.mangle(f, "skip_if")) 630 631 serializer: Func | None = None 632 func = f.metadata.get("serde_serializer") 633 if func: 634 serializer = Func(func, cls.mangle(f, "serializer")) 635 636 deserializer: Func | None = None 637 func = f.metadata.get("serde_deserializer") 638 if func: 639 deserializer = Func(func, cls.mangle(f, "deserializer")) 640 641 flatten = f.metadata.get("serde_flatten") 642 if flatten is True: 643 flatten = FlattenOpts() 644 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 645 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 646 647 kw_only = bool(f.kw_only) 648 649 return cls( 650 f.type, # type: ignore 651 f.name, 652 default=f.default, 653 default_factory=f.default_factory, 654 init=f.init, 655 repr=f.repr, 656 hash=f.hash, 657 compare=f.compare, 658 metadata=f.metadata, 659 rename=f.metadata.get("serde_rename"), 660 alias=f.metadata.get("serde_alias", []), 661 skip=f.metadata.get("serde_skip"), 662 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 663 serializer=serializer, 664 deserializer=deserializer, 665 flatten=flatten, 666 parent=parent, 667 kw_only=kw_only, 668 ) 669 670 def to_dataclass(self) -> dataclasses.Field[T]: 671 base_kwargs = { 672 "init": self.init, 673 "repr": self.repr, 674 "hash": self.hash, 675 "compare": self.compare, 676 "metadata": self.metadata, 677 "kw_only": self.kw_only, 678 "default": self.default, 679 "default_factory": self.default_factory, 680 } 681 682 f = cast(dataclasses.Field[T], dataclasses.field(**base_kwargs)) 683 assert self.name 684 f.name = self.name 685 f.type = self.type 686 return f 687 688 def is_self_referencing(self) -> bool: 689 if self.type is None: 690 return False 691 if self.parent is None: 692 return False 693 return self.type == self.parent # type: ignore 694 695 @staticmethod 696 def mangle(field: dataclasses.Field[Any], name: str) -> str: 697 """ 698 Get mangled name based on field name. 699 """ 700 return f"{field.name}_{name}" 701 702 def conv_name(self, case: str | None = None) -> str: 703 """ 704 Get an actual field name which `rename` and `rename_all` conversions 705 are made. Use `name` property to get a field name before conversion. 706 """ 707 return conv(self, case or self.case) 708 709 def supports_default(self) -> bool: 710 return not getattr(self, "iterbased", False) and ( 711 has_default(self) or has_default_factory(self) 712 ) 713 714 715F = TypeVar("F", bound=Field[Any]) 716 717 718def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 719 """ 720 Iterate fields of the dataclass and returns `serde.core.Field`. 721 """ 722 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 723 724 if serialize_class_var: 725 for name, typ in get_type_hints(cls).items(): 726 if is_class_var(typ): 727 fields.append(field_cls(typ, name, default=getattr(cls, name))) 728 729 return fields # type: ignore 730 731 732def conv(f: Field[Any], case: str | None = None) -> str: 733 """ 734 Convert field name. 735 """ 736 name = f.name 737 if case: 738 casef = getattr(casefy, case, None) 739 if not casef: 740 raise SerdeError( 741 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 742 ) 743 name = casef(name) 744 if f.rename: 745 name = f.rename 746 if name is None: 747 raise SerdeError("Field name is None.") 748 return name 749 750 751def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 752 """ 753 Generate a function name that contains all union types 754 755 * `prefix` prefix to distinguish between serializing and deserializing 756 * `union_args`: type arguments of a Union 757 758 >>> from ipaddress import IPv4Address 759 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 760 'union_se_int_list_str__IPv4Address' 761 """ 762 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}") 763 764 765def literal_func_name(literal_args: Sequence[Any]) -> str: 766 """ 767 Generate a function name with all literals and corresponding types specified with Literal[...] 768 769 770 * `literal_args`: arguments of a Literal 771 772 >>> literal_func_name(["r", "w", "a", "x", "r+", "w+", "a+", "x+"]) 773 'literal_de_r_str_w_str_a_str_x_str_r__str_w__str_a__str_x__str' 774 """ 775 return re.sub( 776 r"[^A-Za-z0-9]", 777 "_", 778 f"{LITERAL_DE_PREFIX}_{'_'.join(f'{a}_{typename(type(a))}' for a in literal_args)}", 779 ) 780 781 782@dataclass(frozen=True) 783class Tagging: 784 """ 785 Controls how union is (de)serialized. This is the same concept as in 786 https://serde.rs/enum-representations.html 787 """ 788 789 class Kind(enum.Enum): 790 External = enum.auto() 791 Internal = enum.auto() 792 Adjacent = enum.auto() 793 Untagged = enum.auto() 794 795 tag: str | None = None 796 content: str | None = None 797 kind: Kind = Kind.External 798 799 def is_external(self) -> bool: 800 return self.kind == self.Kind.External 801 802 def is_internal(self) -> bool: 803 return self.kind == self.Kind.Internal 804 805 def is_adjacent(self) -> bool: 806 return self.kind == self.Kind.Adjacent 807 808 def is_untagged(self) -> bool: 809 return self.kind == self.Kind.Untagged 810 811 @classmethod 812 def is_taggable(cls, typ: type[Any]) -> bool: 813 return dataclasses.is_dataclass(typ) 814 815 def check(self) -> None: 816 if self.is_internal() and self.tag is None: 817 raise SerdeError('"tag" must be specified in InternalTagging') 818 if self.is_adjacent() and (self.tag is None or self.content is None): 819 raise SerdeError('"tag" and "content" must be specified in AdjacentTagging') 820 821 def produce_unique_class_name(self) -> str: 822 """ 823 Produce a unique class name for this tagging. The name is used for generated 824 wrapper dataclass and stored in `Cache`. 825 """ 826 if self.is_internal(): 827 tag = casefy.pascalcase(self.tag) # type: ignore 828 if not tag: 829 raise SerdeError('"tag" must be specified in InternalTagging') 830 return f"Internal{tag}" 831 elif self.is_adjacent(): 832 tag = casefy.pascalcase(self.tag) # type: ignore 833 content = casefy.pascalcase(self.content) # type: ignore 834 if not tag: 835 raise SerdeError('"tag" must be specified in AdjacentTagging') 836 if not content: 837 raise SerdeError('"content" must be specified in AdjacentTagging') 838 return f"Adjacent{tag}{content}" 839 else: 840 return self.kind.name 841 842 def __call__(self, cls: T) -> _WithTagging[T]: 843 return _WithTagging(cls, self) 844 845 846@overload 847def InternalTagging(tag: str) -> Tagging: ... 848 849 850@overload 851def InternalTagging(tag: str, cls: T) -> _WithTagging[T]: ... 852 853 854def InternalTagging(tag: str, cls: T | None = None) -> Tagging | _WithTagging[T]: 855 tagging = Tagging(tag, kind=Tagging.Kind.Internal) 856 if cls: 857 return tagging(cls) 858 else: 859 return tagging 860 861 862@overload 863def AdjacentTagging(tag: str, content: str) -> Tagging: ... 864 865 866@overload 867def AdjacentTagging(tag: str, content: str, cls: T) -> _WithTagging[T]: ... 868 869 870def AdjacentTagging(tag: str, content: str, cls: T | None = None) -> Tagging | _WithTagging[T]: 871 tagging = Tagging(tag, content, kind=Tagging.Kind.Adjacent) 872 if cls: 873 return tagging(cls) 874 else: 875 return tagging 876 877 878ExternalTagging = Tagging() 879 880Untagged = Tagging(kind=Tagging.Kind.Untagged) 881 882 883DefaultTagging = ExternalTagging 884 885 886def ensure(expr: Any, description: str) -> None: 887 if not expr: 888 raise Exception(description) 889 890 891def should_impl_dataclass(cls: type[Any]) -> bool: 892 """ 893 Test if class doesn't have @dataclass. 894 895 `dataclasses.is_dataclass` returns True even Derived class doesn't actually @dataclass. 896 >>> @dataclasses.dataclass 897 ... class Base: 898 ... a: int 899 >>> class Derived(Base): 900 ... b: int 901 >>> dataclasses.is_dataclass(Derived) 902 True 903 904 This function tells the class actually have @dataclass or not. 905 >>> should_impl_dataclass(Base) 906 False 907 >>> should_impl_dataclass(Derived) 908 True 909 """ 910 if not dataclasses.is_dataclass(cls): 911 return True 912 913 # Checking is_dataclass is not enough in such a case that the class is inherited 914 # from another dataclass. To do it correctly, check if all fields in __annotations__ 915 # are present as dataclass fields. 916 annotations = getattr(cls, "__annotations__", {}) 917 if not annotations: 918 return False 919 920 field_names = [field.name for field in dataclass_fields(cls)] 921 for field_name, annotation in annotations.items(): 922 # Omit InitVar field because it doesn't appear in dataclass fields. 923 if is_instance(annotation, dataclasses.InitVar): 924 continue 925 # This field in __annotations__ is not a part of dataclass fields. 926 # This means this class does not implement dataclass directly. 927 if field_name not in field_names: 928 return True 929 930 # If all of the fields in __annotation__ are present as dataclass fields, 931 # the class already implemented dataclass, thus returns False. 932 return False 933 934 935@dataclass 936class TypeCheck: 937 """ 938 Specify type check flavors. 939 """ 940 941 class Kind(enum.Enum): 942 Disabled = enum.auto() 943 """ No check performed """ 944 945 Coerce = enum.auto() 946 """ Value is coerced into the declared type """ 947 Strict = enum.auto() 948 """ Value are strictly checked against the declared type """ 949 950 kind: Kind 951 952 def is_strict(self) -> bool: 953 return self.kind == self.Kind.Strict 954 955 def is_coerce(self) -> bool: 956 return self.kind == self.Kind.Coerce 957 958 def __call__(self, **kwargs: Any) -> TypeCheck: 959 # TODO 960 return self 961 962 963disabled = TypeCheck(kind=TypeCheck.Kind.Disabled) 964 965coerce = TypeCheck(kind=TypeCheck.Kind.Coerce) 966 967strict = TypeCheck(kind=TypeCheck.Kind.Strict) 968 969 970def coerce_object(cls: str, field: str, typ: type[Any], obj: Any) -> Any: 971 try: 972 return typ(obj) if is_coercible(typ, obj) else obj 973 except Exception as e: 974 raise SerdeError( 975 f"failed to coerce the field {cls}.{field} value {obj} into {typename(typ)}: {e}" 976 ) 977 978 979def is_coercible(typ: type[Any], obj: Any) -> bool: 980 if obj is None: 981 return False 982 return True 983 984 985def has_default(field: Field[Any]) -> bool: 986 """ 987 Test if the field has default value. 988 989 >>> @dataclasses.dataclass 990 ... class C: 991 ... a: int 992 ... d: int = 10 993 >>> has_default(dataclasses.fields(C)[0]) 994 False 995 >>> has_default(dataclasses.fields(C)[1]) 996 True 997 """ 998 return not isinstance(field.default, dataclasses._MISSING_TYPE) 999 1000 1001def has_default_factory(field: Field[Any]) -> bool: 1002 """ 1003 Test if the field has default factory. 1004 1005 >>> @dataclasses.dataclass 1006 ... class C: 1007 ... a: int 1008 ... d: dict = dataclasses.field(default_factory=dict) 1009 >>> has_default_factory(dataclasses.fields(C)[0]) 1010 False 1011 >>> has_default_factory(dataclasses.fields(C)[1]) 1012 True 1013 """ 1014 return not isinstance(field.default_factory, dataclasses._MISSING_TYPE) 1015 1016 1017class ClassSerializer(Protocol): 1018 """ 1019 Interface for custom class serializer. 1020 1021 This protocol is intended to be used for custom class serializer. 1022 1023 >>> from datetime import datetime 1024 >>> from serde import serde 1025 >>> from plum import dispatch 1026 >>> class MySerializer(ClassSerializer): 1027 ... @dispatch 1028 ... def serialize(self, value: datetime) -> str: 1029 ... return value.strftime("%d/%m/%y") 1030 """ 1031 1032 def serialize(self, value: Any) -> Any: 1033 pass 1034 1035 1036class ClassDeserializer(Protocol): 1037 """ 1038 Interface for custom class deserializer. 1039 1040 This protocol is intended to be used for custom class deserializer. 1041 1042 >>> from datetime import datetime 1043 >>> from serde import serde 1044 >>> from plum import dispatch 1045 >>> class MyDeserializer(ClassDeserializer): 1046 ... @dispatch 1047 ... def deserialize(self, cls: type[datetime], value: Any) -> datetime: 1048 ... return datetime.strptime(value, "%d/%m/%y") 1049 """ 1050 1051 def deserialize(self, cls: Any, value: Any) -> Any: 1052 pass 1053 1054 1055GLOBAL_CLASS_SERIALIZER: list[ClassSerializer] = [] 1056 1057GLOBAL_CLASS_DESERIALIZER: list[ClassDeserializer] = [] 1058 1059 1060def add_serializer(serializer: ClassSerializer) -> None: 1061 """ 1062 Register custom global serializer. 1063 """ 1064 GLOBAL_CLASS_SERIALIZER.append(serializer) 1065 1066 1067def add_deserializer(deserializer: ClassDeserializer) -> None: 1068 """ 1069 Register custom global deserializer. 1070 """ 1071 GLOBAL_CLASS_DESERIALIZER.append(deserializer)
240@dataclass 241class Scope: 242 """ 243 Container to store types and functions used in code generation context. 244 """ 245 246 cls: type[Any] 247 """ The exact class this scope is for 248 (needed to distinguish scopes between inherited classes) """ 249 250 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 251 """ Generated serialize and deserialize functions """ 252 253 defaults: dict[str, Callable[..., Any] | Any] = dataclasses.field(default_factory=dict) 254 """ Default values of the dataclass fields (factories & normal values) """ 255 256 code: dict[str, str] = dataclasses.field(default_factory=dict) 257 """ Generated source code (only filled when debug is True) """ 258 259 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 260 """ The union serializing functions need references to their types """ 261 262 reuse_instances_default: bool = True 263 """ Default values for to_dict & from_dict arguments """ 264 265 convert_sets_default: bool = False 266 267 def __repr__(self) -> str: 268 res: list[str] = [] 269 270 res.append("==================================================") 271 res.append(self._justify(self.cls.__name__)) 272 res.append("==================================================") 273 res.append("") 274 275 if self.code: 276 res.append("--------------------------------------------------") 277 res.append(self._justify("Functions generated by pyserde")) 278 res.append("--------------------------------------------------") 279 res.extend(list(self.code.values())) 280 res.append("") 281 282 if self.funcs: 283 res.append("--------------------------------------------------") 284 res.append(self._justify("Function references in scope")) 285 res.append("--------------------------------------------------") 286 for k, v in self.funcs.items(): 287 res.append(f"{k}: {v}") 288 res.append("") 289 290 if self.defaults: 291 res.append("--------------------------------------------------") 292 res.append(self._justify("Default values for the dataclass fields")) 293 res.append("--------------------------------------------------") 294 for k, v in self.defaults.items(): 295 res.append(f"{k}: {v}") 296 res.append("") 297 298 if self.union_se_args: 299 res.append("--------------------------------------------------") 300 res.append(self._justify("Type list by used for union serialize functions")) 301 res.append("--------------------------------------------------") 302 for k, lst in self.union_se_args.items(): 303 res.append(f"{k}: {list(lst)}") 304 res.append("") 305 306 return "\n".join(res) 307 308 def _justify(self, s: str, length: int = 50) -> str: 309 white_spaces = int((50 - len(s)) / 2) 310 return " " * (white_spaces if white_spaces > 0 else 0) + s
Container to store types and functions used in code generation context.
The exact class this scope is for (needed to distinguish scopes between inherited classes)
Default values of the dataclass fields (factories & normal values)
318def gen( 319 code: str, globals: dict[str, Any] | None = None, locals: dict[str, Any] | None = None 320) -> str: 321 """ 322 A wrapper of builtin `exec` function. 323 """ 324 if SETTINGS["debug"]: 325 # black formatting is only important when debugging 326 try: 327 from black import FileMode, format_str 328 329 code = format_str(code, mode=FileMode(line_length=100)) 330 except Exception: 331 pass 332 exec(code, globals, locals) 333 return code
A wrapper of builtin exec function.
336def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 337 """ 338 Generate a function and add it to a Scope's `funcs` dictionary. 339 340 * `serde_scope`: the Scope instance to modify 341 * `func_name`: the name of the function 342 * `func_code`: the source code of the function 343 * `globals`: global variables that should be accessible to the generated function 344 """ 345 346 code = gen(func_code, globals) 347 serde_scope.funcs[func_name] = globals[func_name] 348 349 if SETTINGS["debug"]: 350 serde_scope.code[func_name] = code
Generate a function and add it to a Scope's funcs dictionary.
serde_scope: the Scope instance to modifyfunc_name: the name of the functionfunc_code: the source code of the functionglobals: global variables that should be accessible to the generated function
490@dataclass 491class Func: 492 """ 493 Function wrapper that provides `mangled` optional field. 494 495 pyserde copies every function reference into global scope 496 for code generation. Mangling function names is needed in 497 order to avoid name conflict in the global scope when 498 multiple fields receives `skip_if` attribute. 499 """ 500 501 inner: Callable[..., Any] 502 """ Function to wrap in """ 503 504 mangeld: str = "" 505 """ Mangled function name """ 506 507 def __call__(self, v: Any) -> None: 508 return self.inner(v) # type: ignore 509 510 @property 511 def name(self) -> str: 512 """ 513 Mangled function name 514 """ 515 return self.mangeld
Function wrapper that provides mangled optional field.
pyserde copies every function reference into global scope
for code generation. Mangling function names is needed in
order to avoid name conflict in the global scope when
multiple fields receives skip_if attribute.
577@dataclass 578class Field(Generic[T]): 579 """ 580 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 581 582 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 583 """ 584 585 type: type[T] 586 """ Type of Field """ 587 name: str | None 588 """ Name of Field """ 589 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 590 """ Default value of Field """ 591 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 592 """ Default factory method of Field """ 593 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 594 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 595 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 596 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 597 metadata: Mapping[str, Any] = field(default_factory=dict) 598 kw_only: bool = False 599 case: str | None = None 600 alias: list[str] = field(default_factory=list) 601 rename: str | None = None 602 skip: bool | None = None 603 skip_if: Func | None = None 604 skip_if_false: bool | None = None 605 skip_if_default: bool | None = None 606 serializer: Func | None = None # Custom field serializer. 607 deserializer: Func | None = None # Custom field deserializer. 608 flatten: FlattenOpts | None = None 609 parent: Any | None = None 610 type_args: list[str] | None = None 611 612 @classmethod 613 def from_dataclass(cls, f: dataclasses.Field[T], parent: Any | None = None) -> Field[T]: 614 """ 615 Create `Field` object from `dataclasses.Field`. 616 """ 617 skip_if_false_func: Func | None = None 618 if f.metadata.get("serde_skip_if_false"): 619 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 620 621 skip_if_default_func: Func | None = None 622 if f.metadata.get("serde_skip_if_default"): 623 skip_if_def = functools.partial(skip_if_default, default=f.default) 624 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 625 626 skip_if: Func | None = None 627 if f.metadata.get("serde_skip_if"): 628 func = f.metadata.get("serde_skip_if") 629 if callable(func): 630 skip_if = Func(func, cls.mangle(f, "skip_if")) 631 632 serializer: Func | None = None 633 func = f.metadata.get("serde_serializer") 634 if func: 635 serializer = Func(func, cls.mangle(f, "serializer")) 636 637 deserializer: Func | None = None 638 func = f.metadata.get("serde_deserializer") 639 if func: 640 deserializer = Func(func, cls.mangle(f, "deserializer")) 641 642 flatten = f.metadata.get("serde_flatten") 643 if flatten is True: 644 flatten = FlattenOpts() 645 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 646 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 647 648 kw_only = bool(f.kw_only) 649 650 return cls( 651 f.type, # type: ignore 652 f.name, 653 default=f.default, 654 default_factory=f.default_factory, 655 init=f.init, 656 repr=f.repr, 657 hash=f.hash, 658 compare=f.compare, 659 metadata=f.metadata, 660 rename=f.metadata.get("serde_rename"), 661 alias=f.metadata.get("serde_alias", []), 662 skip=f.metadata.get("serde_skip"), 663 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 664 serializer=serializer, 665 deserializer=deserializer, 666 flatten=flatten, 667 parent=parent, 668 kw_only=kw_only, 669 ) 670 671 def to_dataclass(self) -> dataclasses.Field[T]: 672 base_kwargs = { 673 "init": self.init, 674 "repr": self.repr, 675 "hash": self.hash, 676 "compare": self.compare, 677 "metadata": self.metadata, 678 "kw_only": self.kw_only, 679 "default": self.default, 680 "default_factory": self.default_factory, 681 } 682 683 f = cast(dataclasses.Field[T], dataclasses.field(**base_kwargs)) 684 assert self.name 685 f.name = self.name 686 f.type = self.type 687 return f 688 689 def is_self_referencing(self) -> bool: 690 if self.type is None: 691 return False 692 if self.parent is None: 693 return False 694 return self.type == self.parent # type: ignore 695 696 @staticmethod 697 def mangle(field: dataclasses.Field[Any], name: str) -> str: 698 """ 699 Get mangled name based on field name. 700 """ 701 return f"{field.name}_{name}" 702 703 def conv_name(self, case: str | None = None) -> str: 704 """ 705 Get an actual field name which `rename` and `rename_all` conversions 706 are made. Use `name` property to get a field name before conversion. 707 """ 708 return conv(self, case or self.case) 709 710 def supports_default(self) -> bool: 711 return not getattr(self, "iterbased", False) and ( 712 has_default(self) or has_default_factory(self) 713 )
Field class is similar to dataclasses.Field. It provides pyserde specific options.
type, name, default and default_factory are the same members as dataclasses.Field.
612 @classmethod 613 def from_dataclass(cls, f: dataclasses.Field[T], parent: Any | None = None) -> Field[T]: 614 """ 615 Create `Field` object from `dataclasses.Field`. 616 """ 617 skip_if_false_func: Func | None = None 618 if f.metadata.get("serde_skip_if_false"): 619 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 620 621 skip_if_default_func: Func | None = None 622 if f.metadata.get("serde_skip_if_default"): 623 skip_if_def = functools.partial(skip_if_default, default=f.default) 624 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 625 626 skip_if: Func | None = None 627 if f.metadata.get("serde_skip_if"): 628 func = f.metadata.get("serde_skip_if") 629 if callable(func): 630 skip_if = Func(func, cls.mangle(f, "skip_if")) 631 632 serializer: Func | None = None 633 func = f.metadata.get("serde_serializer") 634 if func: 635 serializer = Func(func, cls.mangle(f, "serializer")) 636 637 deserializer: Func | None = None 638 func = f.metadata.get("serde_deserializer") 639 if func: 640 deserializer = Func(func, cls.mangle(f, "deserializer")) 641 642 flatten = f.metadata.get("serde_flatten") 643 if flatten is True: 644 flatten = FlattenOpts() 645 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 646 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 647 648 kw_only = bool(f.kw_only) 649 650 return cls( 651 f.type, # type: ignore 652 f.name, 653 default=f.default, 654 default_factory=f.default_factory, 655 init=f.init, 656 repr=f.repr, 657 hash=f.hash, 658 compare=f.compare, 659 metadata=f.metadata, 660 rename=f.metadata.get("serde_rename"), 661 alias=f.metadata.get("serde_alias", []), 662 skip=f.metadata.get("serde_skip"), 663 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 664 serializer=serializer, 665 deserializer=deserializer, 666 flatten=flatten, 667 parent=parent, 668 kw_only=kw_only, 669 )
Create Field object from dataclasses.Field.
671 def to_dataclass(self) -> dataclasses.Field[T]: 672 base_kwargs = { 673 "init": self.init, 674 "repr": self.repr, 675 "hash": self.hash, 676 "compare": self.compare, 677 "metadata": self.metadata, 678 "kw_only": self.kw_only, 679 "default": self.default, 680 "default_factory": self.default_factory, 681 } 682 683 f = cast(dataclasses.Field[T], dataclasses.field(**base_kwargs)) 684 assert self.name 685 f.name = self.name 686 f.type = self.type 687 return f
696 @staticmethod 697 def mangle(field: dataclasses.Field[Any], name: str) -> str: 698 """ 699 Get mangled name based on field name. 700 """ 701 return f"{field.name}_{name}"
Get mangled name based on field name.
719def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 720 """ 721 Iterate fields of the dataclass and returns `serde.core.Field`. 722 """ 723 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 724 725 if serialize_class_var: 726 for name, typ in get_type_hints(cls).items(): 727 if is_class_var(typ): 728 fields.append(field_cls(typ, name, default=getattr(cls, name))) 729 730 return fields # type: ignore
Iterate fields of the dataclass and returns serde.core.Field.
Flatten options. Currently not used.
733def conv(f: Field[Any], case: str | None = None) -> str: 734 """ 735 Convert field name. 736 """ 737 name = f.name 738 if case: 739 casef = getattr(casefy, case, None) 740 if not casef: 741 raise SerdeError( 742 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 743 ) 744 name = casef(name) 745 if f.rename: 746 name = f.rename 747 if name is None: 748 raise SerdeError("Field name is None.") 749 return name
Convert field name.
752def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 753 """ 754 Generate a function name that contains all union types 755 756 * `prefix` prefix to distinguish between serializing and deserializing 757 * `union_args`: type arguments of a Union 758 759 >>> from ipaddress import IPv4Address 760 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 761 'union_se_int_list_str__IPv4Address' 762 """ 763 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}")
Generate a function name that contains all union types
prefixprefix to distinguish between serializing and deserializingunion_args: type arguments of a Union
>>> from ipaddress import IPv4Address
>>> union_func_name("union_se", [int, list[str], IPv4Address])
'union_se_int_list_str__IPv4Address'