Skip to content

Instantly share code, notes, and snippets.

@jeffzi
Last active April 17, 2023 08:36
Show Gist options
  • Save jeffzi/8ccb1c8c216b0e66ad1fc9ce6b2fc1e2 to your computer and use it in GitHub Desktop.
Save jeffzi/8ccb1c8c216b0e66ad1fc9ce6b2fc1e2 to your computer and use it in GitHub Desktop.
Draft of Pandera dtypes
import functools
import inspect
import itertools
from abc import ABCMeta
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple, Type, TypeVar, Union
import numpy as np
import pandas as pd
########################################################################################
# Base dtype hierarchy
########################################################################################
@dataclass(frozen=True)
class DataType:
def __call__(self, obj: Any):
"""Coerces object to the dtype."""
return self.coerce(obj)
def coerce(self, obj: Any):
"""Coerces object to the dtype."""
raise NotImplementedError
@dataclass(frozen=True)
class Number(DataType):
continuous: bool
exact: bool
@dataclass(frozen=True)
class _PhysicalNumber(Number):
bits: int
@dataclass(frozen=True)
class Integer(_PhysicalNumber):
bits: int = 64
signed: bool = True
continuous: bool = False
exact: bool = True
def __str__(self) -> str:
return ("" if self.signed else "u") + f"int{self.bits}"
@dataclass(frozen=True)
class SignedInteger(Integer):
signed: bool = True
def __init__(self) -> None: # Forbid editing default values
pass
@dataclass(frozen=True)
class UnsignedInteger(Integer):
signed: bool = False
def __init__(self) -> None: # Forbid editing default values
pass
class Int64(SignedInteger):
bits: int = 64
class UInt64(UnsignedInteger):
bits: int = 64
@dataclass(frozen=True)
class Category(DataType):
categories: Tuple[Any] = None # immutable sequence to ensure safe hash
ordered: bool = False
def __post_init__(self) -> "Category":
categories = tuple(self.categories) if self.categories is not None else None
# bypass frozen dataclass
# see https://docs.python.org/3/library/dataclasses.html#frozen-instances
object.__setattr__(self, "categories", categories)
class String(DataType):
pass
########################################################################################
# base backend
########################################################################################
AnyDtype = TypeVar("AnyDtype", bound="DataType")
Converter = Callable[[Any], AnyDtype]
@dataclass
class DtypeRegistry:
by_type: Callable[[Any], AnyDtype]
by_value: Dict[Any, Type[AnyDtype]]
class Backend(ABCMeta):
"""Base backend.
Keep a registry of concrete backends (currently, only pandas). The registry serves
to lookup how to convert from one dtype to another. 2 lookup modes are supported:
* By value: Direct mapping between am object and a dtype.
* By type: Map a type to a function with signature Callable[[Any], AnyDtype] that
takes an instance of the type and converts it to a dtype. (Relies on singledispatch.)
"""
_registry: Dict["Backend", DtypeRegistry] = {}
def __new__(mcs, name, bases, namespace, **kwargs):
namespace["_name"] = kwargs.pop("backend_name")
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
@functools.singledispatch
def _dtype(obj: Any) -> AnyDtype:
raise ValueError(f"data type '{obj}' not understood")
mcs._registry[cls] = DtypeRegistry(by_type=_dtype, by_value={})
return cls
def _register_converter(cls, converter: Converter, *keys: Type[Any]) -> None:
for key in keys:
cls._registry[cls].by_type.register(key, converter)
def _register_lookup(
cls, dtype: Union[DataType, Type[DataType]], *keys: Any
) -> None:
value = dtype() if inspect.isclass(dtype) else dtype
for key in keys:
cls._registry[cls].by_value[key] = value
def register(cls, *keys: Any, dtype: Any = None):
"""Return a decorator if dtype is null."""
if dtype is None:
def _wrapper(dtype):
if inspect.isclass(dtype):
cls._register_lookup(dtype, *keys)
elif inspect.isfunction(dtype):
cls._register_converter(dtype, *keys)
else:
raise ValueError(
f"{cls.__name__}.register can only decorate "
+ "a class or a function."
)
return dtype
return _wrapper
else:
cls._register_lookup(dtype, *keys)
def dtype(cls, obj: Any) -> AnyDtype:
backend_registry = cls._registry[cls]
dtype = backend_registry.by_value.get(obj)
if dtype is not None:
return dtype
try:
return backend_registry.by_type(obj)
except KeyError:
raise ValueError(f"Data type '{obj}' not understood.") from None
@property
def name(cls):
return cls._name
########################################################################################
# Pandas backend and dtype hierarchy
########################################################################################
class PandasBackend(metaclass=Backend, backend_name="pandas"):
@classmethod
def dtype(cls, obj: Any) -> AnyDtype:
try:
return Backend.dtype(cls, obj)
except ValueError:
# let pandas transform any acceptable value into a numpy or pandas dtype.
np_or_pd_dtype = pd.api.types.pandas_dtype(obj)
return Backend.dtype(cls, np_or_pd_dtype)
PandasObject = TypeVar("PandasObject", pd.Series, pd.Index, pd.DataFrame)
@dataclass(frozen=True)
class PandasDtype:
native_dtype: Any = None # maybe should be moved to DataType?
def coerce(self, obj: PandasObject) -> PandasObject:
return obj.astype(self.native_dtype)
@dataclass(frozen=True)
class PandasInt(PandasDtype, Integer):
nullable: bool = True
def __post_init__(self) -> "PandasDtype":
object.__setattr__(self, "native_dtype", pd.api.types.pandas_dtype(str(self)))
def __str__(self) -> str:
alias = super().__str__()
capitals = 0
if self.nullable:
capitals += 1
if not self.signed:
capitals += 1
return alias[:capitals].upper() + alias[capitals:]
def _register_pandasInts(all_bits):
bools = (True, False)
for bits, signed, nullable in itertools.product(all_bits, bools, bools):
pandas_dtype = PandasInt(bits=bits, signed=signed, nullable=nullable)
PandasBackend.register(str(pandas_dtype), dtype=pandas_dtype)
if not nullable:
# register base Integer
pandera_dtype = Integer(bits, signed)
PandasBackend.register(pandera_dtype, dtype=pandas_dtype)
# register Integer subclass, e.g. Int32
alias = str( # get capitalized alias...
PandasInt(
bits=pandas_dtype.bits, signed=pandas_dtype.signed, nullable=True
)
)
subint_cls = globals()[alias] # dont use globals in production
PandasBackend.register(subint_cls, dtype=pandas_dtype)
PandasBackend.register(subint_cls(), dtype=pandas_dtype)
# register numpy dtype
np_dtype = np.dtype(str(pandas_dtype))
PandasBackend.register(np_dtype, dtype=pandas_dtype) # e.g dtype('int32')
PandasBackend.register(np_dtype.type, dtype=pandas_dtype) # e.g np.int32
_BITS = [64] # add all bits in production
_register_pandasInts(_BITS)
@PandasBackend.register(
Category, Category(), pd.CategoricalDtype, pd.CategoricalDtype()
)
@dataclass(frozen=True)
class PandasCategory(PandasDtype, Category):
def __post_init__(self) -> "PandasDtype":
super().__post_init__()
object.__setattr__(
self, "native_dtype", pd.CategoricalDtype(self.categories, self.ordered)
)
@PandasBackend.register(Category, pd.CategoricalDtype)
def _to_pandas_category(cat: pd.CategoricalDtype):
return PandasCategory(cat.categories, cat.ordered)
@PandasBackend.register(String, pd.StringDtype, pd.StringDtype())
class PandasString(PandasDtype, String):
native_type: pd.StringDtype
########################################################################################
# Tests
########################################################################################
assert (
PandasBackend.dtype(Category) # by value
== PandasBackend.dtype(Category()) # by value
== PandasBackend.dtype(pd.CategoricalDtype) # by value
== PandasBackend.dtype(pd.CategoricalDtype()) # by value
== PandasCategory()
)
assert (
PandasBackend.dtype(pd.CategoricalDtype(["a", "b"], ordered=True)) # by type
== PandasBackend.dtype(Category(["a", "b"], ordered=True)) # by type
== PandasCategory(["a", "b"], ordered=True)
)
assert (
PandasBackend.dtype(String) # by value
== PandasBackend.dtype(pd.StringDtype) # by value
== PandasBackend.dtype(pd.StringDtype()) # by value
)
assert (
PandasBackend.dtype(Int64) # by value
== PandasBackend.dtype(Int64()) # by value
== PandasBackend.dtype("int64") # by type (pd.api.types.pandas_dtype converts alias to numpy)
== PandasBackend.dtype(np.int64) # by value
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment