Edit on GitHub

serde.numpy

  1from collections.abc import Callable
  2from typing import Any, Optional
  3
  4from serde.compat import get_args, get_origin
  5
  6
  7def fullname(klass):
  8    module = klass.__module__
  9    if module == "builtins":
 10        return klass.__qualname__  # avoid outputs like 'builtins.str'
 11    return module + "." + klass.__qualname__
 12
 13
 14def is_numpy_type(typ) -> bool:
 15    return is_bare_numpy_array(typ) or is_numpy_scalar(typ) or is_numpy_array(typ)
 16
 17
 18def is_numpy_available() -> bool:
 19    return encode_numpy is not None
 20
 21
 22try:
 23    import numpy as np
 24    import numpy.typing as npt
 25
 26    encode_numpy: Optional[Callable[[Any], Any]]
 27
 28    def encode_numpy(obj: Any):
 29        if isinstance(obj, np.ndarray):
 30            return obj.tolist()
 31        if isinstance(obj, np.datetime64):
 32            return obj.item().isoformat()
 33        if isinstance(obj, np.generic):
 34            return obj.item()
 35        raise TypeError(f"Object of type {fullname(type(obj))} is not serializable")
 36
 37    def is_bare_numpy_array(typ) -> bool:
 38        """
 39        Test if the type is `np.ndarray` or `npt.NDArray` without type args.
 40
 41        >>> import numpy as np
 42        >>> import numpy.typing as npt
 43        >>> is_bare_numpy_array(npt.NDArray[np.int64])
 44        False
 45        >>> is_bare_numpy_array(npt.NDArray)
 46        True
 47        >>> is_bare_numpy_array(np.ndarray)
 48        True
 49        """
 50        return typ in (np.ndarray, npt.NDArray)
 51
 52    def is_numpy_scalar(typ) -> bool:
 53        try:
 54            return issubclass(typ, np.generic)
 55        except TypeError:
 56            return False
 57
 58    def is_numpy_datetime(typ) -> bool:
 59        try:
 60            return issubclass(typ, np.datetime64)
 61        except TypeError:
 62            return False
 63
 64    def serialize_numpy_scalar(arg) -> str:
 65        return f"{arg.varname}.item()"
 66
 67    def deserialize_numpy_scalar(arg: Any) -> str:
 68        return f"{fullname(arg.type)}({arg.data})"
 69
 70    def is_numpy_array(typ) -> bool:
 71        origin = get_origin(typ)
 72        if origin is not None:
 73            typ = origin
 74        return typ is np.ndarray
 75
 76    def is_numpy_jaxtyping(typ) -> bool:
 77        try:
 78            origin = get_origin(typ)
 79            if origin is not None:
 80                typ = origin
 81            return typ is not np.ndarray and issubclass(typ, np.ndarray)
 82        except TypeError:
 83            return False
 84
 85    def serialize_numpy_array(arg) -> str:
 86        return f"{arg.varname}.tolist()"
 87
 88    def serialize_numpy_datetime(arg) -> str:
 89        return f"{arg.varname}.item().isoformat()"
 90
 91    def deserialize_numpy_array(arg) -> str:
 92        if is_bare_numpy_array(arg.type):
 93            return f"numpy.array({arg.data})"
 94
 95        dtype = fullname(arg[1][0].type)
 96        return f"numpy.array({arg.data}, dtype={dtype})"
 97
 98    def deserialize_numpy_jaxtyping_array(arg) -> str:
 99        dtype = f"numpy.{arg.type.dtypes[-1]}"
100        return f"numpy.array({arg.data}, dtype={dtype})"
101
102    def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
103        if is_bare_numpy_array(typ):
104            return np.array(arg)
105
106        dtype = get_args(get_args(typ)[1])[0]
107        return np.array(arg, dtype=dtype)
108
109except ImportError:
110    encode_numpy = None
111
112    def is_numpy_scalar(typ) -> bool:
113        return False
114
115    def is_numpy_datetime(typ) -> bool:
116        return False
117
118    def serialize_numpy_scalar(arg) -> str:
119        return ""
120
121    def deserialize_numpy_scalar(arg):
122        return ""
123
124    def is_numpy_array(typ) -> bool:
125        return False
126
127    def is_numpy_jaxtyping(typ) -> bool:
128        return False
129
130    def serialize_numpy_array(arg) -> str:
131        return ""
132
133    def serialize_numpy_datetime(arg) -> str:
134        return ""
135
136    def deserialize_numpy_array(arg) -> str:
137        return ""
138
139    def deserialize_numpy_jaxtyping_array(arg) -> str:
140        return ""
141
142    def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
143        return arg
def fullname(klass):
 8def fullname(klass):
 9    module = klass.__module__
10    if module == "builtins":
11        return klass.__qualname__  # avoid outputs like 'builtins.str'
12    return module + "." + klass.__qualname__
def is_numpy_type(typ) -> bool:
15def is_numpy_type(typ) -> bool:
16    return is_bare_numpy_array(typ) or is_numpy_scalar(typ) or is_numpy_array(typ)
def is_numpy_available() -> bool:
19def is_numpy_available() -> bool:
20    return encode_numpy is not None
def encode_numpy(obj: Any):
29    def encode_numpy(obj: Any):
30        if isinstance(obj, np.ndarray):
31            return obj.tolist()
32        if isinstance(obj, np.datetime64):
33            return obj.item().isoformat()
34        if isinstance(obj, np.generic):
35            return obj.item()
36        raise TypeError(f"Object of type {fullname(type(obj))} is not serializable")
def is_bare_numpy_array(typ) -> bool:
38    def is_bare_numpy_array(typ) -> bool:
39        """
40        Test if the type is `np.ndarray` or `npt.NDArray` without type args.
41
42        >>> import numpy as np
43        >>> import numpy.typing as npt
44        >>> is_bare_numpy_array(npt.NDArray[np.int64])
45        False
46        >>> is_bare_numpy_array(npt.NDArray)
47        True
48        >>> is_bare_numpy_array(np.ndarray)
49        True
50        """
51        return typ in (np.ndarray, npt.NDArray)

Test if the type is np.ndarray or npt.NDArray without type args.

>>> import numpy as np
>>> import numpy.typing as npt
>>> is_bare_numpy_array(npt.NDArray[np.int64])
False
>>> is_bare_numpy_array(npt.NDArray)
True
>>> is_bare_numpy_array(np.ndarray)
True
def is_numpy_scalar(typ) -> bool:
53    def is_numpy_scalar(typ) -> bool:
54        try:
55            return issubclass(typ, np.generic)
56        except TypeError:
57            return False
def is_numpy_datetime(typ) -> bool:
59    def is_numpy_datetime(typ) -> bool:
60        try:
61            return issubclass(typ, np.datetime64)
62        except TypeError:
63            return False
def serialize_numpy_scalar(arg) -> str:
65    def serialize_numpy_scalar(arg) -> str:
66        return f"{arg.varname}.item()"
def deserialize_numpy_scalar(arg: Any) -> str:
68    def deserialize_numpy_scalar(arg: Any) -> str:
69        return f"{fullname(arg.type)}({arg.data})"
def is_numpy_array(typ) -> bool:
71    def is_numpy_array(typ) -> bool:
72        origin = get_origin(typ)
73        if origin is not None:
74            typ = origin
75        return typ is np.ndarray
def is_numpy_jaxtyping(typ) -> bool:
77    def is_numpy_jaxtyping(typ) -> bool:
78        try:
79            origin = get_origin(typ)
80            if origin is not None:
81                typ = origin
82            return typ is not np.ndarray and issubclass(typ, np.ndarray)
83        except TypeError:
84            return False
def serialize_numpy_array(arg) -> str:
86    def serialize_numpy_array(arg) -> str:
87        return f"{arg.varname}.tolist()"
def serialize_numpy_datetime(arg) -> str:
89    def serialize_numpy_datetime(arg) -> str:
90        return f"{arg.varname}.item().isoformat()"
def deserialize_numpy_array(arg) -> str:
92    def deserialize_numpy_array(arg) -> str:
93        if is_bare_numpy_array(arg.type):
94            return f"numpy.array({arg.data})"
95
96        dtype = fullname(arg[1][0].type)
97        return f"numpy.array({arg.data}, dtype={dtype})"
def deserialize_numpy_jaxtyping_array(arg) -> str:
 99    def deserialize_numpy_jaxtyping_array(arg) -> str:
100        dtype = f"numpy.{arg.type.dtypes[-1]}"
101        return f"numpy.array({arg.data}, dtype={dtype})"
def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
103    def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
104        if is_bare_numpy_array(typ):
105            return np.array(arg)
106
107        dtype = get_args(get_args(typ)[1])[0]
108        return np.array(arg, dtype=dtype)