-
-
Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
from typing import Generic, TypeVar | |
import numpy as np | |
from pydantic.fields import ModelField | |
JSON_ENCODERS = { | |
np.ndarray: lambda arr: arr.tolist() | |
} | |
DType = TypeVar('DType') | |
class TypedArray(np.ndarray, Generic[DType]): | |
"""Wrapper class for numpy arrays that stores and validates type information. | |
This can be used in place of a numpy array, but when used in a pydantic BaseModel | |
or with pydantic.validate_arguments, its dtype will be *coerced* at runtime to the | |
declared type. | |
""" | |
@classmethod | |
def __get_validators__(cls): | |
yield cls.validate | |
@classmethod | |
def validate(cls, val, field: ModelField): | |
dtype_field = field.sub_fields[0] | |
actual_dtype = dtype_field.type_.__args__[0] | |
# If numpy cannot create an array with the request dtype, an error will be raised | |
# and correctly bubbled up. | |
np_array = np.array(val, dtype=actual_dtype) | |
return np_array |
from typing_extensions import Literal | |
import numpy as np | |
import pydantic | |
import pytest | |
from .pydantic import TypedArray, JSON_ENCODERS | |
class Model(pydantic.BaseModel): | |
x: TypedArray[Literal['float32']] | |
class Config: | |
json_encoders = JSON_ENCODERS | |
class InvalidModel(pydantic.BaseModel): | |
x: TypedArray[Literal['asdfasdf']] | |
def test_array(): | |
model = Model(x=[1, 2]) | |
assert(isinstance(model.x, np.ndarray)) | |
assert(model.x.dtype == np.dtype('float32')) | |
# but I think this will not work yet | |
with pytest.raises(pydantic.error_wrappers.ValidationError): | |
Model(x='asdfa') | |
def test_invalid(): | |
with pytest.raises(pydantic.error_wrappers.ValidationError): | |
InvalidModel(x='boom') | |
def test_serde(): | |
model = Model(x=[1, 2]) | |
assert(model.json() == '{"x": [1.0, 2.0]}') | |
# Using validate_arguments here will _coerce_ an array into the correct dtype | |
@pydantic.validate_arguments | |
def square(arr: TypedArray[Literal['float32']]) -> np.array: | |
return arr ** 2 | |
def test_validation_decorator(): | |
x = np.array([1, 2, 3], dtype='int32') | |
y = square(x) | |
assert(y.dtype == np.dtype('float32')) |
Hey,
great gist! I extended your idea in the following ways:
- loading data from NPY, NPZ files
- skipping Generics argument (i.e. Any) to allow numpy itself determine the dtype.
import numpy as np
from pydantic.fields import ModelField
from typing import Any, Generic, Mapping, TypeVar, Literal, get_origin, Optional
from numpy.testing import assert_allclose
DType = TypeVar("DType")
class NPFileDesc(BaseModel):
path: FilePath
key: Optional[str]
class NDArray(np.ndarray, Generic[DType]):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, val: Any, field: ModelField):
if isinstance(val, Mapping):
val = NPFileDesc(**val)
if isinstance(val, NPFileDesc):
val: NPFileDesc
path = val.path
key = val.key
if path.suffix.lower() not in [".npz", ".npy"]:
raise ValueError("Expected npz or npy file.")
content = np.load(str(val.path.absolute()))
if path.suffix.lower() == ".npz":
key = key or content.files[0]
data = content[key]
else:
data = content
else:
data = val
if field.sub_fields is not None:
dtype_field = field.sub_fields[0]
if not get_origin(dtype_field.type_) == Literal:
raise ValueError("DType field is expected to be Literal[str]")
actual_dtype_lit = dtype_field.type_.__args__[0]
return np.array(data, dtype=actual_dtype_lit)
else:
return np.array(data)
def test_numpy_field(tmpdir):
class MySettings(config.BaseSettings):
K: NDArray[Literal["float32"]]
# Directly specify values
cfg = MySettings(K=[1, 2])
assert_allclose(cfg.K, [1.0, 2.0])
assert cfg.K.dtype == np.float32
cfg = MySettings(K=np.eye(2))
assert_allclose(cfg.K, [[1.0, 0], [0.0, 1.0]])
assert cfg.K.dtype == np.float32
# Load from npy
np.save(Path(tmpdir) / "data.npy", np.arange(5))
cfg = MySettings(K={"path": Path(tmpdir) / "data.npy"})
assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
assert cfg.K.dtype == np.float32
np.save(Path(tmpdir) / "data.npy", np.arange(5))
cfg = MySettings(K=NPFileDesc(path=Path(tmpdir) / "data.npy"))
assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
assert cfg.K.dtype == np.float32
np.savez(Path(tmpdir) / "data.npz", values=np.arange(5))
cfg = MySettings(K={"path": Path(tmpdir) / "data.npz", "key": "values"})
assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
assert cfg.K.dtype == np.float32
with pytest.raises(pydantic.ValidationError):
MySettings(K={"path": Path(tmpdir) / "nosuchfile.npz", "key": "values"})
with pytest.raises(pydantic.ValidationError):
MySettings(K={"path": Path(tmpdir) / "nosuchfile.npy", "key": "nosuchkey"})
with pytest.raises(pydantic.ValidationError):
MySettings(K={"path": Path(tmpdir) / "nosuchfile.npy"})
with pytest.raises(pydantic.ValidationError):
MySettings(K="absc")
# Not specifying a dtype will use numpy default dtype resolver
class MySettingsNoGeneric(config.BaseSettings):
K: NDArray
cfg = MySettingsNoGeneric(K=[1, 2])
assert_allclose(cfg.K, [1, 2])
assert cfg.K.dtype == int
# Optional test
class MySettingsOptional(config.BaseSettings):
K: Optional[NDArray]
cfg = MySettingsOptional()
really love it!
Very cool gist!
I've added a test including a recursive model. json()
invocation requires having the json_encoders defined at the calling class even if it doesn't have any ndarray field.
I had to replace few ValidationError
with FileNotFoundError
as the latter were raising first.
from pathlib import Path
from typing import (
Any,
Dict,
Generic,
Literal,
Mapping,
Optional,
TypeVar,
get_origin,
)
import numpy as np
import pytest
from numpy.testing import assert_allclose
from pydantic import BaseModel, BaseSettings, ValidationError
from pydantic.fields import ModelField
DType = TypeVar("DType")
class NPFileDesc(BaseModel):
path: Path
key: Optional[str]
class NDArray(np.ndarray, Generic[DType]):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, val: Any, field: ModelField):
if isinstance(val, Mapping):
val = NPFileDesc(**val)
if isinstance(val, NPFileDesc):
val: NPFileDesc
path = val.path
key = val.key
if path.suffix.lower() not in [".npz", ".npy"]:
raise ValueError("Expected npz or npy file.")
content = np.load(str(val.path.absolute()))
if path.suffix.lower() == ".npz":
key = key or content.files[0]
data = content[key]
else:
data = content
else:
data = val
if field.sub_fields is not None:
dtype_field = field.sub_fields[0]
if not get_origin(dtype_field.type_) == Literal:
raise ValueError("DType field is expected to be Literal[str]")
actual_dtype_lit = dtype_field.type_.__args__[0]
return np.array(data, dtype=actual_dtype_lit)
else:
return np.array(data)
JSON_ENCODERS = {np.ndarray: lambda arr: arr.tolist()}
def test_numpy_field(tmpdir):
class MySettings(BaseSettings):
K: NDArray[Literal["float32"]]
class Config:
json_encoders = {np.ndarray: lambda arr: arr.tolist()}
# Directly specify values
cfg = MySettings(K=[1, 2])
assert_allclose(cfg.K, [1.0, 2.0])
assert cfg.K.dtype == np.float32
assert cfg.json()
cfg = MySettings(K=np.eye(2))
assert_allclose(cfg.K, [[1.0, 0], [0.0, 1.0]])
assert cfg.K.dtype == np.float32
# Load from npy
np.save(Path(tmpdir) / "data.npy", np.arange(5))
cfg = MySettings(K={"path": Path(tmpdir) / "data.npy"})
assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
assert cfg.K.dtype == np.float32
np.save(Path(tmpdir) / "data.npy", np.arange(5))
cfg = MySettings(K=NPFileDesc(path=Path(tmpdir) / "data.npy"))
assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
assert cfg.K.dtype == np.float32
np.savez(Path(tmpdir) / "data.npz", values=np.arange(5))
cfg = MySettings(K={"path": Path(tmpdir) / "data.npz", "key": "values"})
assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
assert cfg.K.dtype == np.float32
with pytest.raises(FileNotFoundError):
MySettings(
K={"path": Path(tmpdir) / "nosuchfile.npz", "key": "values"}
)
with pytest.raises(FileNotFoundError):
MySettings(
K={"path": Path(tmpdir) / "nosuchfile.npy", "key": "nosuchkey"}
)
with pytest.raises(FileNotFoundError):
MySettings(K={"path": Path(tmpdir) / "nosuchfile.npy"})
with pytest.raises(ValidationError):
MySettings(K="absc")
# Not specifying a dtype will use numpy default dtype resolver
class MySettingsNoGeneric(BaseSettings):
K: NDArray
class Config:
json_encoders = {np.ndarray: lambda arr: arr.tolist()}
cfg = MySettingsNoGeneric(K=[1, 2])
assert_allclose(cfg.K, [1, 2])
assert cfg.K.dtype == int
assert cfg.json()
# Optional test
class MySettingsOptional(BaseSettings):
K: Optional[NDArray]
cfg = MySettingsOptional()
class MyModelField(BaseModel):
K: NDArray[Literal["float32"]]
class Config:
json_encoders = JSON_ENCODERS
model_field = MyModelField(K=[1.0, 2.0])
assert model_field.json()
class MyModel(BaseModel):
L: Dict[str, MyModelField]
class Config:
json_encoders = JSON_ENCODERS
model = MyModel(L={"a": MyModelField(K=[1.0, 2.0])})
assert model.L["a"].K.dtype == np.dtype("float32")
assert model.json()
@cheind This gives me a
TypeError: Too few arguments for NDArray
in NDArray[Literal["float32"]]
python 3.9.7
pydantic 1.9.0
numpy 1.22.2
@nstasino where does this error happen?
@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy
@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy
In the meantime, some great improvements have been added by @caniko that you might enjoy.
😄 good job. Will checkout your project.
Yes, this looks neat, thanks for picking it up! Sorry to have ghosted y'all, but I changed jobs and no longer use this stack as much. Glad to see that you're carrying the torch.
FOSS never sleeps ;)
Oops! I didn’t actually run that python code before commenting. Thanks for the suggestion.
I’m away from my laptop again, but I think this will work.