Created
December 1, 2020 00:15
-
-
Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
Pydantic with Numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Generic, TypeVar | |
import numpy as np | |
from pydantic.fields import ModelField | |
JSON_ENCODERS = { | |
np.ndarray: lambda arr: arr.tolist() | |
} | |
DType = TypeVar('DType') | |
class TypedArray(np.ndarray, Generic[DType]): | |
"""Wrapper class for numpy arrays that stores and validates type information. | |
This can be used in place of a numpy array, but when used in a pydantic BaseModel | |
or with pydantic.validate_arguments, its dtype will be *coerced* at runtime to the | |
declared type. | |
""" | |
@classmethod | |
def __get_validators__(cls): | |
yield cls.validate | |
@classmethod | |
def validate(cls, val, field: ModelField): | |
dtype_field = field.sub_fields[0] | |
actual_dtype = dtype_field.type_.__args__[0] | |
# If numpy cannot create an array with the request dtype, an error will be raised | |
# and correctly bubbled up. | |
np_array = np.array(val, dtype=actual_dtype) | |
return np_array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing_extensions import Literal | |
import numpy as np | |
import pydantic | |
import pytest | |
from .pydantic import TypedArray, JSON_ENCODERS | |
class Model(pydantic.BaseModel): | |
x: TypedArray[Literal['float32']] | |
class Config: | |
json_encoders = JSON_ENCODERS | |
class InvalidModel(pydantic.BaseModel): | |
x: TypedArray[Literal['asdfasdf']] | |
def test_array(): | |
model = Model(x=[1, 2]) | |
assert(isinstance(model.x, np.ndarray)) | |
assert(model.x.dtype == np.dtype('float32')) | |
# but I think this will not work yet | |
with pytest.raises(pydantic.error_wrappers.ValidationError): | |
Model(x='asdfa') | |
def test_invalid(): | |
with pytest.raises(pydantic.error_wrappers.ValidationError): | |
InvalidModel(x='boom') | |
def test_serde(): | |
model = Model(x=[1, 2]) | |
assert(model.json() == '{"x": [1.0, 2.0]}') | |
# Using validate_arguments here will _coerce_ an array into the correct dtype | |
@pydantic.validate_arguments | |
def square(arr: TypedArray[Literal['float32']]) -> np.array: | |
return arr ** 2 | |
def test_validation_decorator(): | |
x = np.array([1, 2, 3], dtype='int32') | |
y = square(x) | |
assert(y.dtype == np.dtype('float32')) |
@nstasino where does this error happen?
@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy
@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy
In the meantime, some great improvements have been added by @caniko that you might enjoy.
😄 good job. Will checkout your project.
Yes, this looks neat, thanks for picking it up! Sorry to have ghosted y'all, but I changed jobs and no longer use this stack as much. Glad to see that you're carrying the torch.
FOSS never sleeps ;)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@cheind This gives me a
TypeError: Too few arguments for NDArray
inNDArray[Literal["float32"]]
python 3.9.7
pydantic 1.9.0
numpy 1.22.2