-
-
Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
from typing import Generic, TypeVar | |
import numpy as np | |
from pydantic.fields import ModelField | |
JSON_ENCODERS = { | |
np.ndarray: lambda arr: arr.tolist() | |
} | |
DType = TypeVar('DType') | |
class TypedArray(np.ndarray, Generic[DType]): | |
"""Wrapper class for numpy arrays that stores and validates type information. | |
This can be used in place of a numpy array, but when used in a pydantic BaseModel | |
or with pydantic.validate_arguments, its dtype will be *coerced* at runtime to the | |
declared type. | |
""" | |
@classmethod | |
def __get_validators__(cls): | |
yield cls.validate | |
@classmethod | |
def validate(cls, val, field: ModelField): | |
dtype_field = field.sub_fields[0] | |
actual_dtype = dtype_field.type_.__args__[0] | |
# If numpy cannot create an array with the request dtype, an error will be raised | |
# and correctly bubbled up. | |
np_array = np.array(val, dtype=actual_dtype) | |
return np_array |
from typing_extensions import Literal | |
import numpy as np | |
import pydantic | |
import pytest | |
from .pydantic import TypedArray, JSON_ENCODERS | |
class Model(pydantic.BaseModel): | |
x: TypedArray[Literal['float32']] | |
class Config: | |
json_encoders = JSON_ENCODERS | |
class InvalidModel(pydantic.BaseModel): | |
x: TypedArray[Literal['asdfasdf']] | |
def test_array(): | |
model = Model(x=[1, 2]) | |
assert(isinstance(model.x, np.ndarray)) | |
assert(model.x.dtype == np.dtype('float32')) | |
# but I think this will not work yet | |
with pytest.raises(pydantic.error_wrappers.ValidationError): | |
Model(x='asdfa') | |
def test_invalid(): | |
with pytest.raises(pydantic.error_wrappers.ValidationError): | |
InvalidModel(x='boom') | |
def test_serde(): | |
model = Model(x=[1, 2]) | |
assert(model.json() == '{"x": [1.0, 2.0]}') | |
# Using validate_arguments here will _coerce_ an array into the correct dtype | |
@pydantic.validate_arguments | |
def square(arr: TypedArray[Literal['float32']]) -> np.array: | |
return arr ** 2 | |
def test_validation_decorator(): | |
x = np.array([1, 2, 3], dtype='int32') | |
y = square(x) | |
assert(y.dtype == np.dtype('float32')) |
@cheind This gives me a
TypeError: Too few arguments for NDArray
in NDArray[Literal["float32"]]
python 3.9.7
pydantic 1.9.0
numpy 1.22.2
@nstasino where does this error happen?
@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy
@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy
In the meantime, some great improvements have been added by @caniko that you might enjoy.
😄 good job. Will checkout your project.
Yes, this looks neat, thanks for picking it up! Sorry to have ghosted y'all, but I changed jobs and no longer use this stack as much. Glad to see that you're carrying the torch.
FOSS never sleeps ;)
Very cool gist!
I've added a test including a recursive model.
json()
invocation requires having the json_encoders defined at the calling class even if it doesn't have any ndarray field.I had to replace few
ValidationError
withFileNotFoundError
as the latter were raising first.