Skip to content

Instantly share code, notes, and snippets.

@danielhfrank
Created December 1, 2020 00:15
Show Gist options
  • Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
Pydantic with Numpy
from typing import Generic, TypeVar
import numpy as np
from pydantic.fields import ModelField
JSON_ENCODERS = {
np.ndarray: lambda arr: arr.tolist()
}
DType = TypeVar('DType')
class TypedArray(np.ndarray, Generic[DType]):
"""Wrapper class for numpy arrays that stores and validates type information.
This can be used in place of a numpy array, but when used in a pydantic BaseModel
or with pydantic.validate_arguments, its dtype will be *coerced* at runtime to the
declared type.
"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, val, field: ModelField):
dtype_field = field.sub_fields[0]
actual_dtype = dtype_field.type_.__args__[0]
# If numpy cannot create an array with the request dtype, an error will be raised
# and correctly bubbled up.
np_array = np.array(val, dtype=actual_dtype)
return np_array
from typing_extensions import Literal
import numpy as np
import pydantic
import pytest
from .pydantic import TypedArray, JSON_ENCODERS
class Model(pydantic.BaseModel):
x: TypedArray[Literal['float32']]
class Config:
json_encoders = JSON_ENCODERS
class InvalidModel(pydantic.BaseModel):
x: TypedArray[Literal['asdfasdf']]
def test_array():
model = Model(x=[1, 2])
assert(isinstance(model.x, np.ndarray))
assert(model.x.dtype == np.dtype('float32'))
# but I think this will not work yet
with pytest.raises(pydantic.error_wrappers.ValidationError):
Model(x='asdfa')
def test_invalid():
with pytest.raises(pydantic.error_wrappers.ValidationError):
InvalidModel(x='boom')
def test_serde():
model = Model(x=[1, 2])
assert(model.json() == '{"x": [1.0, 2.0]}')
# Using validate_arguments here will _coerce_ an array into the correct dtype
@pydantic.validate_arguments
def square(arr: TypedArray[Literal['float32']]) -> np.array:
return arr ** 2
def test_validation_decorator():
x = np.array([1, 2, 3], dtype='int32')
y = square(x)
assert(y.dtype == np.dtype('float32'))
@nstasino
Copy link

@cheind This gives me a

TypeError: Too few arguments for NDArray in NDArray[Literal["float32"]]

python 3.9.7
pydantic 1.9.0
numpy 1.22.2

@cheind
Copy link

cheind commented Feb 24, 2022

@nstasino where does this error happen?

@cheind
Copy link

cheind commented Mar 25, 2022

@cheind
Copy link

cheind commented Aug 16, 2022

@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy

In the meantime, some great improvements have been added by @caniko that you might enjoy.

@danieljfarrell
Copy link

😄 good job. Will checkout your project.

@danielhfrank
Copy link
Author

Yes, this looks neat, thanks for picking it up! Sorry to have ghosted y'all, but I changed jobs and no longer use this stack as much. Glad to see that you're carrying the torch.

@caniko
Copy link

caniko commented Aug 16, 2022

FOSS never sleeps ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment