Skip to content

Instantly share code, notes, and snippets.

@danielhfrank
Created December 1, 2020 00:15
Show Gist options
  • Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
Save danielhfrank/00e6b8556eed73fb4053450e602d2434 to your computer and use it in GitHub Desktop.
Pydantic with Numpy
from typing import Generic, TypeVar
import numpy as np
from pydantic.fields import ModelField
JSON_ENCODERS = {
np.ndarray: lambda arr: arr.tolist()
}
DType = TypeVar('DType')
class TypedArray(np.ndarray, Generic[DType]):
"""Wrapper class for numpy arrays that stores and validates type information.
This can be used in place of a numpy array, but when used in a pydantic BaseModel
or with pydantic.validate_arguments, its dtype will be *coerced* at runtime to the
declared type.
"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, val, field: ModelField):
dtype_field = field.sub_fields[0]
actual_dtype = dtype_field.type_.__args__[0]
# If numpy cannot create an array with the request dtype, an error will be raised
# and correctly bubbled up.
np_array = np.array(val, dtype=actual_dtype)
return np_array
from typing_extensions import Literal
import numpy as np
import pydantic
import pytest
from .pydantic import TypedArray, JSON_ENCODERS
class Model(pydantic.BaseModel):
x: TypedArray[Literal['float32']]
class Config:
json_encoders = JSON_ENCODERS
class InvalidModel(pydantic.BaseModel):
x: TypedArray[Literal['asdfasdf']]
def test_array():
model = Model(x=[1, 2])
assert(isinstance(model.x, np.ndarray))
assert(model.x.dtype == np.dtype('float32'))
# but I think this will not work yet
with pytest.raises(pydantic.error_wrappers.ValidationError):
Model(x='asdfa')
def test_invalid():
with pytest.raises(pydantic.error_wrappers.ValidationError):
InvalidModel(x='boom')
def test_serde():
model = Model(x=[1, 2])
assert(model.json() == '{"x": [1.0, 2.0]}')
# Using validate_arguments here will _coerce_ an array into the correct dtype
@pydantic.validate_arguments
def square(arr: TypedArray[Literal['float32']]) -> np.array:
return arr ** 2
def test_validation_decorator():
x = np.array([1, 2, 3], dtype='int32')
y = square(x)
assert(y.dtype == np.dtype('float32'))
@tearf001
Copy link

really love it!

@daudrain
Copy link

Very cool gist!

I've added a test including a recursive model. json() invocation requires having the json_encoders defined at the calling class even if it doesn't have any ndarray field.

I had to replace few ValidationError with FileNotFoundError as the latter were raising first.

from pathlib import Path
from typing import (
    Any,
    Dict,
    Generic,
    Literal,
    Mapping,
    Optional,
    TypeVar,
    get_origin,
)

import numpy as np
import pytest
from numpy.testing import assert_allclose
from pydantic import BaseModel, BaseSettings, ValidationError
from pydantic.fields import ModelField

DType = TypeVar("DType")


class NPFileDesc(BaseModel):
    path: Path
    key: Optional[str]


class NDArray(np.ndarray, Generic[DType]):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, val: Any, field: ModelField):
        if isinstance(val, Mapping):
            val = NPFileDesc(**val)
        if isinstance(val, NPFileDesc):
            val: NPFileDesc
            path = val.path
            key = val.key
            if path.suffix.lower() not in [".npz", ".npy"]:
                raise ValueError("Expected npz or npy file.")

            content = np.load(str(val.path.absolute()))
            if path.suffix.lower() == ".npz":
                key = key or content.files[0]
                data = content[key]
            else:
                data = content
        else:
            data = val
        if field.sub_fields is not None:
            dtype_field = field.sub_fields[0]
            if not get_origin(dtype_field.type_) == Literal:
                raise ValueError("DType field is expected to be Literal[str]")
            actual_dtype_lit = dtype_field.type_.__args__[0]
            return np.array(data, dtype=actual_dtype_lit)
        else:
            return np.array(data)


JSON_ENCODERS = {np.ndarray: lambda arr: arr.tolist()}


def test_numpy_field(tmpdir):
    class MySettings(BaseSettings):
        K: NDArray[Literal["float32"]]

        class Config:
            json_encoders = {np.ndarray: lambda arr: arr.tolist()}

    # Directly specify values
    cfg = MySettings(K=[1, 2])
    assert_allclose(cfg.K, [1.0, 2.0])
    assert cfg.K.dtype == np.float32
    assert cfg.json()

    cfg = MySettings(K=np.eye(2))
    assert_allclose(cfg.K, [[1.0, 0], [0.0, 1.0]])
    assert cfg.K.dtype == np.float32

    # Load from npy
    np.save(Path(tmpdir) / "data.npy", np.arange(5))
    cfg = MySettings(K={"path": Path(tmpdir) / "data.npy"})
    assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
    assert cfg.K.dtype == np.float32

    np.save(Path(tmpdir) / "data.npy", np.arange(5))
    cfg = MySettings(K=NPFileDesc(path=Path(tmpdir) / "data.npy"))
    assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
    assert cfg.K.dtype == np.float32

    np.savez(Path(tmpdir) / "data.npz", values=np.arange(5))
    cfg = MySettings(K={"path": Path(tmpdir) / "data.npz", "key": "values"})
    assert_allclose(cfg.K, [0.0, 1.0, 2.0, 3.0, 4.0])
    assert cfg.K.dtype == np.float32

    with pytest.raises(FileNotFoundError):
        MySettings(
            K={"path": Path(tmpdir) / "nosuchfile.npz", "key": "values"}
        )

    with pytest.raises(FileNotFoundError):
        MySettings(
            K={"path": Path(tmpdir) / "nosuchfile.npy", "key": "nosuchkey"}
        )

    with pytest.raises(FileNotFoundError):
        MySettings(K={"path": Path(tmpdir) / "nosuchfile.npy"})

    with pytest.raises(ValidationError):
        MySettings(K="absc")

    # Not specifying a dtype will use numpy default dtype resolver

    class MySettingsNoGeneric(BaseSettings):
        K: NDArray

        class Config:
            json_encoders = {np.ndarray: lambda arr: arr.tolist()}

    cfg = MySettingsNoGeneric(K=[1, 2])
    assert_allclose(cfg.K, [1, 2])
    assert cfg.K.dtype == int

    assert cfg.json()

    # Optional test

    class MySettingsOptional(BaseSettings):
        K: Optional[NDArray]

    cfg = MySettingsOptional()

    class MyModelField(BaseModel):
        K: NDArray[Literal["float32"]]

        class Config:
            json_encoders = JSON_ENCODERS

    model_field = MyModelField(K=[1.0, 2.0])
    assert model_field.json()

    class MyModel(BaseModel):
        L: Dict[str, MyModelField]

        class Config:
            json_encoders = JSON_ENCODERS

    model = MyModel(L={"a": MyModelField(K=[1.0, 2.0])})
    assert model.L["a"].K.dtype == np.dtype("float32")
    assert model.json()

@nstasino
Copy link

@cheind This gives me a

TypeError: Too few arguments for NDArray in NDArray[Literal["float32"]]

python 3.9.7
pydantic 1.9.0
numpy 1.22.2

@cheind
Copy link

cheind commented Feb 24, 2022

@nstasino where does this error happen?

@cheind
Copy link

cheind commented Mar 25, 2022

@cheind
Copy link

cheind commented Aug 16, 2022

@nstasino, @daudrain, @danielhfrank, @danieljfarrell I've updated the code to support numpy>1.22 here https://github.com/cheind/pydantic-numpy

In the meantime, some great improvements have been added by @caniko that you might enjoy.

@danieljfarrell
Copy link

😄 good job. Will checkout your project.

@danielhfrank
Copy link
Author

Yes, this looks neat, thanks for picking it up! Sorry to have ghosted y'all, but I changed jobs and no longer use this stack as much. Glad to see that you're carrying the torch.

@caniko
Copy link

caniko commented Aug 16, 2022

FOSS never sleeps ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment