serde.core
pyserde core module.
1""" 2pyserde core module. 3""" 4 5from __future__ import annotations 6import dataclasses 7import enum 8import functools 9import logging 10import sys 11import re 12import casefy 13from dataclasses import dataclass 14from beartype.door import is_bearable 15from collections.abc import Mapping, Sequence, Callable 16from typing import ( 17 overload, 18 TypeVar, 19 Generic, 20 Optional, 21 Any, 22 Protocol, 23 get_type_hints, 24 Union, 25) 26 27from .compat import ( 28 T, 29 SerdeError, 30 dataclass_fields, 31 get_origin, 32 is_bare_dict, 33 is_bare_list, 34 is_bare_set, 35 is_bare_tuple, 36 is_class_var, 37 is_dict, 38 is_generic, 39 is_list, 40 is_literal, 41 is_new_type_primitive, 42 is_any, 43 is_opt, 44 is_set, 45 is_tuple, 46 is_union, 47 is_variable_tuple, 48 type_args, 49 typename, 50 _WithTagging, 51) 52 53__all__ = [ 54 "Scope", 55 "gen", 56 "add_func", 57 "Func", 58 "Field", 59 "fields", 60 "FlattenOpts", 61 "conv", 62 "union_func_name", 63] 64 65logger = logging.getLogger("serde") 66 67 68# name of the serde context key 69SERDE_SCOPE = "__serde__" 70 71# main function keys 72FROM_ITER = "from_iter" 73FROM_DICT = "from_dict" 74TO_ITER = "to_iter" 75TO_DICT = "to_dict" 76TYPE_CHECK = "typecheck" 77 78# prefixes used to distinguish the direction of a union function 79UNION_SE_PREFIX = "union_se" 80UNION_DE_PREFIX = "union_de" 81 82LITERAL_DE_PREFIX = "literal_de" 83 84SETTINGS = {"debug": False} 85 86 87def init(debug: bool = False) -> None: 88 SETTINGS["debug"] = debug 89 90 91@dataclass 92class Cache: 93 """ 94 Cache the generated code for non-dataclass classes. 95 96 for example, a type not bound in a dataclass is passed in from_json 97 98 ``` 99 from_json(Union[Foo, Bar], ...) 100 ``` 101 102 It creates the following wrapper dataclass on the fly, 103 104 ``` 105 @serde 106 @dataclass 107 class Union_Foo_bar: 108 v: Union[Foo, Bar] 109 ``` 110 111 Then store this class in this cache. Whenever the same type is passed, 112 the class is retrieved from this cache. So the overhead of the codegen 113 should be only once. 114 """ 115 116 classes: dict[str, type[Any]] = dataclasses.field(default_factory=dict) 117 118 def _get_class(self, cls: type[Any]) -> type[Any]: 119 """ 120 Get a wrapper class from the the cache. If not found, it will generate 121 the class and store it in the cache. 122 """ 123 class_name = f"Wrapper{typename(cls)}" 124 wrapper = self.classes.get(class_name) 125 return wrapper or self._generate_class(cls) 126 127 def _generate_class(self, cls: type[Any]) -> type[Any]: 128 """ 129 Generate a wrapper dataclass then make the it (de)serializable using 130 @serde decorator. 131 """ 132 from . import serde 133 134 class_name = f"Wrapper{typename(cls)}" 135 logger.debug(f"Generating a wrapper class code for {class_name}") 136 137 wrapper = dataclasses.make_dataclass(class_name, [("v", cls)]) 138 139 serde(wrapper) 140 self.classes[class_name] = wrapper 141 142 logger.debug(f"(de)serializing code for {class_name} was generated") 143 return wrapper 144 145 def serialize(self, cls: type[Any], obj: Any, **kwargs: Any) -> Any: 146 """ 147 Serialize the specified type of object into dict or tuple. 148 """ 149 wrapper = self._get_class(cls) 150 scope: Scope = getattr(wrapper, SERDE_SCOPE) 151 data = scope.funcs[TO_DICT](wrapper(obj), **kwargs) 152 153 logging.debug(f"Intermediate value: {data}") 154 155 return data["v"] 156 157 def deserialize(self, cls: type[T], obj: Any) -> T: 158 """ 159 Deserialize from dict or tuple into the specified type. 160 """ 161 wrapper = self._get_class(cls) 162 scope: Scope = getattr(wrapper, SERDE_SCOPE) 163 return scope.funcs[FROM_DICT](data={"v": obj}).v # type: ignore 164 165 def _get_union_class(self, cls: type[Any]) -> Optional[type[Any]]: 166 """ 167 Get a wrapper class from the the cache. If not found, it will generate 168 the class and store it in the cache. 169 """ 170 union_cls, tagging = _extract_from_with_tagging(cls) 171 class_name = union_func_name( 172 f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls)) 173 ) 174 wrapper = self.classes.get(class_name) 175 return wrapper or self._generate_union_class(cls) 176 177 def _generate_union_class(self, cls: type[Any]) -> type[Any]: 178 """ 179 Generate a wrapper dataclass then make the it (de)serializable using 180 @serde decorator. 181 """ 182 import serde 183 184 union_cls, tagging = _extract_from_with_tagging(cls) 185 class_name = union_func_name( 186 f"{tagging.produce_unique_class_name()}Union", list(type_args(union_cls)) 187 ) 188 wrapper = dataclasses.make_dataclass(class_name, [("v", union_cls)]) 189 serde.serde(wrapper, tagging=tagging) 190 self.classes[class_name] = wrapper 191 return wrapper 192 193 def serialize_union(self, cls: type[Any], obj: Any) -> Any: 194 """ 195 Serialize the specified Union into dict or tuple. 196 """ 197 union_cls, _ = _extract_from_with_tagging(cls) 198 wrapper = self._get_union_class(cls) 199 scope: Scope = getattr(wrapper, SERDE_SCOPE) 200 func_name = union_func_name(UNION_SE_PREFIX, list(type_args(union_cls))) 201 return scope.funcs[func_name](obj, False, False) 202 203 def deserialize_union(self, cls: type[T], data: Any) -> T: 204 """ 205 Deserialize from dict or tuple into the specified Union. 206 """ 207 union_cls, _ = _extract_from_with_tagging(cls) 208 wrapper = self._get_union_class(cls) 209 scope: Scope = getattr(wrapper, SERDE_SCOPE) 210 func_name = union_func_name(UNION_DE_PREFIX, list(type_args(union_cls))) 211 return scope.funcs[func_name](cls=union_cls, data=data) # type: ignore 212 213 214def _extract_from_with_tagging(maybe_with_tagging: Any) -> tuple[Any, Tagging]: 215 try: 216 if isinstance(maybe_with_tagging, _WithTagging): 217 union_cls = maybe_with_tagging.inner 218 tagging = maybe_with_tagging.tagging 219 else: 220 raise Exception() 221 except Exception: 222 union_cls = maybe_with_tagging 223 tagging = ExternalTagging 224 225 return (union_cls, tagging) 226 227 228CACHE = Cache() 229""" Global cache variable for non-dataclass classes """ 230 231 232@dataclass 233class Scope: 234 """ 235 Container to store types and functions used in code generation context. 236 """ 237 238 cls: type[Any] 239 """ The exact class this scope is for 240 (needed to distinguish scopes between inherited classes) """ 241 242 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 243 """ Generated serialize and deserialize functions """ 244 245 defaults: dict[str, Union[Callable[..., Any], Any]] = dataclasses.field(default_factory=dict) 246 """ Default values of the dataclass fields (factories & normal values) """ 247 248 code: dict[str, str] = dataclasses.field(default_factory=dict) 249 """ Generated source code (only filled when debug is True) """ 250 251 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 252 """ The union serializing functions need references to their types """ 253 254 reuse_instances_default: bool = True 255 """ Default values for to_dict & from_dict arguments """ 256 257 convert_sets_default: bool = False 258 259 def __repr__(self) -> str: 260 res: list[str] = [] 261 262 res.append("==================================================") 263 res.append(self._justify(self.cls.__name__)) 264 res.append("==================================================") 265 res.append("") 266 267 if self.code: 268 res.append("--------------------------------------------------") 269 res.append(self._justify("Functions generated by pyserde")) 270 res.append("--------------------------------------------------") 271 res.extend(list(self.code.values())) 272 res.append("") 273 274 if self.funcs: 275 res.append("--------------------------------------------------") 276 res.append(self._justify("Function references in scope")) 277 res.append("--------------------------------------------------") 278 for k, v in self.funcs.items(): 279 res.append(f"{k}: {v}") 280 res.append("") 281 282 if self.defaults: 283 res.append("--------------------------------------------------") 284 res.append(self._justify("Default values for the dataclass fields")) 285 res.append("--------------------------------------------------") 286 for k, v in self.defaults.items(): 287 res.append(f"{k}: {v}") 288 res.append("") 289 290 if self.union_se_args: 291 res.append("--------------------------------------------------") 292 res.append(self._justify("Type list by used for union serialize functions")) 293 res.append("--------------------------------------------------") 294 for k, lst in self.union_se_args.items(): 295 res.append(f"{k}: {list(lst)}") 296 res.append("") 297 298 return "\n".join(res) 299 300 def _justify(self, s: str, length: int = 50) -> str: 301 white_spaces = int((50 - len(s)) / 2) 302 return " " * (white_spaces if white_spaces > 0 else 0) + s 303 304 305def raise_unsupported_type(obj: Any) -> None: 306 # needed because we can not render a raise statement everywhere, e.g. as argument 307 raise SerdeError(f"Unsupported type: {typename(type(obj))}") 308 309 310def gen( 311 code: str, globals: Optional[dict[str, Any]] = None, locals: Optional[dict[str, Any]] = None 312) -> str: 313 """ 314 A wrapper of builtin `exec` function. 315 """ 316 if SETTINGS["debug"]: 317 # black formatting is only important when debugging 318 try: 319 from black import FileMode, format_str 320 321 code = format_str(code, mode=FileMode(line_length=100)) 322 except Exception: 323 pass 324 exec(code, globals, locals) 325 return code 326 327 328def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 329 """ 330 Generate a function and add it to a Scope's `funcs` dictionary. 331 332 * `serde_scope`: the Scope instance to modify 333 * `func_name`: the name of the function 334 * `func_code`: the source code of the function 335 * `globals`: global variables that should be accessible to the generated function 336 """ 337 338 code = gen(func_code, globals) 339 serde_scope.funcs[func_name] = globals[func_name] 340 341 if SETTINGS["debug"]: 342 serde_scope.code[func_name] = code 343 344 345def is_instance(obj: Any, typ: Any) -> bool: 346 """ 347 pyserde's own `isinstance` helper. It accepts subscripted generics e.g. `list[int]` and 348 deeply check object against declared type. 349 """ 350 if dataclasses.is_dataclass(typ): 351 return isinstance(obj, typ) 352 elif is_opt(typ): 353 return is_opt_instance(obj, typ) 354 elif is_union(typ): 355 return is_union_instance(obj, typ) 356 elif is_list(typ): 357 return is_list_instance(obj, typ) 358 elif is_set(typ): 359 return is_set_instance(obj, typ) 360 elif is_tuple(typ): 361 return is_tuple_instance(obj, typ) 362 elif is_dict(typ): 363 return is_dict_instance(obj, typ) 364 elif is_generic(typ): 365 return is_generic_instance(obj, typ) 366 elif is_literal(typ): 367 return True 368 elif is_new_type_primitive(typ): 369 inner = getattr(typ, "__supertype__", None) 370 if type(inner) is type: 371 return isinstance(obj, inner) 372 else: 373 return False 374 elif is_any(typ): 375 return True 376 elif typ is Ellipsis: 377 return True 378 else: 379 return is_bearable(obj, typ) 380 381 382def is_opt_instance(obj: Any, typ: type[Any]) -> bool: 383 if obj is None: 384 return True 385 opt_arg = type_args(typ)[0] 386 return is_instance(obj, opt_arg) 387 388 389def is_union_instance(obj: Any, typ: type[Any]) -> bool: 390 for arg in type_args(typ): 391 if is_instance(obj, arg): 392 return True 393 return False 394 395 396def is_list_instance(obj: Any, typ: type[Any]) -> bool: 397 if not isinstance(obj, list): 398 return False 399 if len(obj) == 0 or is_bare_list(typ): 400 return True 401 list_arg = type_args(typ)[0] 402 # for speed reasons we just check the type of the 1st element 403 return is_instance(obj[0], list_arg) 404 405 406def is_set_instance(obj: Any, typ: type[Any]) -> bool: 407 if not isinstance(obj, (set, frozenset)): 408 return False 409 if len(obj) == 0 or is_bare_set(typ): 410 return True 411 set_arg = type_args(typ)[0] 412 # for speed reasons we just check the type of the 1st element 413 return is_instance(next(iter(obj)), set_arg) 414 415 416def is_tuple_instance(obj: Any, typ: type[Any]) -> bool: 417 args = type_args(typ) 418 419 if not isinstance(obj, tuple): 420 return False 421 422 # empty tuple 423 if len(args) == 0 and len(obj) == 0: 424 return True 425 426 # In the form of tuple[T, ...] 427 elif is_variable_tuple(typ): 428 # Get the first type arg. Since tuple[T, ...] is homogeneous tuple, 429 # all the elements should be of this type. 430 arg = type_args(typ)[0] 431 for v in obj: 432 if not is_instance(v, arg): 433 return False 434 return True 435 436 # bare tuple "tuple" is equivalent to tuple[Any, ...] 437 if is_bare_tuple(typ) and isinstance(obj, tuple): 438 return True 439 440 # All the other tuples e.g. tuple[int, str] 441 if len(obj) == len(args): 442 for element, arg in zip(obj, args): 443 if not is_instance(element, arg): 444 return False 445 else: 446 return False 447 448 return True 449 450 451def is_dict_instance(obj: Any, typ: type[Any]) -> bool: 452 if not isinstance(obj, dict): 453 return False 454 if len(obj) == 0 or is_bare_dict(typ): 455 return True 456 ktyp = type_args(typ)[0] 457 vtyp = type_args(typ)[1] 458 for k, v in obj.items(): 459 # for speed reasons we just check the type of the 1st element 460 return is_instance(k, ktyp) and is_instance(v, vtyp) 461 return False 462 463 464def is_generic_instance(obj: Any, typ: type[Any]) -> bool: 465 return is_instance(obj, get_origin(typ)) 466 467 468@dataclass 469class Func: 470 """ 471 Function wrapper that provides `mangled` optional field. 472 473 pyserde copies every function reference into global scope 474 for code generation. Mangling function names is needed in 475 order to avoid name conflict in the global scope when 476 multiple fields receives `skip_if` attribute. 477 """ 478 479 inner: Callable[..., Any] 480 """ Function to wrap in """ 481 482 mangeld: str = "" 483 """ Mangled function name """ 484 485 def __call__(self, v: Any) -> None: 486 return self.inner(v) # type: ignore 487 488 @property 489 def name(self) -> str: 490 """ 491 Mangled function name 492 """ 493 return self.mangeld 494 495 496def skip_if_false(v: Any) -> Any: 497 return not bool(v) 498 499 500def skip_if_default(v: Any, default: Optional[Any] = None) -> Any: 501 return v == default # Why return type is deduced to be Any? 502 503 504@dataclass 505class FlattenOpts: 506 """ 507 Flatten options. Currently not used. 508 """ 509 510 511def field( 512 *args: Any, 513 rename: Optional[str] = None, 514 alias: Optional[list[str]] = None, 515 skip: Optional[bool] = None, 516 skip_if: Optional[Callable[[Any], Any]] = None, 517 skip_if_false: Optional[bool] = None, 518 skip_if_default: Optional[bool] = None, 519 serializer: Optional[Callable[..., Any]] = None, 520 deserializer: Optional[Callable[..., Any]] = None, 521 flatten: Optional[Union[FlattenOpts, bool]] = None, 522 metadata: Optional[dict[str, Any]] = None, 523 **kwargs: Any, 524) -> Any: 525 """ 526 Declare a field with parameters. 527 """ 528 if not metadata: 529 metadata = {} 530 531 if rename is not None: 532 metadata["serde_rename"] = rename 533 if alias is not None: 534 metadata["serde_alias"] = alias 535 if skip is not None: 536 metadata["serde_skip"] = skip 537 if skip_if is not None: 538 metadata["serde_skip_if"] = skip_if 539 if skip_if_false is not None: 540 metadata["serde_skip_if_false"] = skip_if_false 541 if skip_if_default is not None: 542 metadata["serde_skip_if_default"] = skip_if_default 543 if serializer: 544 metadata["serde_serializer"] = serializer 545 if deserializer: 546 metadata["serde_deserializer"] = deserializer 547 if flatten is True: 548 metadata["serde_flatten"] = FlattenOpts() 549 elif flatten: 550 metadata["serde_flatten"] = flatten 551 552 return dataclasses.field(*args, metadata=metadata, **kwargs) 553 554 555@dataclass 556class Field(Generic[T]): 557 """ 558 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 559 560 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 561 """ 562 563 type: type[T] 564 """ Type of Field """ 565 name: Optional[str] 566 """ Name of Field """ 567 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 568 """ Default value of Field """ 569 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 570 """ Default factory method of Field """ 571 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 572 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 573 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 574 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 575 metadata: Mapping[str, Any] = field(default_factory=dict) 576 kw_only: bool = False 577 case: Optional[str] = None 578 alias: list[str] = field(default_factory=list) 579 rename: Optional[str] = None 580 skip: Optional[bool] = None 581 skip_if: Optional[Func] = None 582 skip_if_false: Optional[bool] = None 583 skip_if_default: Optional[bool] = None 584 serializer: Optional[Func] = None # Custom field serializer. 585 deserializer: Optional[Func] = None # Custom field deserializer. 586 flatten: Optional[FlattenOpts] = None 587 parent: Optional[Any] = None 588 type_args: Optional[list[str]] = None 589 590 @classmethod 591 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 592 """ 593 Create `Field` object from `dataclasses.Field`. 594 """ 595 skip_if_false_func: Optional[Func] = None 596 if f.metadata.get("serde_skip_if_false"): 597 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 598 599 skip_if_default_func: Optional[Func] = None 600 if f.metadata.get("serde_skip_if_default"): 601 skip_if_def = functools.partial(skip_if_default, default=f.default) 602 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 603 604 skip_if: Optional[Func] = None 605 if f.metadata.get("serde_skip_if"): 606 func = f.metadata.get("serde_skip_if") 607 if callable(func): 608 skip_if = Func(func, cls.mangle(f, "skip_if")) 609 610 serializer: Optional[Func] = None 611 func = f.metadata.get("serde_serializer") 612 if func: 613 serializer = Func(func, cls.mangle(f, "serializer")) 614 615 deserializer: Optional[Func] = None 616 func = f.metadata.get("serde_deserializer") 617 if func: 618 deserializer = Func(func, cls.mangle(f, "deserializer")) 619 620 flatten = f.metadata.get("serde_flatten") 621 if flatten is True: 622 flatten = FlattenOpts() 623 624 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 625 626 return cls( 627 f.type, 628 f.name, 629 default=f.default, 630 default_factory=f.default_factory, 631 init=f.init, 632 repr=f.repr, 633 hash=f.hash, 634 compare=f.compare, 635 metadata=f.metadata, 636 rename=f.metadata.get("serde_rename"), 637 alias=f.metadata.get("serde_alias", []), 638 skip=f.metadata.get("serde_skip"), 639 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 640 serializer=serializer, 641 deserializer=deserializer, 642 flatten=flatten, 643 parent=parent, 644 kw_only=kw_only, 645 ) 646 647 def to_dataclass(self) -> dataclasses.Field[T]: 648 f = dataclasses.Field( 649 default=self.default, 650 default_factory=self.default_factory, 651 init=self.init, 652 repr=self.repr, 653 hash=self.hash, 654 compare=self.compare, 655 metadata=self.metadata, 656 kw_only=self.kw_only, 657 ) 658 assert self.name 659 f.name = self.name 660 f.type = self.type 661 return f 662 663 def is_self_referencing(self) -> bool: 664 if self.type is None: 665 return False 666 if self.parent is None: 667 return False 668 return self.type == self.parent # type: ignore 669 670 @staticmethod 671 def mangle(field: dataclasses.Field[Any], name: str) -> str: 672 """ 673 Get mangled name based on field name. 674 """ 675 return f"{field.name}_{name}" 676 677 def conv_name(self, case: Optional[str] = None) -> str: 678 """ 679 Get an actual field name which `rename` and `rename_all` conversions 680 are made. Use `name` property to get a field name before conversion. 681 """ 682 return conv(self, case or self.case) 683 684 def supports_default(self) -> bool: 685 return not getattr(self, "iterbased", False) and ( 686 has_default(self) or has_default_factory(self) 687 ) 688 689 690F = TypeVar("F", bound=Field[Any]) 691 692 693def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 694 """ 695 Iterate fields of the dataclass and returns `serde.core.Field`. 696 """ 697 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 698 699 if serialize_class_var: 700 for name, typ in get_type_hints(cls).items(): 701 if is_class_var(typ): 702 fields.append(field_cls(typ, name, default=getattr(cls, name))) 703 704 return fields # type: ignore 705 706 707def conv(f: Field[Any], case: Optional[str] = None) -> str: 708 """ 709 Convert field name. 710 """ 711 name = f.name 712 if case: 713 casef = getattr(casefy, case, None) 714 if not casef: 715 raise SerdeError( 716 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 717 ) 718 name = casef(name) 719 if f.rename: 720 name = f.rename 721 if name is None: 722 raise SerdeError("Field name is None.") 723 return name 724 725 726def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 727 """ 728 Generate a function name that contains all union types 729 730 * `prefix` prefix to distinguish between serializing and deserializing 731 * `union_args`: type arguments of a Union 732 733 >>> from ipaddress import IPv4Address 734 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 735 'union_se_int_list_str__IPv4Address' 736 """ 737 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}") 738 739 740def literal_func_name(literal_args: Sequence[Any]) -> str: 741 """ 742 Generate a function name with all literals and corresponding types specified with Literal[...] 743 744 745 * `literal_args`: arguments of a Literal 746 747 >>> literal_func_name(["r", "w", "a", "x", "r+", "w+", "a+", "x+"]) 748 'literal_de_r_str_w_str_a_str_x_str_r__str_w__str_a__str_x__str' 749 """ 750 return re.sub( 751 r"[^A-Za-z0-9]", 752 "_", 753 f"{LITERAL_DE_PREFIX}_{'_'.join(f'{a}_{typename(type(a))}' for a in literal_args)}", 754 ) 755 756 757@dataclass 758class Tagging: 759 """ 760 Controls how union is (de)serialized. This is the same concept as in 761 https://serde.rs/enum-representations.html 762 """ 763 764 class Kind(enum.Enum): 765 External = enum.auto() 766 Internal = enum.auto() 767 Adjacent = enum.auto() 768 Untagged = enum.auto() 769 770 tag: Optional[str] = None 771 content: Optional[str] = None 772 kind: Kind = Kind.External 773 774 def is_external(self) -> bool: 775 return self.kind == self.Kind.External 776 777 def is_internal(self) -> bool: 778 return self.kind == self.Kind.Internal 779 780 def is_adjacent(self) -> bool: 781 return self.kind == self.Kind.Adjacent 782 783 def is_untagged(self) -> bool: 784 return self.kind == self.Kind.Untagged 785 786 @classmethod 787 def is_taggable(cls, typ: type[Any]) -> bool: 788 return dataclasses.is_dataclass(typ) 789 790 def check(self) -> None: 791 if self.is_internal() and self.tag is None: 792 raise SerdeError('"tag" must be specified in InternalTagging') 793 if self.is_adjacent() and (self.tag is None or self.content is None): 794 raise SerdeError('"tag" and "content" must be specified in AdjacentTagging') 795 796 def produce_unique_class_name(self) -> str: 797 """ 798 Produce a unique class name for this tagging. The name is used for generated 799 wrapper dataclass and stored in `Cache`. 800 """ 801 if self.is_internal(): 802 tag = casefy.pascalcase(self.tag) # type: ignore 803 if not tag: 804 raise SerdeError('"tag" must be specified in InternalTagging') 805 return f"Internal{tag}" 806 elif self.is_adjacent(): 807 tag = casefy.pascalcase(self.tag) # type: ignore 808 content = casefy.pascalcase(self.content) # type: ignore 809 if not tag: 810 raise SerdeError('"tag" must be specified in AdjacentTagging') 811 if not content: 812 raise SerdeError('"content" must be specified in AdjacentTagging') 813 return f"Adjacent{tag}{content}" 814 else: 815 return self.kind.name 816 817 def __call__(self, cls: T) -> _WithTagging[T]: 818 return _WithTagging(cls, self) 819 820 821@overload 822def InternalTagging(tag: str) -> Tagging: ... 823 824 825@overload 826def InternalTagging(tag: str, cls: T) -> _WithTagging[T]: ... 827 828 829def InternalTagging(tag: str, cls: Optional[T] = None) -> Union[Tagging, _WithTagging[T]]: 830 tagging = Tagging(tag, kind=Tagging.Kind.Internal) 831 if cls: 832 return tagging(cls) 833 else: 834 return tagging 835 836 837@overload 838def AdjacentTagging(tag: str, content: str) -> Tagging: ... 839 840 841@overload 842def AdjacentTagging(tag: str, content: str, cls: T) -> _WithTagging[T]: ... 843 844 845def AdjacentTagging( 846 tag: str, content: str, cls: Optional[T] = None 847) -> Union[Tagging, _WithTagging[T]]: 848 tagging = Tagging(tag, content, kind=Tagging.Kind.Adjacent) 849 if cls: 850 return tagging(cls) 851 else: 852 return tagging 853 854 855ExternalTagging = Tagging() 856 857Untagged = Tagging(kind=Tagging.Kind.Untagged) 858 859 860DefaultTagging = ExternalTagging 861 862 863def ensure(expr: Any, description: str) -> None: 864 if not expr: 865 raise Exception(description) 866 867 868def should_impl_dataclass(cls: type[Any]) -> bool: 869 """ 870 Test if class doesn't have @dataclass. 871 872 `dataclasses.is_dataclass` returns True even Derived class doesn't actually @dataclass. 873 >>> @dataclasses.dataclass 874 ... class Base: 875 ... a: int 876 >>> class Derived(Base): 877 ... b: int 878 >>> dataclasses.is_dataclass(Derived) 879 True 880 881 This function tells the class actually have @dataclass or not. 882 >>> should_impl_dataclass(Base) 883 False 884 >>> should_impl_dataclass(Derived) 885 True 886 """ 887 if not dataclasses.is_dataclass(cls): 888 return True 889 890 annotations = getattr(cls, "__annotations__", {}) 891 if not annotations: 892 return False 893 894 field_names = [field.name for field in dataclass_fields(cls)] 895 for field_name in annotations: 896 if field_name not in field_names: 897 return True 898 899 return False 900 901 902@dataclass 903class TypeCheck: 904 """ 905 Specify type check flavors. 906 """ 907 908 class Kind(enum.Enum): 909 Disabled = enum.auto() 910 """ No check performed """ 911 912 Coerce = enum.auto() 913 """ Value is coerced into the declared type """ 914 Strict = enum.auto() 915 """ Value are strictly checked against the declared type """ 916 917 kind: Kind 918 919 def is_strict(self) -> bool: 920 return self.kind == self.Kind.Strict 921 922 def is_coerce(self) -> bool: 923 return self.kind == self.Kind.Coerce 924 925 def __call__(self, **kwargs: Any) -> TypeCheck: 926 # TODO 927 return self 928 929 930disabled = TypeCheck(kind=TypeCheck.Kind.Disabled) 931 932coerce = TypeCheck(kind=TypeCheck.Kind.Coerce) 933 934strict = TypeCheck(kind=TypeCheck.Kind.Strict) 935 936 937def coerce_object(typ: type[Any], obj: Any) -> Any: 938 return typ(obj) if is_coercible(typ, obj) else obj 939 940 941def is_coercible(typ: type[Any], obj: Any) -> bool: 942 if obj is None: 943 return False 944 return True 945 946 947def has_default(field: Field[Any]) -> bool: 948 """ 949 Test if the field has default value. 950 951 >>> @dataclasses.dataclass 952 ... class C: 953 ... a: int 954 ... d: int = 10 955 >>> has_default(dataclasses.fields(C)[0]) 956 False 957 >>> has_default(dataclasses.fields(C)[1]) 958 True 959 """ 960 return not isinstance(field.default, dataclasses._MISSING_TYPE) 961 962 963def has_default_factory(field: Field[Any]) -> bool: 964 """ 965 Test if the field has default factory. 966 967 >>> @dataclasses.dataclass 968 ... class C: 969 ... a: int 970 ... d: dict = dataclasses.field(default_factory=dict) 971 >>> has_default_factory(dataclasses.fields(C)[0]) 972 False 973 >>> has_default_factory(dataclasses.fields(C)[1]) 974 True 975 """ 976 return not isinstance(field.default_factory, dataclasses._MISSING_TYPE) 977 978 979class ClassSerializer(Protocol): 980 """ 981 Interface for custom class serializer. 982 983 This protocol is intended to be used for custom class serializer. 984 985 >>> from datetime import datetime 986 >>> from serde import serde 987 >>> from plum import dispatch 988 >>> class MySerializer(ClassSerializer): 989 ... @dispatch 990 ... def serialize(self, value: datetime) -> str: 991 ... return value.strftime("%d/%m/%y") 992 """ 993 994 def serialize(self, value: Any) -> Any: 995 pass 996 997 998class ClassDeserializer(Protocol): 999 """ 1000 Interface for custom class deserializer. 1001 1002 This protocol is intended to be used for custom class deserializer. 1003 1004 >>> from datetime import datetime 1005 >>> from serde import serde 1006 >>> from plum import dispatch 1007 >>> class MyDeserializer(ClassDeserializer): 1008 ... @dispatch 1009 ... def deserialize(self, cls: type[datetime], value: Any) -> datetime: 1010 ... return datetime.strptime(value, "%d/%m/%y") 1011 """ 1012 1013 def deserialize(self, cls: Any, value: Any) -> Any: 1014 pass 1015 1016 1017GLOBAL_CLASS_SERIALIZER: list[ClassSerializer] = [] 1018 1019GLOBAL_CLASS_DESERIALIZER: list[ClassDeserializer] = [] 1020 1021 1022def add_serializer(serializer: ClassSerializer) -> None: 1023 """ 1024 Register custom global serializer. 1025 """ 1026 GLOBAL_CLASS_SERIALIZER.append(serializer) 1027 1028 1029def add_deserializer(deserializer: ClassDeserializer) -> None: 1030 """ 1031 Register custom global deserializer. 1032 """ 1033 GLOBAL_CLASS_DESERIALIZER.append(deserializer)
233@dataclass 234class Scope: 235 """ 236 Container to store types and functions used in code generation context. 237 """ 238 239 cls: type[Any] 240 """ The exact class this scope is for 241 (needed to distinguish scopes between inherited classes) """ 242 243 funcs: dict[str, Callable[..., Any]] = dataclasses.field(default_factory=dict) 244 """ Generated serialize and deserialize functions """ 245 246 defaults: dict[str, Union[Callable[..., Any], Any]] = dataclasses.field(default_factory=dict) 247 """ Default values of the dataclass fields (factories & normal values) """ 248 249 code: dict[str, str] = dataclasses.field(default_factory=dict) 250 """ Generated source code (only filled when debug is True) """ 251 252 union_se_args: dict[str, list[type[Any]]] = dataclasses.field(default_factory=dict) 253 """ The union serializing functions need references to their types """ 254 255 reuse_instances_default: bool = True 256 """ Default values for to_dict & from_dict arguments """ 257 258 convert_sets_default: bool = False 259 260 def __repr__(self) -> str: 261 res: list[str] = [] 262 263 res.append("==================================================") 264 res.append(self._justify(self.cls.__name__)) 265 res.append("==================================================") 266 res.append("") 267 268 if self.code: 269 res.append("--------------------------------------------------") 270 res.append(self._justify("Functions generated by pyserde")) 271 res.append("--------------------------------------------------") 272 res.extend(list(self.code.values())) 273 res.append("") 274 275 if self.funcs: 276 res.append("--------------------------------------------------") 277 res.append(self._justify("Function references in scope")) 278 res.append("--------------------------------------------------") 279 for k, v in self.funcs.items(): 280 res.append(f"{k}: {v}") 281 res.append("") 282 283 if self.defaults: 284 res.append("--------------------------------------------------") 285 res.append(self._justify("Default values for the dataclass fields")) 286 res.append("--------------------------------------------------") 287 for k, v in self.defaults.items(): 288 res.append(f"{k}: {v}") 289 res.append("") 290 291 if self.union_se_args: 292 res.append("--------------------------------------------------") 293 res.append(self._justify("Type list by used for union serialize functions")) 294 res.append("--------------------------------------------------") 295 for k, lst in self.union_se_args.items(): 296 res.append(f"{k}: {list(lst)}") 297 res.append("") 298 299 return "\n".join(res) 300 301 def _justify(self, s: str, length: int = 50) -> str: 302 white_spaces = int((50 - len(s)) / 2) 303 return " " * (white_spaces if white_spaces > 0 else 0) + s
Container to store types and functions used in code generation context.
The exact class this scope is for (needed to distinguish scopes between inherited classes)
Generated serialize and deserialize functions
Default values of the dataclass fields (factories & normal values)
311def gen( 312 code: str, globals: Optional[dict[str, Any]] = None, locals: Optional[dict[str, Any]] = None 313) -> str: 314 """ 315 A wrapper of builtin `exec` function. 316 """ 317 if SETTINGS["debug"]: 318 # black formatting is only important when debugging 319 try: 320 from black import FileMode, format_str 321 322 code = format_str(code, mode=FileMode(line_length=100)) 323 except Exception: 324 pass 325 exec(code, globals, locals) 326 return code
A wrapper of builtin exec
function.
329def add_func(serde_scope: Scope, func_name: str, func_code: str, globals: dict[str, Any]) -> None: 330 """ 331 Generate a function and add it to a Scope's `funcs` dictionary. 332 333 * `serde_scope`: the Scope instance to modify 334 * `func_name`: the name of the function 335 * `func_code`: the source code of the function 336 * `globals`: global variables that should be accessible to the generated function 337 """ 338 339 code = gen(func_code, globals) 340 serde_scope.funcs[func_name] = globals[func_name] 341 342 if SETTINGS["debug"]: 343 serde_scope.code[func_name] = code
Generate a function and add it to a Scope's funcs
dictionary.
serde_scope
: the Scope instance to modifyfunc_name
: the name of the functionfunc_code
: the source code of the functionglobals
: global variables that should be accessible to the generated function
469@dataclass 470class Func: 471 """ 472 Function wrapper that provides `mangled` optional field. 473 474 pyserde copies every function reference into global scope 475 for code generation. Mangling function names is needed in 476 order to avoid name conflict in the global scope when 477 multiple fields receives `skip_if` attribute. 478 """ 479 480 inner: Callable[..., Any] 481 """ Function to wrap in """ 482 483 mangeld: str = "" 484 """ Mangled function name """ 485 486 def __call__(self, v: Any) -> None: 487 return self.inner(v) # type: ignore 488 489 @property 490 def name(self) -> str: 491 """ 492 Mangled function name 493 """ 494 return self.mangeld
Function wrapper that provides mangled
optional field.
pyserde copies every function reference into global scope
for code generation. Mangling function names is needed in
order to avoid name conflict in the global scope when
multiple fields receives skip_if
attribute.
556@dataclass 557class Field(Generic[T]): 558 """ 559 Field class is similar to `dataclasses.Field`. It provides pyserde specific options. 560 561 `type`, `name`, `default` and `default_factory` are the same members as `dataclasses.Field`. 562 """ 563 564 type: type[T] 565 """ Type of Field """ 566 name: Optional[str] 567 """ Name of Field """ 568 default: Any = field(default_factory=dataclasses._MISSING_TYPE) 569 """ Default value of Field """ 570 default_factory: Any = field(default_factory=dataclasses._MISSING_TYPE) 571 """ Default factory method of Field """ 572 init: bool = field(default_factory=dataclasses._MISSING_TYPE) 573 repr: Any = field(default_factory=dataclasses._MISSING_TYPE) 574 hash: Any = field(default_factory=dataclasses._MISSING_TYPE) 575 compare: Any = field(default_factory=dataclasses._MISSING_TYPE) 576 metadata: Mapping[str, Any] = field(default_factory=dict) 577 kw_only: bool = False 578 case: Optional[str] = None 579 alias: list[str] = field(default_factory=list) 580 rename: Optional[str] = None 581 skip: Optional[bool] = None 582 skip_if: Optional[Func] = None 583 skip_if_false: Optional[bool] = None 584 skip_if_default: Optional[bool] = None 585 serializer: Optional[Func] = None # Custom field serializer. 586 deserializer: Optional[Func] = None # Custom field deserializer. 587 flatten: Optional[FlattenOpts] = None 588 parent: Optional[Any] = None 589 type_args: Optional[list[str]] = None 590 591 @classmethod 592 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 593 """ 594 Create `Field` object from `dataclasses.Field`. 595 """ 596 skip_if_false_func: Optional[Func] = None 597 if f.metadata.get("serde_skip_if_false"): 598 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 599 600 skip_if_default_func: Optional[Func] = None 601 if f.metadata.get("serde_skip_if_default"): 602 skip_if_def = functools.partial(skip_if_default, default=f.default) 603 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 604 605 skip_if: Optional[Func] = None 606 if f.metadata.get("serde_skip_if"): 607 func = f.metadata.get("serde_skip_if") 608 if callable(func): 609 skip_if = Func(func, cls.mangle(f, "skip_if")) 610 611 serializer: Optional[Func] = None 612 func = f.metadata.get("serde_serializer") 613 if func: 614 serializer = Func(func, cls.mangle(f, "serializer")) 615 616 deserializer: Optional[Func] = None 617 func = f.metadata.get("serde_deserializer") 618 if func: 619 deserializer = Func(func, cls.mangle(f, "deserializer")) 620 621 flatten = f.metadata.get("serde_flatten") 622 if flatten is True: 623 flatten = FlattenOpts() 624 625 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 626 627 return cls( 628 f.type, 629 f.name, 630 default=f.default, 631 default_factory=f.default_factory, 632 init=f.init, 633 repr=f.repr, 634 hash=f.hash, 635 compare=f.compare, 636 metadata=f.metadata, 637 rename=f.metadata.get("serde_rename"), 638 alias=f.metadata.get("serde_alias", []), 639 skip=f.metadata.get("serde_skip"), 640 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 641 serializer=serializer, 642 deserializer=deserializer, 643 flatten=flatten, 644 parent=parent, 645 kw_only=kw_only, 646 ) 647 648 def to_dataclass(self) -> dataclasses.Field[T]: 649 f = dataclasses.Field( 650 default=self.default, 651 default_factory=self.default_factory, 652 init=self.init, 653 repr=self.repr, 654 hash=self.hash, 655 compare=self.compare, 656 metadata=self.metadata, 657 kw_only=self.kw_only, 658 ) 659 assert self.name 660 f.name = self.name 661 f.type = self.type 662 return f 663 664 def is_self_referencing(self) -> bool: 665 if self.type is None: 666 return False 667 if self.parent is None: 668 return False 669 return self.type == self.parent # type: ignore 670 671 @staticmethod 672 def mangle(field: dataclasses.Field[Any], name: str) -> str: 673 """ 674 Get mangled name based on field name. 675 """ 676 return f"{field.name}_{name}" 677 678 def conv_name(self, case: Optional[str] = None) -> str: 679 """ 680 Get an actual field name which `rename` and `rename_all` conversions 681 are made. Use `name` property to get a field name before conversion. 682 """ 683 return conv(self, case or self.case) 684 685 def supports_default(self) -> bool: 686 return not getattr(self, "iterbased", False) and ( 687 has_default(self) or has_default_factory(self) 688 )
Field class is similar to dataclasses.Field
. It provides pyserde specific options.
type
, name
, default
and default_factory
are the same members as dataclasses.Field
.
591 @classmethod 592 def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -> Field[T]: 593 """ 594 Create `Field` object from `dataclasses.Field`. 595 """ 596 skip_if_false_func: Optional[Func] = None 597 if f.metadata.get("serde_skip_if_false"): 598 skip_if_false_func = Func(skip_if_false, cls.mangle(f, "skip_if_false")) 599 600 skip_if_default_func: Optional[Func] = None 601 if f.metadata.get("serde_skip_if_default"): 602 skip_if_def = functools.partial(skip_if_default, default=f.default) 603 skip_if_default_func = Func(skip_if_def, cls.mangle(f, "skip_if_default")) 604 605 skip_if: Optional[Func] = None 606 if f.metadata.get("serde_skip_if"): 607 func = f.metadata.get("serde_skip_if") 608 if callable(func): 609 skip_if = Func(func, cls.mangle(f, "skip_if")) 610 611 serializer: Optional[Func] = None 612 func = f.metadata.get("serde_serializer") 613 if func: 614 serializer = Func(func, cls.mangle(f, "serializer")) 615 616 deserializer: Optional[Func] = None 617 func = f.metadata.get("serde_deserializer") 618 if func: 619 deserializer = Func(func, cls.mangle(f, "deserializer")) 620 621 flatten = f.metadata.get("serde_flatten") 622 if flatten is True: 623 flatten = FlattenOpts() 624 625 kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False 626 627 return cls( 628 f.type, 629 f.name, 630 default=f.default, 631 default_factory=f.default_factory, 632 init=f.init, 633 repr=f.repr, 634 hash=f.hash, 635 compare=f.compare, 636 metadata=f.metadata, 637 rename=f.metadata.get("serde_rename"), 638 alias=f.metadata.get("serde_alias", []), 639 skip=f.metadata.get("serde_skip"), 640 skip_if=skip_if or skip_if_false_func or skip_if_default_func, 641 serializer=serializer, 642 deserializer=deserializer, 643 flatten=flatten, 644 parent=parent, 645 kw_only=kw_only, 646 )
Create Field
object from dataclasses.Field
.
648 def to_dataclass(self) -> dataclasses.Field[T]: 649 f = dataclasses.Field( 650 default=self.default, 651 default_factory=self.default_factory, 652 init=self.init, 653 repr=self.repr, 654 hash=self.hash, 655 compare=self.compare, 656 metadata=self.metadata, 657 kw_only=self.kw_only, 658 ) 659 assert self.name 660 f.name = self.name 661 f.type = self.type 662 return f
671 @staticmethod 672 def mangle(field: dataclasses.Field[Any], name: str) -> str: 673 """ 674 Get mangled name based on field name. 675 """ 676 return f"{field.name}_{name}"
Get mangled name based on field name.
694def fields(field_cls: type[F], cls: type[Any], serialize_class_var: bool = False) -> list[F]: 695 """ 696 Iterate fields of the dataclass and returns `serde.core.Field`. 697 """ 698 fields = [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)] 699 700 if serialize_class_var: 701 for name, typ in get_type_hints(cls).items(): 702 if is_class_var(typ): 703 fields.append(field_cls(typ, name, default=getattr(cls, name))) 704 705 return fields # type: ignore
Iterate fields of the dataclass and returns Field
.
Flatten options. Currently not used.
708def conv(f: Field[Any], case: Optional[str] = None) -> str: 709 """ 710 Convert field name. 711 """ 712 name = f.name 713 if case: 714 casef = getattr(casefy, case, None) 715 if not casef: 716 raise SerdeError( 717 f"Unkown case type: {f.case}. Pass the name of case supported by 'casefy' package." 718 ) 719 name = casef(name) 720 if f.rename: 721 name = f.rename 722 if name is None: 723 raise SerdeError("Field name is None.") 724 return name
Convert field name.
727def union_func_name(prefix: str, union_args: Sequence[Any]) -> str: 728 """ 729 Generate a function name that contains all union types 730 731 * `prefix` prefix to distinguish between serializing and deserializing 732 * `union_args`: type arguments of a Union 733 734 >>> from ipaddress import IPv4Address 735 >>> union_func_name("union_se", [int, list[str], IPv4Address]) 736 'union_se_int_list_str__IPv4Address' 737 """ 738 return re.sub(r"[^A-Za-z0-9]", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}")
Generate a function name that contains all union types
prefix
prefix to distinguish between serializing and deserializingunion_args
: type arguments of a Union
>>> from ipaddress import IPv4Address
>>> union_func_name("union_se", [int, list[str], IPv4Address])
'union_se_int_list_str__IPv4Address'