Last active
April 17, 2023 08:36
-
-
Save jeffzi/8ccb1c8c216b0e66ad1fc9ce6b2fc1e2 to your computer and use it in GitHub Desktop.
Draft of Pandera dtypes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import inspect | |
import itertools | |
from abc import ABCMeta | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, Tuple, Type, TypeVar, Union | |
import numpy as np | |
import pandas as pd | |
######################################################################################## | |
# Base dtype hierarchy | |
######################################################################################## | |
@dataclass(frozen=True) | |
class DataType: | |
def __call__(self, obj: Any): | |
"""Coerces object to the dtype.""" | |
return self.coerce(obj) | |
def coerce(self, obj: Any): | |
"""Coerces object to the dtype.""" | |
raise NotImplementedError | |
@dataclass(frozen=True) | |
class Number(DataType): | |
continuous: bool | |
exact: bool | |
@dataclass(frozen=True) | |
class _PhysicalNumber(Number): | |
bits: int | |
@dataclass(frozen=True) | |
class Integer(_PhysicalNumber): | |
bits: int = 64 | |
signed: bool = True | |
continuous: bool = False | |
exact: bool = True | |
def __str__(self) -> str: | |
return ("" if self.signed else "u") + f"int{self.bits}" | |
@dataclass(frozen=True) | |
class SignedInteger(Integer): | |
signed: bool = True | |
def __init__(self) -> None: # Forbid editing default values | |
pass | |
@dataclass(frozen=True) | |
class UnsignedInteger(Integer): | |
signed: bool = False | |
def __init__(self) -> None: # Forbid editing default values | |
pass | |
class Int64(SignedInteger): | |
bits: int = 64 | |
class UInt64(UnsignedInteger): | |
bits: int = 64 | |
@dataclass(frozen=True) | |
class Category(DataType): | |
categories: Tuple[Any] = None # immutable sequence to ensure safe hash | |
ordered: bool = False | |
def __post_init__(self) -> "Category": | |
categories = tuple(self.categories) if self.categories is not None else None | |
# bypass frozen dataclass | |
# see https://docs.python.org/3/library/dataclasses.html#frozen-instances | |
object.__setattr__(self, "categories", categories) | |
class String(DataType): | |
pass | |
######################################################################################## | |
# base backend | |
######################################################################################## | |
AnyDtype = TypeVar("AnyDtype", bound="DataType") | |
Converter = Callable[[Any], AnyDtype] | |
@dataclass | |
class DtypeRegistry: | |
by_type: Callable[[Any], AnyDtype] | |
by_value: Dict[Any, Type[AnyDtype]] | |
class Backend(ABCMeta): | |
"""Base backend. | |
Keep a registry of concrete backends (currently, only pandas). The registry serves | |
to lookup how to convert from one dtype to another. 2 lookup modes are supported: | |
* By value: Direct mapping between am object and a dtype. | |
* By type: Map a type to a function with signature Callable[[Any], AnyDtype] that | |
takes an instance of the type and converts it to a dtype. (Relies on singledispatch.) | |
""" | |
_registry: Dict["Backend", DtypeRegistry] = {} | |
def __new__(mcs, name, bases, namespace, **kwargs): | |
namespace["_name"] = kwargs.pop("backend_name") | |
cls = super().__new__(mcs, name, bases, namespace, **kwargs) | |
@functools.singledispatch | |
def _dtype(obj: Any) -> AnyDtype: | |
raise ValueError(f"data type '{obj}' not understood") | |
mcs._registry[cls] = DtypeRegistry(by_type=_dtype, by_value={}) | |
return cls | |
def _register_converter(cls, converter: Converter, *keys: Type[Any]) -> None: | |
for key in keys: | |
cls._registry[cls].by_type.register(key, converter) | |
def _register_lookup( | |
cls, dtype: Union[DataType, Type[DataType]], *keys: Any | |
) -> None: | |
value = dtype() if inspect.isclass(dtype) else dtype | |
for key in keys: | |
cls._registry[cls].by_value[key] = value | |
def register(cls, *keys: Any, dtype: Any = None): | |
"""Return a decorator if dtype is null.""" | |
if dtype is None: | |
def _wrapper(dtype): | |
if inspect.isclass(dtype): | |
cls._register_lookup(dtype, *keys) | |
elif inspect.isfunction(dtype): | |
cls._register_converter(dtype, *keys) | |
else: | |
raise ValueError( | |
f"{cls.__name__}.register can only decorate " | |
+ "a class or a function." | |
) | |
return dtype | |
return _wrapper | |
else: | |
cls._register_lookup(dtype, *keys) | |
def dtype(cls, obj: Any) -> AnyDtype: | |
backend_registry = cls._registry[cls] | |
dtype = backend_registry.by_value.get(obj) | |
if dtype is not None: | |
return dtype | |
try: | |
return backend_registry.by_type(obj) | |
except KeyError: | |
raise ValueError(f"Data type '{obj}' not understood.") from None | |
@property | |
def name(cls): | |
return cls._name | |
######################################################################################## | |
# Pandas backend and dtype hierarchy | |
######################################################################################## | |
class PandasBackend(metaclass=Backend, backend_name="pandas"): | |
@classmethod | |
def dtype(cls, obj: Any) -> AnyDtype: | |
try: | |
return Backend.dtype(cls, obj) | |
except ValueError: | |
# let pandas transform any acceptable value into a numpy or pandas dtype. | |
np_or_pd_dtype = pd.api.types.pandas_dtype(obj) | |
return Backend.dtype(cls, np_or_pd_dtype) | |
PandasObject = TypeVar("PandasObject", pd.Series, pd.Index, pd.DataFrame) | |
@dataclass(frozen=True) | |
class PandasDtype: | |
native_dtype: Any = None # maybe should be moved to DataType? | |
def coerce(self, obj: PandasObject) -> PandasObject: | |
return obj.astype(self.native_dtype) | |
@dataclass(frozen=True) | |
class PandasInt(PandasDtype, Integer): | |
nullable: bool = True | |
def __post_init__(self) -> "PandasDtype": | |
object.__setattr__(self, "native_dtype", pd.api.types.pandas_dtype(str(self))) | |
def __str__(self) -> str: | |
alias = super().__str__() | |
capitals = 0 | |
if self.nullable: | |
capitals += 1 | |
if not self.signed: | |
capitals += 1 | |
return alias[:capitals].upper() + alias[capitals:] | |
def _register_pandasInts(all_bits): | |
bools = (True, False) | |
for bits, signed, nullable in itertools.product(all_bits, bools, bools): | |
pandas_dtype = PandasInt(bits=bits, signed=signed, nullable=nullable) | |
PandasBackend.register(str(pandas_dtype), dtype=pandas_dtype) | |
if not nullable: | |
# register base Integer | |
pandera_dtype = Integer(bits, signed) | |
PandasBackend.register(pandera_dtype, dtype=pandas_dtype) | |
# register Integer subclass, e.g. Int32 | |
alias = str( # get capitalized alias... | |
PandasInt( | |
bits=pandas_dtype.bits, signed=pandas_dtype.signed, nullable=True | |
) | |
) | |
subint_cls = globals()[alias] # dont use globals in production | |
PandasBackend.register(subint_cls, dtype=pandas_dtype) | |
PandasBackend.register(subint_cls(), dtype=pandas_dtype) | |
# register numpy dtype | |
np_dtype = np.dtype(str(pandas_dtype)) | |
PandasBackend.register(np_dtype, dtype=pandas_dtype) # e.g dtype('int32') | |
PandasBackend.register(np_dtype.type, dtype=pandas_dtype) # e.g np.int32 | |
_BITS = [64] # add all bits in production | |
_register_pandasInts(_BITS) | |
@PandasBackend.register( | |
Category, Category(), pd.CategoricalDtype, pd.CategoricalDtype() | |
) | |
@dataclass(frozen=True) | |
class PandasCategory(PandasDtype, Category): | |
def __post_init__(self) -> "PandasDtype": | |
super().__post_init__() | |
object.__setattr__( | |
self, "native_dtype", pd.CategoricalDtype(self.categories, self.ordered) | |
) | |
@PandasBackend.register(Category, pd.CategoricalDtype) | |
def _to_pandas_category(cat: pd.CategoricalDtype): | |
return PandasCategory(cat.categories, cat.ordered) | |
@PandasBackend.register(String, pd.StringDtype, pd.StringDtype()) | |
class PandasString(PandasDtype, String): | |
native_type: pd.StringDtype | |
######################################################################################## | |
# Tests | |
######################################################################################## | |
assert ( | |
PandasBackend.dtype(Category) # by value | |
== PandasBackend.dtype(Category()) # by value | |
== PandasBackend.dtype(pd.CategoricalDtype) # by value | |
== PandasBackend.dtype(pd.CategoricalDtype()) # by value | |
== PandasCategory() | |
) | |
assert ( | |
PandasBackend.dtype(pd.CategoricalDtype(["a", "b"], ordered=True)) # by type | |
== PandasBackend.dtype(Category(["a", "b"], ordered=True)) # by type | |
== PandasCategory(["a", "b"], ordered=True) | |
) | |
assert ( | |
PandasBackend.dtype(String) # by value | |
== PandasBackend.dtype(pd.StringDtype) # by value | |
== PandasBackend.dtype(pd.StringDtype()) # by value | |
) | |
assert ( | |
PandasBackend.dtype(Int64) # by value | |
== PandasBackend.dtype(Int64()) # by value | |
== PandasBackend.dtype("int64") # by type (pd.api.types.pandas_dtype converts alias to numpy) | |
== PandasBackend.dtype(np.int64) # by value | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment