serde.core
pyserde core module.
1""" 2pyserde core module. 3""" 4 5from __future__ import annotations 6import dataclasses 7import enum 8import functools 9import logging 10import sys 11import re 12import casefy 13from dataclasses import dataclass 14 15from beartype.door import is_bearable 16from collections.abc import Mapping, Sequence, Callable, Hashable 17from typing import ( 18 overload, 19 TypeVar, 20 Generic, 21 Optional, 22 Any, 23 Protocol, 24 get_type_hints, 25 Union, 26) 27 28from .compat import ( 29 T, 30 SerdeError, 31 dataclass_fields, 32 get_origin, 33 is_bare_dict, 34 is_bare_list, 35 is_bare_set, 36 is_bare_tuple, 37 is_class_var, 38 is_dict, 39 is_generic, 40 is_list, 41 is_literal, 42 is_new_type_primitive, 43 is_any, 44 is_opt, 45 is_opt_dataclass, 46 is_set, 47 is_tuple, 48 is_union, 49 is_variable_tuple, 50 type_args, 51 typename, 52 _WithTagging, 53) 54 55__all__ = [ 56 "Scope", 57 "gen", 58 "add_func", 59 "Func", 60 "Field", 61 "fields", 62 "FlattenOpts", 63 "conv", 64 "union_func_name", 65] 66 67logger = logging.getLogger("serde") 68 69 70# name of the serde context key 71SERDE_SCOPE = "__serde__" 72 73# main function keys 74FROM_ITER = "from_iter" 75FROM_DICT = "from_dict" 76TO_ITER = "to_iter" 77TO_DICT = "to_dict" 78TYPE_CHECK = "typecheck" 79 80# prefixes used to distinguish the direction of a union function 81UNION_SE_PREFIX = "union_se" 82UNION_DE_PREFIX = "union_de" 83 84LITERAL_DE_PREFIX = "literal_de" 85 86SETTINGS = {"debug": False} 87 88 89@dataclass(frozen=True) 90class UnionCacheKey: 91 union: Hashable 92 tagging: Tagging 93 94 95def init(debug: bool = False) -> None: 96 SETTINGS["debug"] = debug 97 98 99@dataclass 100class Cache: 101 """ 102 Cache the generated code for non-dataclass classes. 103 104 for example, a type not bound in a dataclass is passed in from_json 105 106 ``` 107 from_json(Union[Foo, Bar], ...) 108 ``` 109 110 It creates the following wrapper dataclass on the fly, 111 112 ``` 113 @serde 114 @dataclass 115 class Union_Foo_bar: 116 v: Union[Foo, Bar] 117 ``` 118 119 Then store this class in this cache. Whenever the same type is passed, 120 the class is retrieved from this cache. So the overhead of the codegen 121 should be only once. 122 """ 123 124 classes: dict[Hashable, type[Any]] = dataclasses.field(default_factory=dict) 125 126 def _get_class(self, cls: type[Any]) -> type[Any]: 127 """ 128 Get a wrapper class from the the cache. If not found, it will generate 129 the class and store it in the cache. 130 """ 131 wrapper = self.classes.get(cls) # type: ignore[call-overload] # mypy doesn't recognize type[Any] as Hashable 132 return wrapper or self._generate_class(cls) 133 134 def _generate_class(self, cls: type[Any]) -> type[Any]: 135 """ 136 Generate a wrapper dataclass then make the it (de)serializable using 137 @serde decorator. 138 """ 139 from . import serde 140 141 class_name = f"Wrapper{typename(cls)}" 142 logger.debug(f"Generating a wrapper class code for {class_name}") 143 144 wrapper = dataclasses.make_dataclass(class_name, [("v", cls)]) 145 146 serde(wrapper) 147 self.classes[cls] = wrapper # type: ignore[index] # mypy doesn't recognize type[Any] as Hashable 148 149 logger.debug(f"(de)serializing code for {class_name} was generated") 150 return wrapper 151 152 def serialize(self, cls: type[Any], obj: Any, **kwargs: Any) -> Any: 153 """ 154 Serialize the specified type of object into dict or tuple. 155 """ 156 wrapper = self._get_class(cls) 157 scope: Scope = getattr(wrapper, SERDE_SCOPE) 158 data = scope.funcs[TO_DICT](wrapper(obj), **kwargs) 159 160 logging.debug(f"Intermediate value: {data}") 161 162 return data["v"] 163 164 def deserialize(self, cls: type[T], obj: Any) -> T: 165 """ 166 Deserialize from dict or tuple into the specified type. 167 """ 168 wrapper = self._get_class(cls) 169 scope: Scope = getattr(wrapper, SERDE_SCOPE) 170 return scope.funcs[FROM_DICT](data={"v": obj}).v # type: ignore 171 172 def _get_union_class(self, cls: type[Any]) -> Optional[type[Any]]: 173 """ 174 Get a wrapper class from the the cache. If not found, it will generate 175 the class and store it in the cache. 176 """ 177 union_cls, tagging = _extract_from_with_tagging(cls) 178 cache_key = UnionCacheKey(union=union_cls, tagging=tagging) 179 wrapper = self.classes.get(cache_key) 180 return wrapper or self._generate_union_class(cls) 181 182 def _generate_union_class(self, cls: type[Any]) -> type[Any]: 183 """ 184 Generate a wrapper dataclass then make the it (de)serializable using 185 @serde decorator. 186 """ 187 import serde 188 189 union_cls, tagging = _extract_from_with_tagging(cls) 190 cache_key = UnionCacheKey(union=union_cls, tagging=tagging) 191 class_name = union_func_name( 192 f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls)) 193 ) 194 wrapper = dataclasses.make_dataclass(class_name, [("v", union_cls)]) 195 serde.serde(wrapper, tagging=tagging) 196 self.classes[cache_key] = wrapper 197 return wrapper 198 199 def serialize_union(self, cls: type[Any], obj: Any) -> Any: 200 """ 201 Serialize the specified Union into dict or tuple. 202 """ 203 union_cls, _ = _extract_from_with_tagging(cls) 204 wrapper = self._get_union_class(cls) 205 scope: Scope = getattr(wrapper, SERDE_SCOPE) 206 func_name = union_func_name(UNION_SE_PREFIX, list(type_args(union_cls))) 207 return scope.funcs[func_name](obj, False, False) 208 209 def deserialize_union(self, cls: type[T], data: Any) -> T: 210 """ 211 Deserialize from dict or tuple into the specified Union. 212 """ 213 union_cls, _ = _extract_from_with_tagging(cls) 214 wrapper = self._get_union_class(cls) 215 scope: Scope = getattr(wrapper, SERDE_SCOPE) 216 func_name = union_func_name(UNION_DE_PREFIX, list(type_args(union_cls))) 217 return scope.funcs[func_name](cls=union_cls, data=data) # type: ignore 218 219 220def _extract_from_with_tagging(maybe_with_tagging: Any) -> tuple[Any, Tagging]: 221 if isinstance(maybe_with_tagging, _WithTagging): 222 return maybe_with_tagging.inner, maybe_with_tagging.tagging 223 else: 224 return maybe_with_tagging, ExternalTagging 225 226 227CACHE = Cache() 228""" Global cache variable for non-dataclass classes """ 229 230 231@dataclass 232class Scope: 233 """ 234 Container to store types and functions used in code generation context. 235 """ 236 237 cls: type[Any] 238 """ The exact class this scope is for 239 (needed to distinguish scopes between inherited classes) """ 240 241 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 242 """ Generated serialize and deserialize functions """ 243 244 defaults: dict[str, Union[Callable[..., Any], Any]] = dataclasses.field(default_factory=dict) 245 """ Default values of the dataclass fields (factories & normal values) """ 246 247 code: dict[str, str] = dataclasses.field(default_factory=dict) 248 """ Generated source code (only filled when debug is True) """ 249 250 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 251 """ The union serializing functions need references to their types """ 252 253 reuse_instances_default: bool = True 254 """ Default values for to_dict & from_dict arguments """ 255 256 convert_sets_default: bool = False 257 258 def __repr__(self) -> str: 259 res: list[str] = [] 260 261 res.append("==================================================") 262 res.append(self._justify(self.cls.__name__)) 263 res.append("==================================================") 264 res.append("") 265 266 if self.code: 267 res.append("--------------------------------------------------") 268 res.append(self._justify("Functions generated by pyserde")) 269 res.append("--------------------------------------------------") 270 res.extend(list(self.code.values())) 271 res.append("") 272 273 if self.funcs: 274 res.append("--------------------------------------------------") 275 res.append(self._justify("Function references in scope")) 276 res.append("--------------------------------------------------") 277 for k, v in self.funcs.items(): 278 res.append(f"{k}: {v}") 279 res.append("") 280 281 if self.defaults: 282 res.append("--------------------------------------------------") 283 res.append(self._justify("Default values for the dataclass fields")) 284 res.append("--------------------------------------------------") 285 for k, v in self.defaults.items(): 286 res.append(f"{k}: {v}") 287 res.append("") 288 289 if self.union_se_args: 290 res.append("--------------------------------------------------") 291 res.append(self._justify("Type list by used for union serialize functions")) 292 res.append("--------------------------------------------------") 293 for k, lst in self.union_se_args.items(): 294 res.append(f"{k}: {list(lst)}") 295 res.append("") 296 297 return "\n".join(res) 298 299 def _justify(self, s: str, length: int = 50) -> str: 300 white_spaces = int((50 - len(s)) / 2) 301 return " " * (white_spaces if white_spaces > 0 else 0) + s 302 303 304def raise_unsupported_type(obj: Any) -> None: 305 # needed because we can not render a raise statement everywhere, e.g. as argument 306 raise SerdeError(f"Unsupported type: {typename(type(obj))}") 307 308 309def gen( 310 code: str, globals: Optional[dict[str, Any]] = None, locals: Optional[dict[str, Any]] = None 311) -> str: 312 """ 313 A wrapper of builtin `exec` function. 314 """ 315 if SETTINGS["debug"]: 316 # black formatting is only important when debugging 317 try: 318 from black import FileMode, format_str 319 320 code = format_str(code, mode=FileMode(line_length=100)) 321 except Exception: 322 pass 323 exec(code, globals, locals) 324 return code 325 326 327def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 328 """ 329 Generate a function and add it to a Scope's `funcs` dictionary. 330 331 * `serde_scope`: the Scope instance to modify 332 * `func_name`: the name of the function 333 * `func_code`: the source code of the function 334 * `globals`: global variables that should be accessible to the generated function 335 """ 336 337 code = gen(func_code, globals) 338 serde_scope.funcs[func_name] = globals[func_name] 339 340 if SETTINGS["debug"]: 341 serde_scope.code[func_name] = code 342 343 344def is_instance(obj: Any, typ: Any) -> bool: 345 """ 346 pyserde's own `isinstance` helper. It accepts subscripted generics e.g. `list[int]` and 347 deeply check object against declared type. 348 """ 349 if dataclasses.is_dataclass(typ): 350 if not isinstance(typ, type): 351 raise SerdeError("expect dataclass class but dataclass instance received") 352 return isinstance(obj, typ) 353 elif is_opt(typ): 354 return is_opt_instance(obj, typ) 355 elif is_union(typ): 356 return is_union_instance(obj, typ) 357 elif is_list(typ): 358 return is_list_instance(obj, typ) 359 elif is_set(typ): 360 return is_set_instance(obj, typ) 361 elif is_tuple(typ): 362 return is_tuple_instance(obj, typ) 363 elif is_dict(typ): 364 return is_dict_instance(obj, typ) 365 elif is_generic(typ): 366 return is_generic_instance(obj, typ) 367 elif is_literal(typ): 368 return True 369 elif is_new_type_primitive(typ): 370 inner = getattr(typ, "__supertype__", None) 371 if type(inner) is type: 372 return isinstance(obj, inner) 373 else: 374 return False 375 elif is_any(typ): 376 return True 377 elif typ is Ellipsis: 378 return True 379 else: 380 return is_bearable(obj, typ) 381 382 383def is_opt_instance(obj: Any, typ: type[Any]) -> bool: 384 if obj is None: 385 return True 386 opt_arg = type_args(typ)[0] 387 return is_instance(obj, opt_arg) 388 389 390def is_union_instance(obj: Any, typ: type[Any]) -> bool: 391 for arg in type_args(typ): 392 if is_instance(obj, arg): 393 return True 394 return False 395 396 397def is_list_instance(obj: Any, typ: type[Any]) -> bool: 398 if not isinstance(obj, list): 399 return False 400 if len(obj) == 0 or is_bare_list(typ): 401 return True 402 list_arg = type_args(typ)[0] 403 # for speed reasons we just check the type of the 1st element 404 return is_instance(obj[0], list_arg) 405 406 407def is_set_instance(obj: Any, typ: type[Any]) -> bool: 408 if not isinstance(obj, (set, frozenset)): 409 return False 410 if len(obj) == 0 or is_bare_set(typ): 411 return True 412 set_arg = type_args(typ)[0] 413 # for speed reasons we just check the type of the 1st element 414 return is_instance(next(iter(obj)), set_arg) 415 416 417def is_tuple_instance(obj: Any, typ: type[Any]) -> bool: 418 args = type_args(typ) 419 420 if not isinstance(obj, tuple): 421 return False 422 423 # empty tuple 424 if len(args) == 0 and len(obj) == 0: 425 return True 426 427 # In the form of tuple[T, ...] 428 elif is_variable_tuple(typ): 429 # Get the first type arg. Since tuple[T, ...] is homogeneous tuple, 430 # all the elements should be of this type. 431 arg = type_args(typ)[0] 432 for v in obj: 433 if not is_instance(v, arg): 434 return False 435 return True 436 437 # bare tuple "tuple" is equivalent to tuple[Any, ...] 438 if is_bare_tuple(typ) and isinstance(obj, tuple): 439 return True 440 441 # All the other tuples e.g. tuple[int, str] 442 if len(obj) == len(args): 443 for element, arg in zip(obj, args): 444 if not is_instance(element, arg): 445 return False 446 else: 447 return False 448 449 return True 450 451 452def is_dict_instance(obj: Any, typ: type[Any]) -> bool: 453 if not isinstance(obj, dict): 454 return False 455 if len(obj) == 0 or is_bare_dict(typ): 456 return True 457 ktyp = type_args(typ)[0] 458 vtyp = type_args(typ)[1] 459 for k, v in obj.items(): 460 # for speed reasons we just check the type of the 1st element 461 return is_instance(k, ktyp) and is_instance(v, vtyp) 462 return False 463 464 465def is_generic_instance(obj: Any, typ: type[Any]) -> bool: 466 return is_instance(obj, get_origin(typ)) 467 468 469@dataclass 470class Func: 471 """ 472 Function wrapper that provides `mangled` optional field. 473 474 pyserde copies every function reference into global scope 475 for code generation. Mangling function names is needed in 476 order to avoid name conflict in the global scope when 477 multiple fields receives `skip_if` attribute. 478 """ 479 480 inner: Callable[..., Any] 481 """ Function to wrap in """ 482 483 mangeld: str = "" 484 """ Mangled function name """ 485 486 def __call__(self, v: Any) -> None: 487 return self.inner(v) # type: ignore 488 489 @property 490 def name(self) -> str: 491 """ 492 Mangled function name 493 """ 494 return self.mangeld 495 496 497def skip_if_false(v: Any) -> Any: 498 return not bool(v) 499 500 501def skip_if_default(v: Any, default: Optional[Any] = None) -> Any: 502 return v == default # Why return type is deduced to be Any? 503 504 505@dataclass 506class FlattenOpts: 507 """ 508 Flatten options. Currently not used. 509 """ 510 511 512def field( 513 *args: Any, 514 rename: Optional[str] = None, 515 alias: Optional[list[str]] = None, 516 skip: Optional[bool] = None, 517 skip_if: Optional[Callable[[Any], Any]] = None, 518 skip_if_false: Optional[bool] = None, 519 skip_if_default: Optional[bool] = None, 520 serializer: Optional[Callable[..., Any]] = None, 521 deserializer: Optional[Callable[..., Any]] = None, 522 flatten: Optional[Union[FlattenOpts, bool]] = None, 523 metadata: Optional[dict[str, Any]] = None, 524 **kwargs: Any, 525) -> Any: 526 """ 527 Declare a field with parameters. 528 """ 529 if not metadata: 530 metadata = {} 531 532 if rename is not None: 533 metadata["serde_rename"] = rename 534 if alias is not None: 535 metadata["serde_alias"] = alias 536 if skip is not None: 537 metadata["serde_skip"] = skip 538 if skip_if is not None: 539 metadata["serde_skip_if"] = skip_if 540 if skip_if_false is not None: 541 metadata["serde_skip_if_false"] = skip_if_false 542 if skip_if_default is not None: 543 metadata["serde_skip_if_default"] = skip_if_default 544 if serializer: 545 metadata["serde_serializer"] = serializer 546 if deserializer: 547 metadata["serde_deserializer"] = deserializer 548 if flatten is True: 549 metadata["serde_flatten"] = FlattenOpts() 550 elif flatten: 551 metadata["serde_flatten"] = flatten 552 553 return dataclasses.field(*args, metadata=metadata, **kwargs) 554 555 556@dataclass 557class Field(Generic[T]): 558 """ 559 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 560 561 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 562 """ 563 564 type: type[T] 565 """ Type of Field """ 566 name: Optional[str] 567 """ Name of Field """ 568 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 569 """ Default value of Field """ 570 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 571 """ Default factory method of Field """ 572 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 573 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 574 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 575 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 576 metadata: Mapping[str, Any] = field(default_factory=dict) 577 kw_only: bool = False 578 case: Optional[str] = None 579 alias: list[str] = field(default_factory=list) 580 rename: Optional[str] = None 581 skip: Optional[bool] = None 582 skip_if: Optional[Func] = None 583 skip_if_false: Optional[bool] = None 584 skip_if_default: Optional[bool] = None 585 serializer: Optional[Func] = None # Custom field serializer. 586 deserializer: Optional[Func] = None # Custom field deserializer. 587 flatten: Optional[FlattenOpts] = None 588 parent: Optional[Any] = None 589 type_args: Optional[list[str]] = None 590 591 @classmethod 592 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 593 """ 594 Create `Field` object from `dataclasses.Field`. 595 """ 596 skip_if_false_func: Optional[Func] = None 597 if f.metadata.get("serde_skip_if_false"): 598 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 599 600 skip_if_default_func: Optional[Func] = None 601 if f.metadata.get("serde_skip_if_default"): 602 skip_if_def = functools.partial(skip_if_default, default=f.default) 603 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 604 605 skip_if: Optional[Func] = None 606 if f.metadata.get("serde_skip_if"): 607 func = f.metadata.get("serde_skip_if") 608 if callable(func): 609 skip_if = Func(func, cls.mangle(f, "skip_if")) 610 611 serializer: Optional[Func] = None 612 func = f.metadata.get("serde_serializer") 613 if func: 614 serializer = Func(func, cls.mangle(f, "serializer")) 615 616 deserializer: Optional[Func] = None 617 func = f.metadata.get("serde_deserializer") 618 if func: 619 deserializer = Func(func, cls.mangle(f, "deserializer")) 620 621 flatten = f.metadata.get("serde_flatten") 622 if flatten is True: 623 flatten = FlattenOpts() 624 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 625 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 626 627 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 628 629 return cls( 630 f.type, # type: ignore 631 f.name, 632 default=f.default, 633 default_factory=f.default_factory, 634 init=f.init, 635 repr=f.repr, 636 hash=f.hash, 637 compare=f.compare, 638 metadata=f.metadata, 639 rename=f.metadata.get("serde_rename"), 640 alias=f.metadata.get("serde_alias", []), 641 skip=f.metadata.get("serde_skip"), 642 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 643 serializer=serializer, 644 deserializer=deserializer, 645 flatten=flatten, 646 parent=parent, 647 kw_only=kw_only, 648 ) 649 650 def to_dataclass(self) -> dataclasses.Field[T]: 651 f = dataclasses.Field( 652 default=self.default, 653 default_factory=self.default_factory, 654 init=self.init, 655 repr=self.repr, 656 hash=self.hash, 657 compare=self.compare, 658 metadata=self.metadata, 659 kw_only=self.kw_only, 660 ) 661 assert self.name 662 f.name = self.name 663 f.type = self.type 664 return f 665 666 def is_self_referencing(self) -> bool: 667 if self.type is None: 668 return False 669 if self.parent is None: 670 return False 671 return self.type == self.parent # type: ignore 672 673 @staticmethod 674 def mangle(field: dataclasses.Field[Any], name: str) -> str: 675 """ 676 Get mangled name based on field name. 677 """ 678 return f"{field.name}_{name}" 679 680 def conv_name(self, case: Optional[str] = None) -> str: 681 """ 682 Get an actual field name which `rename` and `rename_all` conversions 683 are made. Use `name` property to get a field name before conversion. 684 """ 685 return conv(self, case or self.case) 686 687 def supports_default(self) -> bool: 688 return not getattr(self, "iterbased", False) and ( 689 has_default(self) or has_default_factory(self) 690 ) 691 692 693F = TypeVar("F", bound=Field[Any]) 694 695 696def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 697 """ 698 Iterate fields of the dataclass and returns `serde.core.Field`. 699 """ 700 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 701 702 if serialize_class_var: 703 for name, typ in get_type_hints(cls).items(): 704 if is_class_var(typ): 705 fields.append(field_cls(typ, name, default=getattr(cls, name))) 706 707 return fields # type: ignore 708 709 710def conv(f: Field[Any], case: Optional[str] = None) -> str: 711 """ 712 Convert field name. 713 """ 714 name = f.name 715 if case: 716 casef = getattr(casefy, case, None) 717 if not casef: 718 raise SerdeError( 719 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 720 ) 721 name = casef(name) 722 if f.rename: 723 name = f.rename 724 if name is None: 725 raise SerdeError("Field name is None.") 726 return name 727 728 729def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 730 """ 731 Generate a function name that contains all union types 732 733 * `prefix` prefix to distinguish between serializing and deserializing 734 * `union_args`: type arguments of a Union 735 736 >>> from ipaddress import IPv4Address 737 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 738 'union_se_int_list_str__IPv4Address' 739 """ 740 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}") 741 742 743def literal_func_name(literal_args: Sequence[Any]) -> str: 744 """ 745 Generate a function name with all literals and corresponding types specified with Literal[...] 746 747 748 * `literal_args`: arguments of a Literal 749 750 >>> literal_func_name(["r", "w", "a", "x", "r+", "w+", "a+", "x+"]) 751 'literal_de_r_str_w_str_a_str_x_str_r__str_w__str_a__str_x__str' 752 """ 753 return re.sub( 754 r"[^A-Za-z0-9]", 755 "_", 756 f"{LITERAL_DE_PREFIX}_{'_'.join(f'{a}_{typename(type(a))}' for a in literal_args)}", 757 ) 758 759 760@dataclass(frozen=True) 761class Tagging: 762 """ 763 Controls how union is (de)serialized. This is the same concept as in 764 https://serde.rs/enum-representations.html 765 """ 766 767 class Kind(enum.Enum): 768 External = enum.auto() 769 Internal = enum.auto() 770 Adjacent = enum.auto() 771 Untagged = enum.auto() 772 773 tag: Optional[str] = None 774 content: Optional[str] = None 775 kind: Kind = Kind.External 776 777 def is_external(self) -> bool: 778 return self.kind == self.Kind.External 779 780 def is_internal(self) -> bool: 781 return self.kind == self.Kind.Internal 782 783 def is_adjacent(self) -> bool: 784 return self.kind == self.Kind.Adjacent 785 786 def is_untagged(self) -> bool: 787 return self.kind == self.Kind.Untagged 788 789 @classmethod 790 def is_taggable(cls, typ: type[Any]) -> bool: 791 return dataclasses.is_dataclass(typ) 792 793 def check(self) -> None: 794 if self.is_internal() and self.tag is None: 795 raise SerdeError('"tag" must be specified in InternalTagging') 796 if self.is_adjacent() and (self.tag is None or self.content is None): 797 raise SerdeError('"tag" and "content" must be specified in AdjacentTagging') 798 799 def produce_unique_class_name(self) -> str: 800 """ 801 Produce a unique class name for this tagging. The name is used for generated 802 wrapper dataclass and stored in `Cache`. 803 """ 804 if self.is_internal(): 805 tag = casefy.pascalcase(self.tag) # type: ignore 806 if not tag: 807 raise SerdeError('"tag" must be specified in InternalTagging') 808 return f"Internal{tag}" 809 elif self.is_adjacent(): 810 tag = casefy.pascalcase(self.tag) # type: ignore 811 content = casefy.pascalcase(self.content) # type: ignore 812 if not tag: 813 raise SerdeError('"tag" must be specified in AdjacentTagging') 814 if not content: 815 raise SerdeError('"content" must be specified in AdjacentTagging') 816 return f"Adjacent{tag}{content}" 817 else: 818 return self.kind.name 819 820 def __call__(self, cls: T) -> _WithTagging[T]: 821 return _WithTagging(cls, self) 822 823 824@overload 825def InternalTagging(tag: str) -> Tagging: ... 826 827 828@overload 829def InternalTagging(tag: str, cls: T) -> _WithTagging[T]: ... 830 831 832def InternalTagging(tag: str, cls: Optional[T] = None) -> Union[Tagging, _WithTagging[T]]: 833 tagging = Tagging(tag, kind=Tagging.Kind.Internal) 834 if cls: 835 return tagging(cls) 836 else: 837 return tagging 838 839 840@overload 841def AdjacentTagging(tag: str, content: str) -> Tagging: ... 842 843 844@overload 845def AdjacentTagging(tag: str, content: str, cls: T) -> _WithTagging[T]: ... 846 847 848def AdjacentTagging( 849 tag: str, content: str, cls: Optional[T] = None 850) -> Union[Tagging, _WithTagging[T]]: 851 tagging = Tagging(tag, content, kind=Tagging.Kind.Adjacent) 852 if cls: 853 return tagging(cls) 854 else: 855 return tagging 856 857 858ExternalTagging = Tagging() 859 860Untagged = Tagging(kind=Tagging.Kind.Untagged) 861 862 863DefaultTagging = ExternalTagging 864 865 866def ensure(expr: Any, description: str) -> None: 867 if not expr: 868 raise Exception(description) 869 870 871def should_impl_dataclass(cls: type[Any]) -> bool: 872 """ 873 Test if class doesn't have @dataclass. 874 875 `dataclasses.is_dataclass` returns True even Derived class doesn't actually @dataclass. 876 >>> @dataclasses.dataclass 877 ... class Base: 878 ... a: int 879 >>> class Derived(Base): 880 ... b: int 881 >>> dataclasses.is_dataclass(Derived) 882 True 883 884 This function tells the class actually have @dataclass or not. 885 >>> should_impl_dataclass(Base) 886 False 887 >>> should_impl_dataclass(Derived) 888 True 889 """ 890 if not dataclasses.is_dataclass(cls): 891 return True 892 893 # Checking is_dataclass is not enough in such a case that the class is inherited 894 # from another dataclass. To do it correctly, check if all fields in __annotations__ 895 # are present as dataclass fields. 896 annotations = getattr(cls, "__annotations__", {}) 897 if not annotations: 898 return False 899 900 field_names = [field.name for field in dataclass_fields(cls)] 901 for field_name, annotation in annotations.items(): 902 # Omit InitVar field because it doesn't appear in dataclass fields. 903 if is_instance(annotation, dataclasses.InitVar): 904 continue 905 # This field in __annotations__ is not a part of dataclass fields. 906 # This means this class does not implement dataclass directly. 907 if field_name not in field_names: 908 return True 909 910 # If all of the fields in __annotation__ are present as dataclass fields, 911 # the class already implemented dataclass, thus returns False. 912 return False 913 914 915@dataclass 916class TypeCheck: 917 """ 918 Specify type check flavors. 919 """ 920 921 class Kind(enum.Enum): 922 Disabled = enum.auto() 923 """ No check performed """ 924 925 Coerce = enum.auto() 926 """ Value is coerced into the declared type """ 927 Strict = enum.auto() 928 """ Value are strictly checked against the declared type """ 929 930 kind: Kind 931 932 def is_strict(self) -> bool: 933 return self.kind == self.Kind.Strict 934 935 def is_coerce(self) -> bool: 936 return self.kind == self.Kind.Coerce 937 938 def __call__(self, **kwargs: Any) -> TypeCheck: 939 # TODO 940 return self 941 942 943disabled = TypeCheck(kind=TypeCheck.Kind.Disabled) 944 945coerce = TypeCheck(kind=TypeCheck.Kind.Coerce) 946 947strict = TypeCheck(kind=TypeCheck.Kind.Strict) 948 949 950def coerce_object(cls: str, field: str, typ: type[Any], obj: Any) -> Any: 951 try: 952 return typ(obj) if is_coercible(typ, obj) else obj 953 except Exception as e: 954 raise SerdeError( 955 f"failed to coerce the field {cls}.{field} value {obj} into {typename(typ)}: {e}" 956 ) 957 958 959def is_coercible(typ: type[Any], obj: Any) -> bool: 960 if obj is None: 961 return False 962 return True 963 964 965def has_default(field: Field[Any]) -> bool: 966 """ 967 Test if the field has default value. 968 969 >>> @dataclasses.dataclass 970 ... class C: 971 ... a: int 972 ... d: int = 10 973 >>> has_default(dataclasses.fields(C)[0]) 974 False 975 >>> has_default(dataclasses.fields(C)[1]) 976 True 977 """ 978 return not isinstance(field.default, dataclasses._MISSING_TYPE) 979 980 981def has_default_factory(field: Field[Any]) -> bool: 982 """ 983 Test if the field has default factory. 984 985 >>> @dataclasses.dataclass 986 ... class C: 987 ... a: int 988 ... d: dict = dataclasses.field(default_factory=dict) 989 >>> has_default_factory(dataclasses.fields(C)[0]) 990 False 991 >>> has_default_factory(dataclasses.fields(C)[1]) 992 True 993 """ 994 return not isinstance(field.default_factory, dataclasses._MISSING_TYPE) 995 996 997class ClassSerializer(Protocol): 998 """ 999 Interface for custom class serializer. 1000 1001 This protocol is intended to be used for custom class serializer. 1002 1003 >>> from datetime import datetime 1004 >>> from serde import serde 1005 >>> from plum import dispatch 1006 >>> class MySerializer(ClassSerializer): 1007 ... @dispatch 1008 ... def serialize(self, value: datetime) -> str: 1009 ... return value.strftime("%d/%m/%y") 1010 """ 1011 1012 def serialize(self, value: Any) -> Any: 1013 pass 1014 1015 1016class ClassDeserializer(Protocol): 1017 """ 1018 Interface for custom class deserializer. 1019 1020 This protocol is intended to be used for custom class deserializer. 1021 1022 >>> from datetime import datetime 1023 >>> from serde import serde 1024 >>> from plum import dispatch 1025 >>> class MyDeserializer(ClassDeserializer): 1026 ... @dispatch 1027 ... def deserialize(self, cls: type[datetime], value: Any) -> datetime: 1028 ... return datetime.strptime(value, "%d/%m/%y") 1029 """ 1030 1031 def deserialize(self, cls: Any, value: Any) -> Any: 1032 pass 1033 1034 1035GLOBAL_CLASS_SERIALIZER: list[ClassSerializer] = [] 1036 1037GLOBAL_CLASS_DESERIALIZER: list[ClassDeserializer] = [] 1038 1039 1040def add_serializer(serializer: ClassSerializer) -> None: 1041 """ 1042 Register custom global serializer. 1043 """ 1044 GLOBAL_CLASS_SERIALIZER.append(serializer) 1045 1046 1047def add_deserializer(deserializer: ClassDeserializer) -> None: 1048 """ 1049 Register custom global deserializer. 1050 """ 1051 GLOBAL_CLASS_DESERIALIZER.append(deserializer)
232@dataclass 233class Scope: 234 """ 235 Container to store types and functions used in code generation context. 236 """ 237 238 cls: type[Any] 239 """ The exact class this scope is for 240 (needed to distinguish scopes between inherited classes) """ 241 242 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 243 """ Generated serialize and deserialize functions """ 244 245 defaults: dict[str, Union[Callable[..., Any], Any]] = dataclasses.field(default_factory=dict) 246 """ Default values of the dataclass fields (factories & normal values) """ 247 248 code: dict[str, str] = dataclasses.field(default_factory=dict) 249 """ Generated source code (only filled when debug is True) """ 250 251 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 252 """ The union serializing functions need references to their types """ 253 254 reuse_instances_default: bool = True 255 """ Default values for to_dict & from_dict arguments """ 256 257 convert_sets_default: bool = False 258 259 def __repr__(self) -> str: 260 res: list[str] = [] 261 262 res.append("==================================================") 263 res.append(self._justify(self.cls.__name__)) 264 res.append("==================================================") 265 res.append("") 266 267 if self.code: 268 res.append("--------------------------------------------------") 269 res.append(self._justify("Functions generated by pyserde")) 270 res.append("--------------------------------------------------") 271 res.extend(list(self.code.values())) 272 res.append("") 273 274 if self.funcs: 275 res.append("--------------------------------------------------") 276 res.append(self._justify("Function references in scope")) 277 res.append("--------------------------------------------------") 278 for k, v in self.funcs.items(): 279 res.append(f"{k}: {v}") 280 res.append("") 281 282 if self.defaults: 283 res.append("--------------------------------------------------") 284 res.append(self._justify("Default values for the dataclass fields")) 285 res.append("--------------------------------------------------") 286 for k, v in self.defaults.items(): 287 res.append(f"{k}: {v}") 288 res.append("") 289 290 if self.union_se_args: 291 res.append("--------------------------------------------------") 292 res.append(self._justify("Type list by used for union serialize functions")) 293 res.append("--------------------------------------------------") 294 for k, lst in self.union_se_args.items(): 295 res.append(f"{k}: {list(lst)}") 296 res.append("") 297 298 return "\n".join(res) 299 300 def _justify(self, s: str, length: int = 50) -> str: 301 white_spaces = int((50 - len(s)) / 2) 302 return " " * (white_spaces if white_spaces > 0 else 0) + s
Container to store types and functions used in code generation context.
The exact class this scope is for (needed to distinguish scopes between inherited classes)
Default values of the dataclass fields (factories & normal values)
310def gen( 311 code: str, globals: Optional[dict[str, Any]] = None, locals: Optional[dict[str, Any]] = None 312) -> str: 313 """ 314 A wrapper of builtin `exec` function. 315 """ 316 if SETTINGS["debug"]: 317 # black formatting is only important when debugging 318 try: 319 from black import FileMode, format_str 320 321 code = format_str(code, mode=FileMode(line_length=100)) 322 except Exception: 323 pass 324 exec(code, globals, locals) 325 return code
A wrapper of builtin exec
function.
328def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 329 """ 330 Generate a function and add it to a Scope's `funcs` dictionary. 331 332 * `serde_scope`: the Scope instance to modify 333 * `func_name`: the name of the function 334 * `func_code`: the source code of the function 335 * `globals`: global variables that should be accessible to the generated function 336 """ 337 338 code = gen(func_code, globals) 339 serde_scope.funcs[func_name] = globals[func_name] 340 341 if SETTINGS["debug"]: 342 serde_scope.code[func_name] = code
Generate a function and add it to a Scope's funcs
dictionary.
serde_scope
: the Scope instance to modifyfunc_name
: the name of the functionfunc_code
: the source code of the functionglobals
: global variables that should be accessible to the generated function
470@dataclass 471class Func: 472 """ 473 Function wrapper that provides `mangled` optional field. 474 475 pyserde copies every function reference into global scope 476 for code generation. Mangling function names is needed in 477 order to avoid name conflict in the global scope when 478 multiple fields receives `skip_if` attribute. 479 """ 480 481 inner: Callable[..., Any] 482 """ Function to wrap in """ 483 484 mangeld: str = "" 485 """ Mangled function name """ 486 487 def __call__(self, v: Any) -> None: 488 return self.inner(v) # type: ignore 489 490 @property 491 def name(self) -> str: 492 """ 493 Mangled function name 494 """ 495 return self.mangeld
Function wrapper that provides mangled
optional field.
pyserde copies every function reference into global scope
for code generation. Mangling function names is needed in
order to avoid name conflict in the global scope when
multiple fields receives skip_if
attribute.
557@dataclass 558class Field(Generic[T]): 559 """ 560 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 561 562 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 563 """ 564 565 type: type[T] 566 """ Type of Field """ 567 name: Optional[str] 568 """ Name of Field """ 569 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 570 """ Default value of Field """ 571 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 572 """ Default factory method of Field """ 573 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 574 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 575 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 576 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 577 metadata: Mapping[str, Any] = field(default_factory=dict) 578 kw_only: bool = False 579 case: Optional[str] = None 580 alias: list[str] = field(default_factory=list) 581 rename: Optional[str] = None 582 skip: Optional[bool] = None 583 skip_if: Optional[Func] = None 584 skip_if_false: Optional[bool] = None 585 skip_if_default: Optional[bool] = None 586 serializer: Optional[Func] = None # Custom field serializer. 587 deserializer: Optional[Func] = None # Custom field deserializer. 588 flatten: Optional[FlattenOpts] = None 589 parent: Optional[Any] = None 590 type_args: Optional[list[str]] = None 591 592 @classmethod 593 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 594 """ 595 Create `Field` object from `dataclasses.Field`. 596 """ 597 skip_if_false_func: Optional[Func] = None 598 if f.metadata.get("serde_skip_if_false"): 599 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 600 601 skip_if_default_func: Optional[Func] = None 602 if f.metadata.get("serde_skip_if_default"): 603 skip_if_def = functools.partial(skip_if_default, default=f.default) 604 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 605 606 skip_if: Optional[Func] = None 607 if f.metadata.get("serde_skip_if"): 608 func = f.metadata.get("serde_skip_if") 609 if callable(func): 610 skip_if = Func(func, cls.mangle(f, "skip_if")) 611 612 serializer: Optional[Func] = None 613 func = f.metadata.get("serde_serializer") 614 if func: 615 serializer = Func(func, cls.mangle(f, "serializer")) 616 617 deserializer: Optional[Func] = None 618 func = f.metadata.get("serde_deserializer") 619 if func: 620 deserializer = Func(func, cls.mangle(f, "deserializer")) 621 622 flatten = f.metadata.get("serde_flatten") 623 if flatten is True: 624 flatten = FlattenOpts() 625 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 626 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 627 628 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 629 630 return cls( 631 f.type, # type: ignore 632 f.name, 633 default=f.default, 634 default_factory=f.default_factory, 635 init=f.init, 636 repr=f.repr, 637 hash=f.hash, 638 compare=f.compare, 639 metadata=f.metadata, 640 rename=f.metadata.get("serde_rename"), 641 alias=f.metadata.get("serde_alias", []), 642 skip=f.metadata.get("serde_skip"), 643 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 644 serializer=serializer, 645 deserializer=deserializer, 646 flatten=flatten, 647 parent=parent, 648 kw_only=kw_only, 649 ) 650 651 def to_dataclass(self) -> dataclasses.Field[T]: 652 f = dataclasses.Field( 653 default=self.default, 654 default_factory=self.default_factory, 655 init=self.init, 656 repr=self.repr, 657 hash=self.hash, 658 compare=self.compare, 659 metadata=self.metadata, 660 kw_only=self.kw_only, 661 ) 662 assert self.name 663 f.name = self.name 664 f.type = self.type 665 return f 666 667 def is_self_referencing(self) -> bool: 668 if self.type is None: 669 return False 670 if self.parent is None: 671 return False 672 return self.type == self.parent # type: ignore 673 674 @staticmethod 675 def mangle(field: dataclasses.Field[Any], name: str) -> str: 676 """ 677 Get mangled name based on field name. 678 """ 679 return f"{field.name}_{name}" 680 681 def conv_name(self, case: Optional[str] = None) -> str: 682 """ 683 Get an actual field name which `rename` and `rename_all` conversions 684 are made. Use `name` property to get a field name before conversion. 685 """ 686 return conv(self, case or self.case) 687 688 def supports_default(self) -> bool: 689 return not getattr(self, "iterbased", False) and ( 690 has_default(self) or has_default_factory(self) 691 )
Field class is similar to dataclasses.Field
. It provides pyserde specific options.
type
, name
, default
and default_factory
are the same members as dataclasses.Field
.
592 @classmethod 593 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 594 """ 595 Create `Field` object from `dataclasses.Field`. 596 """ 597 skip_if_false_func: Optional[Func] = None 598 if f.metadata.get("serde_skip_if_false"): 599 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 600 601 skip_if_default_func: Optional[Func] = None 602 if f.metadata.get("serde_skip_if_default"): 603 skip_if_def = functools.partial(skip_if_default, default=f.default) 604 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 605 606 skip_if: Optional[Func] = None 607 if f.metadata.get("serde_skip_if"): 608 func = f.metadata.get("serde_skip_if") 609 if callable(func): 610 skip_if = Func(func, cls.mangle(f, "skip_if")) 611 612 serializer: Optional[Func] = None 613 func = f.metadata.get("serde_serializer") 614 if func: 615 serializer = Func(func, cls.mangle(f, "serializer")) 616 617 deserializer: Optional[Func] = None 618 func = f.metadata.get("serde_deserializer") 619 if func: 620 deserializer = Func(func, cls.mangle(f, "deserializer")) 621 622 flatten = f.metadata.get("serde_flatten") 623 if flatten is True: 624 flatten = FlattenOpts() 625 if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)): 626 raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}") 627 628 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 629 630 return cls( 631 f.type, # type: ignore 632 f.name, 633 default=f.default, 634 default_factory=f.default_factory, 635 init=f.init, 636 repr=f.repr, 637 hash=f.hash, 638 compare=f.compare, 639 metadata=f.metadata, 640 rename=f.metadata.get("serde_rename"), 641 alias=f.metadata.get("serde_alias", []), 642 skip=f.metadata.get("serde_skip"), 643 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 644 serializer=serializer, 645 deserializer=deserializer, 646 flatten=flatten, 647 parent=parent, 648 kw_only=kw_only, 649 )
Create Field
object from dataclasses.Field
.
651 def to_dataclass(self) -> dataclasses.Field[T]: 652 f = dataclasses.Field( 653 default=self.default, 654 default_factory=self.default_factory, 655 init=self.init, 656 repr=self.repr, 657 hash=self.hash, 658 compare=self.compare, 659 metadata=self.metadata, 660 kw_only=self.kw_only, 661 ) 662 assert self.name 663 f.name = self.name 664 f.type = self.type 665 return f
674 @staticmethod 675 def mangle(field: dataclasses.Field[Any], name: str) -> str: 676 """ 677 Get mangled name based on field name. 678 """ 679 return f"{field.name}_{name}"
Get mangled name based on field name.
697def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 698 """ 699 Iterate fields of the dataclass and returns `serde.core.Field`. 700 """ 701 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 702 703 if serialize_class_var: 704 for name, typ in get_type_hints(cls).items(): 705 if is_class_var(typ): 706 fields.append(field_cls(typ, name, default=getattr(cls, name))) 707 708 return fields # type: ignore
Iterate fields of the dataclass and returns serde.core.Field
.
Flatten options. Currently not used.
711def conv(f: Field[Any], case: Optional[str] = None) -> str: 712 """ 713 Convert field name. 714 """ 715 name = f.name 716 if case: 717 casef = getattr(casefy, case, None) 718 if not casef: 719 raise SerdeError( 720 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 721 ) 722 name = casef(name) 723 if f.rename: 724 name = f.rename 725 if name is None: 726 raise SerdeError("Field name is None.") 727 return name
Convert field name.
730def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 731 """ 732 Generate a function name that contains all union types 733 734 * `prefix` prefix to distinguish between serializing and deserializing 735 * `union_args`: type arguments of a Union 736 737 >>> from ipaddress import IPv4Address 738 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 739 'union_se_int_list_str__IPv4Address' 740 """ 741 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}")
Generate a function name that contains all union types
prefix
prefix to distinguish between serializing and deserializingunion_args
: type arguments of a Union
>>> from ipaddress import IPv4Address
>>> union_func_name("union_se", [int, list[str], IPv4Address])
'union_se_int_list_str__IPv4Address'