serde.numpy
1from collections.abc import Callable 2from typing import Any, Optional 3 4from serde.compat import get_args, get_origin 5 6 7def fullname(klass): 8 module = klass.__module__ 9 if module == "builtins": 10 return klass.__qualname__ # avoid outputs like 'builtins.str' 11 return module + "." + klass.__qualname__ 12 13 14def is_numpy_type(typ) -> bool: 15 return is_bare_numpy_array(typ) or is_numpy_scalar(typ) or is_numpy_array(typ) 16 17 18def is_numpy_available() -> bool: 19 return encode_numpy is not None 20 21 22try: 23 import numpy as np 24 import numpy.typing as npt 25 26 encode_numpy: Optional[Callable[[Any], Any]] 27 28 def encode_numpy(obj: Any): 29 if isinstance(obj, np.ndarray): 30 return obj.tolist() 31 if isinstance(obj, np.datetime64): 32 return obj.item().isoformat() 33 if isinstance(obj, np.generic): 34 return obj.item() 35 raise TypeError(f"Object of type {fullname(type(obj))} is not serializable") 36 37 def is_bare_numpy_array(typ) -> bool: 38 """ 39 Test if the type is `np.ndarray` or `npt.NDArray` without type args. 40 41 >>> import numpy as np 42 >>> import numpy.typing as npt 43 >>> is_bare_numpy_array(npt.NDArray[np.int64]) 44 False 45 >>> is_bare_numpy_array(npt.NDArray) 46 True 47 >>> is_bare_numpy_array(np.ndarray) 48 True 49 """ 50 return typ in (np.ndarray, npt.NDArray) 51 52 def is_numpy_scalar(typ) -> bool: 53 try: 54 return issubclass(typ, np.generic) 55 except TypeError: 56 return False 57 58 def is_numpy_datetime(typ) -> bool: 59 try: 60 return issubclass(typ, np.datetime64) 61 except TypeError: 62 return False 63 64 def serialize_numpy_scalar(arg) -> str: 65 return f"{arg.varname}.item()" 66 67 def deserialize_numpy_scalar(arg: Any) -> str: 68 return f"{fullname(arg.type)}({arg.data})" 69 70 def is_numpy_array(typ) -> bool: 71 origin = get_origin(typ) 72 if origin is not None: 73 typ = origin 74 return typ is np.ndarray 75 76 def is_numpy_jaxtyping(typ) -> bool: 77 try: 78 origin = get_origin(typ) 79 if origin is not None: 80 typ = origin 81 return typ is not np.ndarray and issubclass(typ, np.ndarray) 82 except TypeError: 83 return False 84 85 def serialize_numpy_array(arg) -> str: 86 return f"{arg.varname}.tolist()" 87 88 def serialize_numpy_datetime(arg) -> str: 89 return f"{arg.varname}.item().isoformat()" 90 91 def deserialize_numpy_array(arg) -> str: 92 if is_bare_numpy_array(arg.type): 93 return f"numpy.array({arg.data})" 94 95 dtype = fullname(arg[1][0].type) 96 return f"numpy.array({arg.data}, dtype={dtype})" 97 98 def deserialize_numpy_jaxtyping_array(arg) -> str: 99 dtype = f"numpy.{arg.type.dtypes[-1]}" 100 return f"numpy.array({arg.data}, dtype={dtype})" 101 102 def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any: 103 if is_bare_numpy_array(typ): 104 return np.array(arg) 105 106 dtype = get_args(get_args(typ)[1])[0] 107 return np.array(arg, dtype=dtype) 108 109except ImportError: 110 encode_numpy = None 111 112 def is_numpy_scalar(typ) -> bool: 113 return False 114 115 def is_numpy_datetime(typ) -> bool: 116 return False 117 118 def serialize_numpy_scalar(arg) -> str: 119 return "" 120 121 def deserialize_numpy_scalar(arg): 122 return "" 123 124 def is_numpy_array(typ) -> bool: 125 return False 126 127 def is_numpy_jaxtyping(typ) -> bool: 128 return False 129 130 def serialize_numpy_array(arg) -> str: 131 return "" 132 133 def serialize_numpy_datetime(arg) -> str: 134 return "" 135 136 def deserialize_numpy_array(arg) -> str: 137 return "" 138 139 def deserialize_numpy_jaxtyping_array(arg) -> str: 140 return "" 141 142 def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any: 143 return arg
def
fullname(klass):
def
is_numpy_type(typ) -> bool:
def
is_numpy_available() -> bool:
def
encode_numpy(obj: Any):
def
is_bare_numpy_array(typ) -> bool:
38 def is_bare_numpy_array(typ) -> bool: 39 """ 40 Test if the type is `np.ndarray` or `npt.NDArray` without type args. 41 42 >>> import numpy as np 43 >>> import numpy.typing as npt 44 >>> is_bare_numpy_array(npt.NDArray[np.int64]) 45 False 46 >>> is_bare_numpy_array(npt.NDArray) 47 True 48 >>> is_bare_numpy_array(np.ndarray) 49 True 50 """ 51 return typ in (np.ndarray, npt.NDArray)
Test if the type is np.ndarray
or npt.NDArray
without type args.
>>> import numpy as np
>>> import numpy.typing as npt
>>> is_bare_numpy_array(npt.NDArray[np.int64])
False
>>> is_bare_numpy_array(npt.NDArray)
True
>>> is_bare_numpy_array(np.ndarray)
True
def
is_numpy_scalar(typ) -> bool:
def
is_numpy_datetime(typ) -> bool:
def
serialize_numpy_scalar(arg) -> str:
def
deserialize_numpy_scalar(arg: Any) -> str:
def
is_numpy_array(typ) -> bool:
def
is_numpy_jaxtyping(typ) -> bool:
def
serialize_numpy_array(arg) -> str:
def
serialize_numpy_datetime(arg) -> str:
def
deserialize_numpy_array(arg) -> str:
def
deserialize_numpy_jaxtyping_array(arg) -> str:
def
deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any: