Created
November 28, 2023 20:18
-
-
Save thomasaarholt/e8aff4502bfe34861f8b54d81f411f64 to your computer and use it in GitHub Desktop.
Polars type serialization in a pydantic model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
import polars as pl | |
import json | |
from pydantic import BaseModel | |
original_expression = (pl.col("foo") * 2) == pl.col("bar") | |
class Column(BaseModel, arbitrary_types_allowed=True): | |
"This model won't be serializable" | |
name: str | |
expression: pl.Expr | |
col = Column(name="foo", expression=original_expression) | |
# foo.model_dump_json() # <- this will throw a serialization error | |
# just to demonstrate, you can use json.loads to deserialize the polars json, into a dict | |
expr_json = original_expression.meta.write_json() # polars json version of the expr | |
expr_dict = json.loads( | |
expr_json | |
) # python deserialized version of the expr. This is a dict of strings. | |
class TempColumn(BaseModel): | |
"This model WILL be serializable" | |
name: str | |
expression: dict[str, Any] | |
temp_col = TempColumn( | |
name=col.name, | |
expression=json.loads( | |
col.expression.meta.write_json() | |
), # notice I used the entries from the instance of Column | |
) | |
temp_col_json = temp_col.model_dump_json() | |
# read it back in | |
temp_col = TempColumn.model_validate_json(temp_col_json) | |
expression_read_from_pydantic = pl.Expr.from_json(json.dumps(temp_col.expression)) | |
# expressions can't be equated, since == would compare two columns, but we can check that their string repr is the same | |
assert str(expression_read_from_pydantic) == str(original_expression) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment