Edit on GitHub

serde.numpy

  1from collections.abc import Callable
  2from typing import Any, Optional
  3
  4from serde.compat import get_args, get_origin
  5
  6
  7def fullname(klass):
  8    module = klass.__module__
  9    if module == "builtins":
 10        return klass.__qualname__  # avoid outputs like 'builtins.str'
 11    return module + "." + klass.__qualname__
 12
 13
 14def is_numpy_type(typ) -> bool:
 15    return is_bare_numpy_array(typ) or is_numpy_scalar(typ) or is_numpy_array(typ)
 16
 17
 18def is_numpy_available() -> bool:
 19    return encode_numpy is not None
 20
 21
 22try:
 23    import numpy as np
 24    import numpy.typing as npt
 25
 26    encode_numpy: Optional[Callable[[Any], Any]]
 27
 28    def encode_numpy(obj: Any):
 29        if isinstance(obj, np.ndarray):
 30            return obj.tolist()
 31        if isinstance(obj, np.datetime64):
 32            return obj.item().isoformat()
 33        if isinstance(obj, np.generic):
 34            return obj.item()
 35        raise TypeError(f"Object of type {fullname(type(obj))} is not serializable")
 36
 37    def is_bare_numpy_array(typ) -> bool:
 38        """
 39        Test if the type is `np.ndarray` or `npt.NDArray` without type args.
 40
 41        >>> import numpy as np
 42        >>> import numpy.typing as npt
 43        >>> is_bare_numpy_array(npt.NDArray[np.int64])
 44        False
 45        >>> is_bare_numpy_array(npt.NDArray)
 46        True
 47        >>> is_bare_numpy_array(np.ndarray)
 48        True
 49        """
 50        return typ in (np.ndarray, npt.NDArray)
 51
 52    def is_numpy_scalar(typ) -> bool:
 53        try:
 54            return issubclass(typ, np.generic)
 55        except TypeError:
 56            return False
 57
 58    def is_numpy_datetime(typ) -> bool:
 59        try:
 60            return issubclass(typ, np.datetime64)
 61        except TypeError:
 62            return False
 63
 64    def serialize_numpy_scalar(arg) -> str:
 65        return f"{arg.varname}.item()"
 66
 67    def deserialize_numpy_scalar(arg: Any) -> str:
 68        return f"{fullname(arg.type)}({arg.data})"
 69
 70    def is_numpy_array(typ) -> bool:
 71        origin = get_origin(typ)
 72        if origin is not None:
 73            typ = origin
 74        return typ is np.ndarray
 75
 76    def serialize_numpy_array(arg) -> str:
 77        return f"{arg.varname}.tolist()"
 78
 79    def serialize_numpy_datetime(arg) -> str:
 80        return f"{arg.varname}.item().isoformat()"
 81
 82    def deserialize_numpy_array(arg) -> str:
 83        if is_bare_numpy_array(arg.type):
 84            return f"numpy.array({arg.data})"
 85
 86        dtype = fullname(arg[1][0].type)
 87        return f"numpy.array({arg.data}, dtype={dtype})"
 88
 89    def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
 90        if is_bare_numpy_array(typ):
 91            return np.array(arg)
 92
 93        dtype = get_args(get_args(typ)[1])[0]
 94        return np.array(arg, dtype=dtype)
 95
 96except ImportError:
 97    encode_numpy = None
 98
 99    def is_numpy_scalar(typ) -> bool:
100        return False
101
102    def is_numpy_datetime(typ) -> bool:
103        return False
104
105    def serialize_numpy_scalar(arg) -> str:
106        return ""
107
108    def deserialize_numpy_scalar(arg):
109        return ""
110
111    def is_numpy_array(typ) -> bool:
112        return False
113
114    def serialize_numpy_array(arg) -> str:
115        return ""
116
117    def serialize_numpy_datetime(arg) -> str:
118        return ""
119
120    def deserialize_numpy_array(arg) -> str:
121        return ""
122
123    def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
124        return arg
def fullname(klass):
 8def fullname(klass):
 9    module = klass.__module__
10    if module == "builtins":
11        return klass.__qualname__  # avoid outputs like 'builtins.str'
12    return module + "." + klass.__qualname__
def is_numpy_type(typ) -> bool:
15def is_numpy_type(typ) -> bool:
16    return is_bare_numpy_array(typ) or is_numpy_scalar(typ) or is_numpy_array(typ)
def is_numpy_available() -> bool:
19def is_numpy_available() -> bool:
20    return encode_numpy is not None
def encode_numpy(obj: Any):
29    def encode_numpy(obj: Any):
30        if isinstance(obj, np.ndarray):
31            return obj.tolist()
32        if isinstance(obj, np.datetime64):
33            return obj.item().isoformat()
34        if isinstance(obj, np.generic):
35            return obj.item()
36        raise TypeError(f"Object of type {fullname(type(obj))} is not serializable")
def is_bare_numpy_array(typ) -> bool:
38    def is_bare_numpy_array(typ) -> bool:
39        """
40        Test if the type is `np.ndarray` or `npt.NDArray` without type args.
41
42        >>> import numpy as np
43        >>> import numpy.typing as npt
44        >>> is_bare_numpy_array(npt.NDArray[np.int64])
45        False
46        >>> is_bare_numpy_array(npt.NDArray)
47        True
48        >>> is_bare_numpy_array(np.ndarray)
49        True
50        """
51        return typ in (np.ndarray, npt.NDArray)

Test if the type is np.ndarray or npt.NDArray without type args.

>>> import numpy as np
>>> import numpy.typing as npt
>>> is_bare_numpy_array(npt.NDArray[np.int64])
False
>>> is_bare_numpy_array(npt.NDArray)
True
>>> is_bare_numpy_array(np.ndarray)
True
def is_numpy_scalar(typ) -> bool:
53    def is_numpy_scalar(typ) -> bool:
54        try:
55            return issubclass(typ, np.generic)
56        except TypeError:
57            return False
def is_numpy_datetime(typ) -> bool:
59    def is_numpy_datetime(typ) -> bool:
60        try:
61            return issubclass(typ, np.datetime64)
62        except TypeError:
63            return False
def serialize_numpy_scalar(arg) -> str:
65    def serialize_numpy_scalar(arg) -> str:
66        return f"{arg.varname}.item()"
def deserialize_numpy_scalar(arg: Any) -> str:
68    def deserialize_numpy_scalar(arg: Any) -> str:
69        return f"{fullname(arg.type)}({arg.data})"
def is_numpy_array(typ) -> bool:
71    def is_numpy_array(typ) -> bool:
72        origin = get_origin(typ)
73        if origin is not None:
74            typ = origin
75        return typ is np.ndarray
def serialize_numpy_array(arg) -> str:
77    def serialize_numpy_array(arg) -> str:
78        return f"{arg.varname}.tolist()"
def serialize_numpy_datetime(arg) -> str:
80    def serialize_numpy_datetime(arg) -> str:
81        return f"{arg.varname}.item().isoformat()"
def deserialize_numpy_array(arg) -> str:
83    def deserialize_numpy_array(arg) -> str:
84        if is_bare_numpy_array(arg.type):
85            return f"numpy.array({arg.data})"
86
87        dtype = fullname(arg[1][0].type)
88        return f"numpy.array({arg.data}, dtype={dtype})"
def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
90    def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
91        if is_bare_numpy_array(typ):
92            return np.array(arg)
93
94        dtype = get_args(get_args(typ)[1])[0]
95        return np.array(arg, dtype=dtype)