serde.core
pyserde core module.
1""" 2pyserde core module. 3""" 4 5from __future__ import annotations 6import dataclasses 7import enum 8import functools 9import logging 10import sys 11import re 12import casefy 13from dataclasses import dataclass 14 15from beartype.door import is_bearable 16from collections.abc import Mapping, Sequence, Callable 17from typing import ( 18 overload, 19 TypeVar, 20 Generic, 21 Optional, 22 Any, 23 Protocol, 24 get_type_hints, 25 Union, 26) 27 28from .compat import ( 29 T, 30 SerdeError, 31 dataclass_fields, 32 get_origin, 33 is_bare_dict, 34 is_bare_list, 35 is_bare_set, 36 is_bare_tuple, 37 is_class_var, 38 is_dict, 39 is_generic, 40 is_list, 41 is_literal, 42 is_new_type_primitive, 43 is_any, 44 is_opt, 45 is_opt_dataclass, 46 is_set, 47 is_tuple, 48 is_union, 49 is_variable_tuple, 50 type_args, 51 typename, 52 _WithTagging, 53) 54 55__all__ = [ 56 "Scope", 57 "gen", 58 "add_func", 59 "Func", 60 "Field", 61 "fields", 62 "FlattenOpts", 63 "conv", 64 "union_func_name", 65] 66 67logger = logging.getLogger("serde") 68 69 70# name of the serde context key 71SERDE_SCOPE = "__serde__" 72 73# main function keys 74FROM_ITER = "from_iter" 75FROM_DICT = "from_dict" 76TO_ITER = "to_iter" 77TO_DICT = "to_dict" 78TYPE_CHECK = "typecheck" 79 80# prefixes used to distinguish the direction of a union function 81UNION_SE_PREFIX = "union_se" 82UNION_DE_PREFIX = "union_de" 83 84LITERAL_DE_PREFIX = "literal_de" 85 86SETTINGS = {"debug": False} 87 88 89def init(debug: bool = False) -> None: 90 SETTINGS["debug"] = debug 91 92 93@dataclass 94class Cache: 95 """ 96 Cache the generated code for non-dataclass classes. 97 98 for example, a type not bound in a dataclass is passed in from_json 99 100 ``` 101 from_json(Union[Foo, Bar], ...) 102 ``` 103 104 It creates the following wrapper dataclass on the fly, 105 106 ``` 107 @serde 108 @dataclass 109 class Union_Foo_bar: 110 v: Union[Foo, Bar] 111 ``` 112 113 Then store this class in this cache. Whenever the same type is passed, 114 the class is retrieved from this cache. So the overhead of the codegen 115 should be only once. 116 """ 117 118 classes: dict[str, type[Any]] = dataclasses.field(default_factory=dict) 119 120 def _get_class(self, cls: type[Any]) -> type[Any]: 121 """ 122 Get a wrapper class from the the cache. If not found, it will generate 123 the class and store it in the cache. 124 """ 125 class_name = f"Wrapper{typename(cls)}" 126 wrapper = self.classes.get(class_name) 127 return wrapper or self._generate_class(cls) 128 129 def _generate_class(self, cls: type[Any]) -> type[Any]: 130 """ 131 Generate a wrapper dataclass then make the it (de)serializable using 132 @serde decorator. 133 """ 134 from . import serde 135 136 class_name = f"Wrapper{typename(cls)}" 137 logger.debug(f"Generating a wrapper class code for {class_name}") 138 139 wrapper = dataclasses.make_dataclass(class_name, [("v", cls)]) 140 141 serde(wrapper) 142 self.classes[class_name] = wrapper 143 144 logger.debug(f"(de)serializing code for {class_name} was generated") 145 return wrapper 146 147 def serialize(self, cls: type[Any], obj: Any, **kwargs: Any) -> Any: 148 """ 149 Serialize the specified type of object into dict or tuple. 150 """ 151 wrapper = self._get_class(cls) 152 scope: Scope = getattr(wrapper, SERDE_SCOPE) 153 data = scope.funcs[TO_DICT](wrapper(obj), **kwargs) 154 155 logging.debug(f"Intermediate value: {data}") 156 157 return data["v"] 158 159 def deserialize(self, cls: type[T], obj: Any) -> T: 160 """ 161 Deserialize from dict or tuple into the specified type. 162 """ 163 wrapper = self._get_class(cls) 164 scope: Scope = getattr(wrapper, SERDE_SCOPE) 165 return scope.funcs[FROM_DICT](data={"v": obj}).v # type: ignore 166 167 def _get_union_class(self, cls: type[Any]) -> Optional[type[Any]]: 168 """ 169 Get a wrapper class from the the cache. If not found, it will generate 170 the class and store it in the cache. 171 """ 172 union_cls, tagging = _extract_from_with_tagging(cls) 173 class_name = union_func_name( 174 f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls)) 175 ) 176 wrapper = self.classes.get(class_name) 177 return wrapper or self._generate_union_class(cls) 178 179 def _generate_union_class(self, cls: type[Any]) -> type[Any]: 180 """ 181 Generate a wrapper dataclass then make the it (de)serializable using 182 @serde decorator. 183 """ 184 import serde 185 186 union_cls, tagging = _extract_from_with_tagging(cls) 187 class_name = union_func_name( 188 f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls)) 189 ) 190 wrapper = dataclasses.make_dataclass(class_name, [("v", union_cls)]) 191 serde.serde(wrapper, tagging=tagging) 192 self.classes[class_name] = wrapper 193 return wrapper 194 195 def serialize_union(self, cls: type[Any], obj: Any) -> Any: 196 """ 197 Serialize the specified Union into dict or tuple. 198 """ 199 union_cls, _ = _extract_from_with_tagging(cls) 200 wrapper = self._get_union_class(cls) 201 scope: Scope = getattr(wrapper, SERDE_SCOPE) 202 func_name = union_func_name(UNION_SE_PREFIX, list(type_args(union_cls))) 203 return scope.funcs[func_name](obj, False, False) 204 205 def deserialize_union(self, cls: type[T], data: Any) -> T: 206 """ 207 Deserialize from dict or tuple into the specified Union. 208 """ 209 union_cls, _ = _extract_from_with_tagging(cls) 210 wrapper = self._get_union_class(cls) 211 scope: Scope = getattr(wrapper, SERDE_SCOPE) 212 func_name = union_func_name(UNION_DE_PREFIX, list(type_args(union_cls))) 213 return scope.funcs[func_name](cls=union_cls, data=data) # type: ignore 214 215 216def _extract_from_with_tagging(maybe_with_tagging: Any) -> tuple[Any, Tagging]: 217 try: 218 if isinstance(maybe_with_tagging, _WithTagging): 219 union_cls = maybe_with_tagging.inner 220 tagging = maybe_with_tagging.tagging 221 else: 222 raise Exception() 223 except Exception: 224 union_cls = maybe_with_tagging 225 tagging = ExternalTagging 226 227 return (union_cls, tagging) 228 229 230CACHE = Cache() 231""" Global cache variable for non-dataclass classes """ 232 233 234@dataclass 235class Scope: 236 """ 237 Container to store types and functions used in code generation context. 238 """ 239 240 cls: type[Any] 241 """ The exact class this scope is for 242 (needed to distinguish scopes between inherited classes) """ 243 244 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 245 """ Generated serialize and deserialize functions """ 246 247 defaults: dict[str, Union[Callable[..., Any], Any]] = dataclasses.field(default_factory=dict) 248 """ Default values of the dataclass fields (factories & normal values) """ 249 250 code: dict[str, str] = dataclasses.field(default_factory=dict) 251 """ Generated source code (only filled when debug is True) """ 252 253 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 254 """ The union serializing functions need references to their types """ 255 256 reuse_instances_default: bool = True 257 """ Default values for to_dict & from_dict arguments """ 258 259 convert_sets_default: bool = False 260 261 def __repr__(self) -> str: 262 res: list[str] = [] 263 264 res.append("==================================================") 265 res.append(self._justify(self.cls.__name__)) 266 res.append("==================================================") 267 res.append("") 268 269 if self.code: 270 res.append("--------------------------------------------------") 271 res.append(self._justify("Functions generated by pyserde")) 272 res.append("--------------------------------------------------") 273 res.extend(list(self.code.values())) 274 res.append("") 275 276 if self.funcs: 277 res.append("--------------------------------------------------") 278 res.append(self._justify("Function references in scope")) 279 res.append("--------------------------------------------------") 280 for k, v in self.funcs.items(): 281 res.append(f"{k}: {v}") 282 res.append("") 283 284 if self.defaults: 285 res.append("--------------------------------------------------") 286 res.append(self._justify("Default values for the dataclass fields")) 287 res.append("--------------------------------------------------") 288 for k, v in self.defaults.items(): 289 res.append(f"{k}: {v}") 290 res.append("") 291 292 if self.union_se_args: 293 res.append("--------------------------------------------------") 294 res.append(self._justify("Type list by used for union serialize functions")) 295 res.append("--------------------------------------------------") 296 for k, lst in self.union_se_args.items(): 297 res.append(f"{k}: {list(lst)}") 298 res.append("") 299 300 return "\n".join(res) 301 302 def _justify(self, s: str, length: int = 50) -> str: 303 white_spaces = int((50 - len(s)) / 2) 304 return " " * (white_spaces if white_spaces > 0 else 0) + s 305 306 307def raise_unsupported_type(obj: Any) -> None: 308 # needed because we can not render a raise statement everywhere, e.g. as argument 309 raise SerdeError(f"Unsupported type: {typename(type(obj))}") 310 311 312def gen( 313 code: str, globals: Optional[dict[str, Any]] = None, locals: Optional[dict[str, Any]] = None 314) -> str: 315 """ 316 A wrapper of builtin `exec` function. 317 """ 318 if SETTINGS["debug"]: 319 # black formatting is only important when debugging 320 try: 321 from black import FileMode, format_str 322 323 code = format_str(code, mode=FileMode(line_length=100)) 324 except Exception: 325 pass 326 exec(code, globals, locals) 327 return code 328 329 330def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 331 """ 332 Generate a function and add it to a Scope's `funcs` dictionary. 333 334 * `serde_scope`: the Scope instance to modify 335 * `func_name`: the name of the function 336 * `func_code`: the source code of the function 337 * `globals`: global variables that should be accessible to the generated function 338 """ 339 340 code = gen(func_code, globals) 341 serde_scope.funcs[func_name] = globals[func_name] 342 343 if SETTINGS["debug"]: 344 serde_scope.code[func_name] = code 345 346 347def is_instance(obj: Any, typ: Any) -> bool: 348 """ 349 pyserde's own `isinstance` helper. It accepts subscripted generics e.g. `list[int]` and 350 deeply check object against declared type. 351 """ 352 if dataclasses.is_dataclass(typ): 353 if not isinstance(typ, type): 354 raise SerdeError("expect dataclass class but dataclass instance received") 355 return isinstance(obj, typ) 356 elif is_opt(typ): 357 return is_opt_instance(obj, typ) 358 elif is_union(typ): 359 return is_union_instance(obj, typ) 360 elif is_list(typ): 361 return is_list_instance(obj, typ) 362 elif is_set(typ): 363 return is_set_instance(obj, typ) 364 elif is_tuple(typ): 365 return is_tuple_instance(obj, typ) 366 elif is_dict(typ): 367 return is_dict_instance(obj, typ) 368 elif is_generic(typ): 369 return is_generic_instance(obj, typ) 370 elif is_literal(typ): 371 return True 372 elif is_new_type_primitive(typ): 373 inner = getattr(typ, "__supertype__", None) 374 if type(inner) is type: 375 return isinstance(obj, inner) 376 else: 377 return False 378 elif is_any(typ): 379 return True 380 elif typ is Ellipsis: 381 return True 382 else: 383 return is_bearable(obj, typ) 384 385 386def is_opt_instance(obj: Any, typ: type[Any]) -> bool: 387 if obj is None: 388 return True 389 opt_arg = type_args(typ)[0] 390 return is_instance(obj, opt_arg) 391 392 393def is_union_instance(obj: Any, typ: type[Any]) -> bool: 394 for arg in type_args(typ): 395 if is_instance(obj, arg): 396 return True 397 return False 398 399 400def is_list_instance(obj: Any, typ: type[Any]) -> bool: 401 if not isinstance(obj, list): 402 return False 403 if len(obj) == 0 or is_bare_list(typ): 404 return True 405 list_arg = type_args(typ)[0] 406 # for speed reasons we just check the type of the 1st element 407 return is_instance(obj[0], list_arg) 408 409 410def is_set_instance(obj: Any, typ: type[Any]) -> bool: 411 if not isinstance(obj, (set, frozenset)): 412 return False 413 if len(obj) == 0 or is_bare_set(typ): 414 return True 415 set_arg = type_args(typ)[0] 416 # for speed reasons we just check the type of the 1st element 417 return is_instance(next(iter(obj)), set_arg) 418 419 420def is_tuple_instance(obj: Any, typ: type[Any]) -> bool: 421 args = type_args(typ) 422 423 if not isinstance(obj, tuple): 424 return False 425 426 # empty tuple 427 if len(args) == 0 and len(obj) == 0: 428 return True 429 430 # In the form of tuple[T, ...] 431 elif is_variable_tuple(typ): 432 # Get the first type arg. Since tuple[T, ...] is homogeneous tuple, 433 # all the elements should be of this type. 434 arg = type_args(typ)[0] 435 for v in obj: 436 if not is_instance(v, arg): 437 return False 438 return True 439 440 # bare tuple "tuple" is equivalent to tuple[Any, ...] 441 if is_bare_tuple(typ) and isinstance(obj, tuple): 442 return True 443 444 # All the other tuples e.g. tuple[int, str] 445 if len(obj) == len(args): 446 for element, arg in zip(obj, args): 447 if not is_instance(element, arg): 448 return False 449 else: 450 return False 451 452 return True 453 454 455def is_dict_instance(obj: Any, typ: type[Any]) -> bool: 456 if not isinstance(obj, dict): 457 return False 458 if len(obj) == 0 or is_bare_dict(typ): 459 return True 460 ktyp = type_args(typ)[0] 461 vtyp = type_args(typ)[1] 462 for k, v in obj.items(): 463 # for speed reasons we just check the type of the 1st element 464 return is_instance(k, ktyp) and is_instance(v, vtyp) 465 return False 466 467 468def is_generic_instance(obj: Any, typ: type[Any]) -> bool: 469 return is_instance(obj, get_origin(typ)) 470 471 472@dataclass 473class Func: 474 """ 475 Function wrapper that provides `mangled` optional field. 476 477 pyserde copies every function reference into global scope 478 for code generation. Mangling function names is needed in 479 order to avoid name conflict in the global scope when 480 multiple fields receives `skip_if` attribute. 481 """ 482 483 inner: Callable[..., Any] 484 """ Function to wrap in """ 485 486 mangeld: str = "" 487 """ Mangled function name """ 488 489 def __call__(self, v: Any) -> None: 490 return self.inner(v) # type: ignore 491 492 @property 493 def name(self) -> str: 494 """ 495 Mangled function name 496 """ 497 return self.mangeld 498 499 500def skip_if_false(v: Any) -> Any: 501 return not bool(v) 502 503 504def skip_if_default(v: Any, default: Optional[Any] = None) -> Any: 505 return v == default # Why return type is deduced to be Any? 506 507 508@dataclass 509class FlattenOpts: 510 """ 511 Flatten options. Currently not used. 512 """ 513 514 515def field( 516 *args: Any, 517 rename: Optional[str] = None, 518 alias: Optional[list[str]] = None, 519 skip: Optional[bool] = None, 520 skip_if: Optional[Callable[[Any], Any]] = None, 521 skip_if_false: Optional[bool] = None, 522 skip_if_default: Optional[bool] = None, 523 serializer: Optional[Callable[..., Any]] = None, 524 deserializer: Optional[Callable[..., Any]] = None, 525 flatten: Optional[Union[FlattenOpts, bool]] = None, 526 metadata: Optional[dict[str, Any]] = None, 527 **kwargs: Any, 528) -> Any: 529 """ 530 Declare a field with parameters. 531 """ 532 if not metadata: 533 metadata = {} 534 535 if rename is not None: 536 metadata["serde_rename"] = rename 537 if alias is not None: 538 metadata["serde_alias"] = alias 539 if skip is not None: 540 metadata["serde_skip"] = skip 541 if skip_if is not None: 542 metadata["serde_skip_if"] = skip_if 543 if skip_if_false is not None: 544 metadata["serde_skip_if_false"] = skip_if_false 545 if skip_if_default is not None: 546 metadata["serde_skip_if_default"] = skip_if_default 547 if serializer: 548 metadata["serde_serializer"] = serializer 549 if deserializer: 550 metadata["serde_deserializer"] = deserializer 551 if flatten is True: 552 metadata["serde_flatten"] = FlattenOpts() 553 elif flatten: 554 metadata["serde_flatten"] = flatten 555 556 return dataclasses.field(*args, metadata=metadata, **kwargs) 557 558 559@dataclass 560class Field(Generic[T]): 561 """ 562 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 563 564 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 565 """ 566 567 type: type[T] 568 """ Type of Field """ 569 name: Optional[str] 570 """ Name of Field """ 571 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 572 """ Default value of Field """ 573 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 574 """ Default factory method of Field """ 575 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 576 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 577 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 578 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 579 metadata: Mapping[str, Any] = field(default_factory=dict) 580 kw_only: bool = False 581 case: Optional[str] = None 582 alias: list[str] = field(default_factory=list) 583 rename: Optional[str] = None 584 skip: Optional[bool] = None 585 skip_if: Optional[Func] = None 586 skip_if_false: Optional[bool] = None 587 skip_if_default: Optional[bool] = None 588 serializer: Optional[Func] = None # Custom field serializer. 589 deserializer: Optional[Func] = None # Custom field deserializer. 590 flatten: Optional[FlattenOpts] = None 591 parent: Optional[Any] = None 592 type_args: Optional[list[str]] = None 593 594 @classmethod 595 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 596 """ 597 Create `Field` object from `dataclasses.Field`. 598 """ 599 skip_if_false_func: Optional[Func] = None 600 if f.metadata.get("serde_skip_if_false"): 601 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 602 603 skip_if_default_func: Optional[Func] = None 604 if f.metadata.get("serde_skip_if_default"): 605 skip_if_def = functools.partial(skip_if_default, default=f.default) 606 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 607 608 skip_if: Optional[Func] = None 609 if f.metadata.get("serde_skip_if"): 610 func = f.metadata.get("serde_skip_if") 611 if callable(func): 612 skip_if = Func(func, cls.mangle(f, "skip_if")) 613 614 serializer: Optional[Func] = None 615 func = f.metadata.get("serde_serializer") 616 if func: 617 serializer = Func(func, cls.mangle(f, "serializer")) 618 619 deserializer: Optional[Func] = None 620 func = f.metadata.get("serde_deserializer") 621 if func: 622 deserializer = Func(func, cls.mangle(f, "deserializer")) 623 624 flatten = f.metadata.get("serde_flatten") 625 if flatten is True: 626 flatten = FlattenOpts() 627 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 628 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 629 630 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 631 632 return cls( 633 f.type, # type: ignore 634 f.name, 635 default=f.default, 636 default_factory=f.default_factory, 637 init=f.init, 638 repr=f.repr, 639 hash=f.hash, 640 compare=f.compare, 641 metadata=f.metadata, 642 rename=f.metadata.get("serde_rename"), 643 alias=f.metadata.get("serde_alias", []), 644 skip=f.metadata.get("serde_skip"), 645 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 646 serializer=serializer, 647 deserializer=deserializer, 648 flatten=flatten, 649 parent=parent, 650 kw_only=kw_only, 651 ) 652 653 def to_dataclass(self) -> dataclasses.Field[T]: 654 f = dataclasses.Field( 655 default=self.default, 656 default_factory=self.default_factory, 657 init=self.init, 658 repr=self.repr, 659 hash=self.hash, 660 compare=self.compare, 661 metadata=self.metadata, 662 kw_only=self.kw_only, 663 ) 664 assert self.name 665 f.name = self.name 666 f.type = self.type 667 return f 668 669 def is_self_referencing(self) -> bool: 670 if self.type is None: 671 return False 672 if self.parent is None: 673 return False 674 return self.type == self.parent # type: ignore 675 676 @staticmethod 677 def mangle(field: dataclasses.Field[Any], name: str) -> str: 678 """ 679 Get mangled name based on field name. 680 """ 681 return f"{field.name}_{name}" 682 683 def conv_name(self, case: Optional[str] = None) -> str: 684 """ 685 Get an actual field name which `rename` and `rename_all` conversions 686 are made. Use `name` property to get a field name before conversion. 687 """ 688 return conv(self, case or self.case) 689 690 def supports_default(self) -> bool: 691 return not getattr(self, "iterbased", False) and ( 692 has_default(self) or has_default_factory(self) 693 ) 694 695 696F = TypeVar("F", bound=Field[Any]) 697 698 699def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 700 """ 701 Iterate fields of the dataclass and returns `serde.core.Field`. 702 """ 703 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 704 705 if serialize_class_var: 706 for name, typ in get_type_hints(cls).items(): 707 if is_class_var(typ): 708 fields.append(field_cls(typ, name, default=getattr(cls, name))) 709 710 return fields # type: ignore 711 712 713def conv(f: Field[Any], case: Optional[str] = None) -> str: 714 """ 715 Convert field name. 716 """ 717 name = f.name 718 if case: 719 casef = getattr(casefy, case, None) 720 if not casef: 721 raise SerdeError( 722 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 723 ) 724 name = casef(name) 725 if f.rename: 726 name = f.rename 727 if name is None: 728 raise SerdeError("Field name is None.") 729 return name 730 731 732def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 733 """ 734 Generate a function name that contains all union types 735 736 * `prefix` prefix to distinguish between serializing and deserializing 737 * `union_args`: type arguments of a Union 738 739 >>> from ipaddress import IPv4Address 740 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 741 'union_se_int_list_str__IPv4Address' 742 """ 743 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}") 744 745 746def literal_func_name(literal_args: Sequence[Any]) -> str: 747 """ 748 Generate a function name with all literals and corresponding types specified with Literal[...] 749 750 751 * `literal_args`: arguments of a Literal 752 753 >>> literal_func_name(["r", "w", "a", "x", "r+", "w+", "a+", "x+"]) 754 'literal_de_r_str_w_str_a_str_x_str_r__str_w__str_a__str_x__str' 755 """ 756 return re.sub( 757 r"[^A-Za-z0-9]", 758 "_", 759 f"{LITERAL_DE_PREFIX}_{'_'.join(f'{a}_{typename(type(a))}' for a in literal_args)}", 760 ) 761 762 763@dataclass(unsafe_hash=True) 764class Tagging: 765 """ 766 Controls how union is (de)serialized. This is the same concept as in 767 https://serde.rs/enum-representations.html 768 """ 769 770 class Kind(enum.Enum): 771 External = enum.auto() 772 Internal = enum.auto() 773 Adjacent = enum.auto() 774 Untagged = enum.auto() 775 776 tag: Optional[str] = None 777 content: Optional[str] = None 778 kind: Kind = Kind.External 779 780 def is_external(self) -> bool: 781 return self.kind == self.Kind.External 782 783 def is_internal(self) -> bool: 784 return self.kind == self.Kind.Internal 785 786 def is_adjacent(self) -> bool: 787 return self.kind == self.Kind.Adjacent 788 789 def is_untagged(self) -> bool: 790 return self.kind == self.Kind.Untagged 791 792 @classmethod 793 def is_taggable(cls, typ: type[Any]) -> bool: 794 return dataclasses.is_dataclass(typ) 795 796 def check(self) -> None: 797 if self.is_internal() and self.tag is None: 798 raise SerdeError('"tag" must be specified in InternalTagging') 799 if self.is_adjacent() and (self.tag is None or self.content is None): 800 raise SerdeError('"tag" and "content" must be specified in AdjacentTagging') 801 802 def produce_unique_class_name(self) -> str: 803 """ 804 Produce a unique class name for this tagging. The name is used for generated 805 wrapper dataclass and stored in `Cache`. 806 """ 807 if self.is_internal(): 808 tag = casefy.pascalcase(self.tag) # type: ignore 809 if not tag: 810 raise SerdeError('"tag" must be specified in InternalTagging') 811 return f"Internal{tag}" 812 elif self.is_adjacent(): 813 tag = casefy.pascalcase(self.tag) # type: ignore 814 content = casefy.pascalcase(self.content) # type: ignore 815 if not tag: 816 raise SerdeError('"tag" must be specified in AdjacentTagging') 817 if not content: 818 raise SerdeError('"content" must be specified in AdjacentTagging') 819 return f"Adjacent{tag}{content}" 820 else: 821 return self.kind.name 822 823 def __call__(self, cls: T) -> _WithTagging[T]: 824 return _WithTagging(cls, self) 825 826 827@overload 828def InternalTagging(tag: str) -> Tagging: ... 829 830 831@overload 832def InternalTagging(tag: str, cls: T) -> _WithTagging[T]: ... 833 834 835def InternalTagging(tag: str, cls: Optional[T] = None) -> Union[Tagging, _WithTagging[T]]: 836 tagging = Tagging(tag, kind=Tagging.Kind.Internal) 837 if cls: 838 return tagging(cls) 839 else: 840 return tagging 841 842 843@overload 844def AdjacentTagging(tag: str, content: str) -> Tagging: ... 845 846 847@overload 848def AdjacentTagging(tag: str, content: str, cls: T) -> _WithTagging[T]: ... 849 850 851def AdjacentTagging( 852 tag: str, content: str, cls: Optional[T] = None 853) -> Union[Tagging, _WithTagging[T]]: 854 tagging = Tagging(tag, content, kind=Tagging.Kind.Adjacent) 855 if cls: 856 return tagging(cls) 857 else: 858 return tagging 859 860 861ExternalTagging = Tagging() 862 863Untagged = Tagging(kind=Tagging.Kind.Untagged) 864 865 866DefaultTagging = ExternalTagging 867 868 869def ensure(expr: Any, description: str) -> None: 870 if not expr: 871 raise Exception(description) 872 873 874def should_impl_dataclass(cls: type[Any]) -> bool: 875 """ 876 Test if class doesn't have @dataclass. 877 878 `dataclasses.is_dataclass` returns True even Derived class doesn't actually @dataclass. 879 >>> @dataclasses.dataclass 880 ... class Base: 881 ... a: int 882 >>> class Derived(Base): 883 ... b: int 884 >>> dataclasses.is_dataclass(Derived) 885 True 886 887 This function tells the class actually have @dataclass or not. 888 >>> should_impl_dataclass(Base) 889 False 890 >>> should_impl_dataclass(Derived) 891 True 892 """ 893 if not dataclasses.is_dataclass(cls): 894 return True 895 896 # Checking is_dataclass is not enough in such a case that the class is inherited 897 # from another dataclass. To do it correctly, check if all fields in __annotations__ 898 # are present as dataclass fields. 899 annotations = getattr(cls, "__annotations__", {}) 900 if not annotations: 901 return False 902 903 field_names = [field.name for field in dataclass_fields(cls)] 904 for field_name, annotation in annotations.items(): 905 # Omit InitVar field because it doesn't appear in dataclass fields. 906 if is_instance(annotation, dataclasses.InitVar): 907 continue 908 # This field in __annotations__ is not a part of dataclass fields. 909 # This means this class does not implement dataclass directly. 910 if field_name not in field_names: 911 return True 912 913 # If all of the fields in __annotation__ are present as dataclass fields, 914 # the class already implemented dataclass, thus returns False. 915 return False 916 917 918@dataclass 919class TypeCheck: 920 """ 921 Specify type check flavors. 922 """ 923 924 class Kind(enum.Enum): 925 Disabled = enum.auto() 926 """ No check performed """ 927 928 Coerce = enum.auto() 929 """ Value is coerced into the declared type """ 930 Strict = enum.auto() 931 """ Value are strictly checked against the declared type """ 932 933 kind: Kind 934 935 def is_strict(self) -> bool: 936 return self.kind == self.Kind.Strict 937 938 def is_coerce(self) -> bool: 939 return self.kind == self.Kind.Coerce 940 941 def __call__(self, **kwargs: Any) -> TypeCheck: 942 # TODO 943 return self 944 945 946disabled = TypeCheck(kind=TypeCheck.Kind.Disabled) 947 948coerce = TypeCheck(kind=TypeCheck.Kind.Coerce) 949 950strict = TypeCheck(kind=TypeCheck.Kind.Strict) 951 952 953def coerce_object(cls: str, field: str, typ: type[Any], obj: Any) -> Any: 954 try: 955 return typ(obj) if is_coercible(typ, obj) else obj 956 except Exception as e: 957 raise SerdeError( 958 f"failed to coerce the field {cls}.{field} value {obj} into {typename(typ)}: {e}" 959 ) 960 961 962def is_coercible(typ: type[Any], obj: Any) -> bool: 963 if obj is None: 964 return False 965 return True 966 967 968def has_default(field: Field[Any]) -> bool: 969 """ 970 Test if the field has default value. 971 972 >>> @dataclasses.dataclass 973 ... class C: 974 ... a: int 975 ... d: int = 10 976 >>> has_default(dataclasses.fields(C)[0]) 977 False 978 >>> has_default(dataclasses.fields(C)[1]) 979 True 980 """ 981 return not isinstance(field.default, dataclasses._MISSING_TYPE) 982 983 984def has_default_factory(field: Field[Any]) -> bool: 985 """ 986 Test if the field has default factory. 987 988 >>> @dataclasses.dataclass 989 ... class C: 990 ... a: int 991 ... d: dict = dataclasses.field(default_factory=dict) 992 >>> has_default_factory(dataclasses.fields(C)[0]) 993 False 994 >>> has_default_factory(dataclasses.fields(C)[1]) 995 True 996 """ 997 return not isinstance(field.default_factory, dataclasses._MISSING_TYPE) 998 999 1000class ClassSerializer(Protocol): 1001 """ 1002 Interface for custom class serializer. 1003 1004 This protocol is intended to be used for custom class serializer. 1005 1006 >>> from datetime import datetime 1007 >>> from serde import serde 1008 >>> from plum import dispatch 1009 >>> class MySerializer(ClassSerializer): 1010 ... @dispatch 1011 ... def serialize(self, value: datetime) -> str: 1012 ... return value.strftime("%d/%m/%y") 1013 """ 1014 1015 def serialize(self, value: Any) -> Any: 1016 pass 1017 1018 1019class ClassDeserializer(Protocol): 1020 """ 1021 Interface for custom class deserializer. 1022 1023 This protocol is intended to be used for custom class deserializer. 1024 1025 >>> from datetime import datetime 1026 >>> from serde import serde 1027 >>> from plum import dispatch 1028 >>> class MyDeserializer(ClassDeserializer): 1029 ... @dispatch 1030 ... def deserialize(self, cls: type[datetime], value: Any) -> datetime: 1031 ... return datetime.strptime(value, "%d/%m/%y") 1032 """ 1033 1034 def deserialize(self, cls: Any, value: Any) -> Any: 1035 pass 1036 1037 1038GLOBAL_CLASS_SERIALIZER: list[ClassSerializer] = [] 1039 1040GLOBAL_CLASS_DESERIALIZER: list[ClassDeserializer] = [] 1041 1042 1043def add_serializer(serializer: ClassSerializer) -> None: 1044 """ 1045 Register custom global serializer. 1046 """ 1047 GLOBAL_CLASS_SERIALIZER.append(serializer) 1048 1049 1050def add_deserializer(deserializer: ClassDeserializer) -> None: 1051 """ 1052 Register custom global deserializer. 1053 """ 1054 GLOBAL_CLASS_DESERIALIZER.append(deserializer)
235@dataclass 236class Scope: 237 """ 238 Container to store types and functions used in code generation context. 239 """ 240 241 cls: type[Any] 242 """ The exact class this scope is for 243 (needed to distinguish scopes between inherited classes) """ 244 245 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 246 """ Generated serialize and deserialize functions """ 247 248 defaults: dict[str, Union[Callable[..., Any], Any]] = dataclasses.field(default_factory=dict) 249 """ Default values of the dataclass fields (factories & normal values) """ 250 251 code: dict[str, str] = dataclasses.field(default_factory=dict) 252 """ Generated source code (only filled when debug is True) """ 253 254 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 255 """ The union serializing functions need references to their types """ 256 257 reuse_instances_default: bool = True 258 """ Default values for to_dict & from_dict arguments """ 259 260 convert_sets_default: bool = False 261 262 def __repr__(self) -> str: 263 res: list[str] = [] 264 265 res.append("==================================================") 266 res.append(self._justify(self.cls.__name__)) 267 res.append("==================================================") 268 res.append("") 269 270 if self.code: 271 res.append("--------------------------------------------------") 272 res.append(self._justify("Functions generated by pyserde")) 273 res.append("--------------------------------------------------") 274 res.extend(list(self.code.values())) 275 res.append("") 276 277 if self.funcs: 278 res.append("--------------------------------------------------") 279 res.append(self._justify("Function references in scope")) 280 res.append("--------------------------------------------------") 281 for k, v in self.funcs.items(): 282 res.append(f"{k}: {v}") 283 res.append("") 284 285 if self.defaults: 286 res.append("--------------------------------------------------") 287 res.append(self._justify("Default values for the dataclass fields")) 288 res.append("--------------------------------------------------") 289 for k, v in self.defaults.items(): 290 res.append(f"{k}: {v}") 291 res.append("") 292 293 if self.union_se_args: 294 res.append("--------------------------------------------------") 295 res.append(self._justify("Type list by used for union serialize functions")) 296 res.append("--------------------------------------------------") 297 for k, lst in self.union_se_args.items(): 298 res.append(f"{k}: {list(lst)}") 299 res.append("") 300 301 return "\n".join(res) 302 303 def _justify(self, s: str, length: int = 50) -> str: 304 white_spaces = int((50 - len(s)) / 2) 305 return " " * (white_spaces if white_spaces > 0 else 0) + s
Container to store types and functions used in code generation context.
The exact class this scope is for (needed to distinguish scopes between inherited classes)
Default values of the dataclass fields (factories & normal values)
313def gen( 314 code: str, globals: Optional[dict[str, Any]] = None, locals: Optional[dict[str, Any]] = None 315) -> str: 316 """ 317 A wrapper of builtin `exec` function. 318 """ 319 if SETTINGS["debug"]: 320 # black formatting is only important when debugging 321 try: 322 from black import FileMode, format_str 323 324 code = format_str(code, mode=FileMode(line_length=100)) 325 except Exception: 326 pass 327 exec(code, globals, locals) 328 return code
A wrapper of builtin exec
function.
331def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 332 """ 333 Generate a function and add it to a Scope's `funcs` dictionary. 334 335 * `serde_scope`: the Scope instance to modify 336 * `func_name`: the name of the function 337 * `func_code`: the source code of the function 338 * `globals`: global variables that should be accessible to the generated function 339 """ 340 341 code = gen(func_code, globals) 342 serde_scope.funcs[func_name] = globals[func_name] 343 344 if SETTINGS["debug"]: 345 serde_scope.code[func_name] = code
Generate a function and add it to a Scope's funcs
dictionary.
serde_scope
: the Scope instance to modifyfunc_name
: the name of the functionfunc_code
: the source code of the functionglobals
: global variables that should be accessible to the generated function
473@dataclass 474class Func: 475 """ 476 Function wrapper that provides `mangled` optional field. 477 478 pyserde copies every function reference into global scope 479 for code generation. Mangling function names is needed in 480 order to avoid name conflict in the global scope when 481 multiple fields receives `skip_if` attribute. 482 """ 483 484 inner: Callable[..., Any] 485 """ Function to wrap in """ 486 487 mangeld: str = "" 488 """ Mangled function name """ 489 490 def __call__(self, v: Any) -> None: 491 return self.inner(v) # type: ignore 492 493 @property 494 def name(self) -> str: 495 """ 496 Mangled function name 497 """ 498 return self.mangeld
Function wrapper that provides mangled
optional field.
pyserde copies every function reference into global scope
for code generation. Mangling function names is needed in
order to avoid name conflict in the global scope when
multiple fields receives skip_if
attribute.
560@dataclass 561class Field(Generic[T]): 562 """ 563 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 564 565 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 566 """ 567 568 type: type[T] 569 """ Type of Field """ 570 name: Optional[str] 571 """ Name of Field """ 572 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 573 """ Default value of Field """ 574 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 575 """ Default factory method of Field """ 576 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 577 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 578 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 579 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 580 metadata: Mapping[str, Any] = field(default_factory=dict) 581 kw_only: bool = False 582 case: Optional[str] = None 583 alias: list[str] = field(default_factory=list) 584 rename: Optional[str] = None 585 skip: Optional[bool] = None 586 skip_if: Optional[Func] = None 587 skip_if_false: Optional[bool] = None 588 skip_if_default: Optional[bool] = None 589 serializer: Optional[Func] = None # Custom field serializer. 590 deserializer: Optional[Func] = None # Custom field deserializer. 591 flatten: Optional[FlattenOpts] = None 592 parent: Optional[Any] = None 593 type_args: Optional[list[str]] = None 594 595 @classmethod 596 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 597 """ 598 Create `Field` object from `dataclasses.Field`. 599 """ 600 skip_if_false_func: Optional[Func] = None 601 if f.metadata.get("serde_skip_if_false"): 602 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 603 604 skip_if_default_func: Optional[Func] = None 605 if f.metadata.get("serde_skip_if_default"): 606 skip_if_def = functools.partial(skip_if_default, default=f.default) 607 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 608 609 skip_if: Optional[Func] = None 610 if f.metadata.get("serde_skip_if"): 611 func = f.metadata.get("serde_skip_if") 612 if callable(func): 613 skip_if = Func(func, cls.mangle(f, "skip_if")) 614 615 serializer: Optional[Func] = None 616 func = f.metadata.get("serde_serializer") 617 if func: 618 serializer = Func(func, cls.mangle(f, "serializer")) 619 620 deserializer: Optional[Func] = None 621 func = f.metadata.get("serde_deserializer") 622 if func: 623 deserializer = Func(func, cls.mangle(f, "deserializer")) 624 625 flatten = f.metadata.get("serde_flatten") 626 if flatten is True: 627 flatten = FlattenOpts() 628 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 629 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 630 631 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 632 633 return cls( 634 f.type, # type: ignore 635 f.name, 636 default=f.default, 637 default_factory=f.default_factory, 638 init=f.init, 639 repr=f.repr, 640 hash=f.hash, 641 compare=f.compare, 642 metadata=f.metadata, 643 rename=f.metadata.get("serde_rename"), 644 alias=f.metadata.get("serde_alias", []), 645 skip=f.metadata.get("serde_skip"), 646 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 647 serializer=serializer, 648 deserializer=deserializer, 649 flatten=flatten, 650 parent=parent, 651 kw_only=kw_only, 652 ) 653 654 def to_dataclass(self) -> dataclasses.Field[T]: 655 f = dataclasses.Field( 656 default=self.default, 657 default_factory=self.default_factory, 658 init=self.init, 659 repr=self.repr, 660 hash=self.hash, 661 compare=self.compare, 662 metadata=self.metadata, 663 kw_only=self.kw_only, 664 ) 665 assert self.name 666 f.name = self.name 667 f.type = self.type 668 return f 669 670 def is_self_referencing(self) -> bool: 671 if self.type is None: 672 return False 673 if self.parent is None: 674 return False 675 return self.type == self.parent # type: ignore 676 677 @staticmethod 678 def mangle(field: dataclasses.Field[Any], name: str) -> str: 679 """ 680 Get mangled name based on field name. 681 """ 682 return f"{field.name}_{name}" 683 684 def conv_name(self, case: Optional[str] = None) -> str: 685 """ 686 Get an actual field name which `rename` and `rename_all` conversions 687 are made. Use `name` property to get a field name before conversion. 688 """ 689 return conv(self, case or self.case) 690 691 def supports_default(self) -> bool: 692 return not getattr(self, "iterbased", False) and ( 693 has_default(self) or has_default_factory(self) 694 )
Field class is similar to dataclasses.Field
. It provides pyserde specific options.
type
, name
, default
and default_factory
are the same members as dataclasses.Field
.
595 @classmethod 596 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 597 """ 598 Create `Field` object from `dataclasses.Field`. 599 """ 600 skip_if_false_func: Optional[Func] = None 601 if f.metadata.get("serde_skip_if_false"): 602 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 603 604 skip_if_default_func: Optional[Func] = None 605 if f.metadata.get("serde_skip_if_default"): 606 skip_if_def = functools.partial(skip_if_default, default=f.default) 607 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 608 609 skip_if: Optional[Func] = None 610 if f.metadata.get("serde_skip_if"): 611 func = f.metadata.get("serde_skip_if") 612 if callable(func): 613 skip_if = Func(func, cls.mangle(f, "skip_if")) 614 615 serializer: Optional[Func] = None 616 func = f.metadata.get("serde_serializer") 617 if func: 618 serializer = Func(func, cls.mangle(f, "serializer")) 619 620 deserializer: Optional[Func] = None 621 func = f.metadata.get("serde_deserializer") 622 if func: 623 deserializer = Func(func, cls.mangle(f, "deserializer")) 624 625 flatten = f.metadata.get("serde_flatten") 626 if flatten is True: 627 flatten = FlattenOpts() 628 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 629 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 630 631 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 632 633 return cls( 634 f.type, # type: ignore 635 f.name, 636 default=f.default, 637 default_factory=f.default_factory, 638 init=f.init, 639 repr=f.repr, 640 hash=f.hash, 641 compare=f.compare, 642 metadata=f.metadata, 643 rename=f.metadata.get("serde_rename"), 644 alias=f.metadata.get("serde_alias", []), 645 skip=f.metadata.get("serde_skip"), 646 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 647 serializer=serializer, 648 deserializer=deserializer, 649 flatten=flatten, 650 parent=parent, 651 kw_only=kw_only, 652 )
Create Field
object from dataclasses.Field
.
654 def to_dataclass(self) -> dataclasses.Field[T]: 655 f = dataclasses.Field( 656 default=self.default, 657 default_factory=self.default_factory, 658 init=self.init, 659 repr=self.repr, 660 hash=self.hash, 661 compare=self.compare, 662 metadata=self.metadata, 663 kw_only=self.kw_only, 664 ) 665 assert self.name 666 f.name = self.name 667 f.type = self.type 668 return f
677 @staticmethod 678 def mangle(field: dataclasses.Field[Any], name: str) -> str: 679 """ 680 Get mangled name based on field name. 681 """ 682 return f"{field.name}_{name}"
Get mangled name based on field name.
700def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 701 """ 702 Iterate fields of the dataclass and returns `serde.core.Field`. 703 """ 704 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 705 706 if serialize_class_var: 707 for name, typ in get_type_hints(cls).items(): 708 if is_class_var(typ): 709 fields.append(field_cls(typ, name, default=getattr(cls, name))) 710 711 return fields # type: ignore
Iterate fields of the dataclass and returns serde.core.Field
.
Flatten options. Currently not used.
714def conv(f: Field[Any], case: Optional[str] = None) -> str: 715 """ 716 Convert field name. 717 """ 718 name = f.name 719 if case: 720 casef = getattr(casefy, case, None) 721 if not casef: 722 raise SerdeError( 723 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 724 ) 725 name = casef(name) 726 if f.rename: 727 name = f.rename 728 if name is None: 729 raise SerdeError("Field name is None.") 730 return name
Convert field name.
733def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 734 """ 735 Generate a function name that contains all union types 736 737 * `prefix` prefix to distinguish between serializing and deserializing 738 * `union_args`: type arguments of a Union 739 740 >>> from ipaddress import IPv4Address 741 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 742 'union_se_int_list_str__IPv4Address' 743 """ 744 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}")
Generate a function name that contains all union types
prefix
prefix to distinguish between serializing and deserializingunion_args
: type arguments of a Union
>>> from ipaddress import IPv4Address
>>> union_func_name("union_se", [int, list[str], IPv4Address])
'union_se_int_list_str__IPv4Address'