from __future__ import annotations
import functools
import itertools
import json
import warnings
from collections.abc import Callable
from inspect import signature
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import polars as pl
import genome_kit as gk
from genome_kit._optional import require_polars
from .gk_structs import CURRENT_VERSION, CellType, ColumnInfo, GkDfType, GkDfVersion
from .registry import GK_TO_GKDF_TYPE, get_registry
def _map_batches_safe(fn: Callable) -> Callable:
"""Helper function to wrap a UDF and run safely with polars map_batches.
Polars has a bug in map_batches that incorrectly forwards the return_dtype argument
to the UDF. See https://github.com/pola-rs/polars/issues/24840.
Args:
fn: The user defined function to wrap.
Returns:
A wrapped version of the UDF that can be safely used with map_batches.
"""
sig = signature(fn)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
accepted = sig.parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted}
return fn(*args, **filtered_kwargs)
return wrapper
def _detect_gk_cols(
lf: pl.LazyFrame, infer_schema_length: int = 100
) -> dict[str, ColumnInfo]:
"""Detect columns in the LazyFrame that contains GenomeKit objects.
Args:
lf: The LazyFrame to inspect.
infer_schema_length: The number of rows to use for schema inference when
detecting GenomeKit columns.
Returns:
A dictionary mapping column names to the ColumnInfo dataclass containing the
GkDfType and CellType for the column.
"""
pl = require_polars()
lf_cols = lf.collect_schema().names()
target_cols = {}
# datatype inference done on first n=infer_schema_length rows. Follows inference
# logic from Polars DataFrames when rows are provided.
# see https://github.com/pola-rs/polars/blob/1cd236c60c01572c5ec6fdd252d8b20218d7b440/py-polars/src/polars/dataframe/frame.py#L248-L251
head = lf.head(infer_schema_length).collect()
for col in lf_cols:
# remove nulls for type inference, list/scalar cols depend on first non-null value
vals = head.get_column(col).drop_nulls() # removes scalar nulls
# column only contains null values in the first infer_schema_length rows
if len(vals) == 0:
warnings.warn(
f"Column {col} contains only null values in the first {infer_schema_length} rows, "
"unable to infer type for serialization. Please ensure this column "
"contains non-null values for accurate serialization."
)
continue
first = vals[0]
head_types = {type(v) for v in vals}
if isinstance(first, list):
if head_types != {list}:
raise ValueError(
f"Column {col} contains mixed data types: {list(itertools.islice(head_types, 3))}.\n"
"Please ensure all cells are the same type before serialization."
)
cell_type = CellType.LIST
col_types = {type(item) for v in vals for item in v if item is not None}
else:
cell_type = CellType.SCALAR
col_types = set(vals.map_elements(type, return_dtype=pl.Object))
if len(col_types) != 1:
raise ValueError(
f"Column {col} contains mixed data types: {list(itertools.islice(col_types, 3))}.\n"
"Please ensure all cells are the same type before serialization."
)
col_type = GK_TO_GKDF_TYPE.get(col_types.pop(), None)
if col_type is None:
# column is not a genomekit type, so no serialization needed
continue
target_cols[col] = ColumnInfo(cell_type=cell_type, gkdf_type=col_type)
return target_cols
def _list_serializer(
serializer: Callable[[pl.Series], pl.Series], return_dtype: Any
) -> Callable[[pl.Series], pl.Series]:
"""Helper function to convert a serializer to accept lists of GenomeKit objects.
Args:
serializer: A serializer function for a series of GenomeKit objects
return_dtype: The return data type for the serialized series
Returns:
A serializer function for a series of lists of GenomeKit objects.
"""
pl = require_polars()
def _serialize_list(s: pl.Series) -> pl.Series:
return pl.Series(
name=s.name,
values=[
serializer(pl.Series(values=l)).to_list() if l is not None else None
for l in s
],
dtype=return_dtype,
)
return _serialize_list
def _init_gk_annotations(
lf: pl.LazyFrame, target_cols: dict[str, dict]
) -> list[gk.Genome]:
"""Initialize GenomeKit annotations for all unique genomes in the LazyFrame.
Prevents race conditions when opening dganno files during polars operations.
Objects are returned in a list to keep weak references alive.
Args:
lf: The LazyFrame containing the serialized GenomeKit objects.
target_cols: A dictionary mapping column names to their column information.
Each value is a dictionary representation of the ColumnInfo dataclass.
Returns:
A list of initialized gene tables for the unique genomes in the LazyFrame.
"""
pl = require_polars()
def genome_str_field(col_info: dict) -> str:
gkdf_type = col_info["gkdf_type"]
if gkdf_type == GkDfType.GENOME:
return "genome_name"
elif gkdf_type in (GkDfType.INTERVAL, GkDfType.VARIANT):
return "refg"
else:
return "anno"
anno_strong_refs = []
# extract genome_str field from every column
genomes_exprs = []
genomes_list_exprs = []
for c in target_cols.keys():
genome_field = genome_str_field(target_cols[c])
if target_cols[c]["cell_type"] == CellType.SCALAR:
genomes_exprs.append(pl.col(c).struct.field(genome_field))
else:
genomes_list_exprs.append(pl.col(c).explode().struct.field(genome_field))
# expressions to extract genome_str must be run separately since exploded lists
# may have more rows than the original dataframe
plans = []
if genomes_exprs:
plans.append(
lf.select(
pl.concat_list(genomes_exprs)
.explode()
.drop_nulls()
.unique()
.alias("genome_str")
)
)
if genomes_list_exprs:
plans.append(
lf.select(
pl.concat(genomes_list_exprs)
.explode()
.drop_nulls()
.unique()
.alias("genome_str")
)
)
genomes = pl.concat(plans).unique().collect()["genome_str"].to_list()
# warms annotations for all unique annotation genomes in the file.
# all annotations available for serialization are contained in dganno file
for genome_str in genomes:
genome = gk.Genome(genome_str)
try:
anno_strong_refs.append(genome.genes)
except ValueError:
# reference genomes don't have annotations
continue
return anno_strong_refs
def _validate_gkdf_metadata(metadata: dict[str, str]) -> None:
"""Validate the parquet metadata for a gkdf parquet file.
Args:
metadata: The parquet metadata to validate.
"""
# gkdf version
metadata_version = metadata.get("gkdf_version")
version = GkDfVersion(metadata_version) if metadata_version is not None else None
if version != CURRENT_VERSION:
raise ValueError(
f"Invalid or missing gkdf_version in Parquet metadata, unable to deserialize GenomeKit objects. "
f"Expected GkDfVersion {CURRENT_VERSION}, but found {version}."
)
# target cols
if metadata.get("target_cols") is None:
raise ValueError(
"Missing target_cols in Parquet metadata, unable to deserialize GenomeKit objects."
)
# gk version
gk_version = metadata.get("gk_version")
if gk_version is None:
raise ValueError("Missing gk_version in Parquet metadata.")
elif gk_version != gk.__version__:
warnings.warn(
f"Parquet file was written with GenomeKit version {gk_version}, but current version is {gk.__version__}. "
"Deserializing GenomeKit objects may not be consistent across versions."
)
def _list_deserializer(
deserializer: Callable[[pl.Series], pl.Series],
) -> Callable[[pl.Series], pl.Series]:
"""Helper function to convert a deserializer to accept lists of serialized GenomeKit objects.
Args:
deserializer: A deserializer function for a series of serialized GenomeKit objects
Returns:
A deserializer function for a series of lists of serialized GenomeKit objects.
"""
pl = require_polars()
def _deserialize_list(s: pl.Series) -> pl.Series:
return pl.Series(
name=s.name,
values=[
deserializer(pl.Series(values=l)).to_list() if l is not None else None
for l in s
],
dtype=pl.Object,
)
return _deserialize_list
def _deserialize_gk_cols(
lf: pl.LazyFrame, target_cols: dict[str, dict]
) -> pl.LazyFrame:
"""Deserialize columns containing GenomeKit objects.
Args:
lf: The LazyFrame containing the serialized GenomeKit objects.
target_cols: A dictionary mapping column names to their column information.
Each value is a dictionary representation of the ColumnInfo dataclass.
Returns:
A LazyFrame with deserialized GenomeKit objects in the target columns.
"""
pl = require_polars()
registry = get_registry()
def _build_deserialization_expr(col: str) -> pl.Expr:
col_info = target_cols[col] # dict representation of ColumnInfo
gkdf_type = col_info["gkdf_type"]
if col_info["cell_type"] == CellType.LIST:
deserializer = _list_deserializer(
registry[CURRENT_VERSION][gkdf_type].deserializer
)
else:
deserializer = registry[CURRENT_VERSION][gkdf_type].deserializer
return (
pl.col(col)
.map_batches(
_map_batches_safe(deserializer),
return_dtype=pl.Object,
)
.alias(col)
)
# with_columns_seq provides a 2x speedup here over with_columns
return lf.with_columns_seq(_build_deserialization_expr(col) for col in target_cols)
# TODO: add union of pd.DataFrame
[docs]
def write_parquet(
df: pl.DataFrame | pl.LazyFrame, path: str | Path, infer_schema_length: int = 100
) -> None:
"""Serialize a DataFrame with GenomeKit objects to a Parquet file.
Args:
df: A Polars DataFrame or LazyFrame with columns containing GenomeKit objects.
path: The file path to write the Parquet file to.
infer_schema_length: The number of rows to use for schema inference when writing the Parquet file.
"""
pl = require_polars()
path = Path(path)
if isinstance(df, pl.DataFrame):
df = df.lazy()
# mapping from column name to ColumnInfo dataclass
target_cols = _detect_gk_cols(df, infer_schema_length=infer_schema_length)
if not target_cols:
warnings.warn(
"No GenomeKit columns detected for serialization, writing DataFrame as is."
)
df.sink_parquet(path)
return
registry = get_registry()
def _build_serialization_expr(col: str) -> pl.Expr:
col_info = target_cols[col] # ColumnInfo dataclass
gkdf_type = col_info.gkdf_type
if col_info.cell_type == CellType.LIST:
return_dtype = pl.List(inner=registry[CURRENT_VERSION][gkdf_type].struct)
serializer = _list_serializer(
registry[CURRENT_VERSION][gkdf_type].serializer,
return_dtype=return_dtype,
)
else:
return_dtype = registry[CURRENT_VERSION][gkdf_type].struct
serializer = registry[CURRENT_VERSION][gkdf_type].serializer
return (
pl.col(col)
.map_batches(
_map_batches_safe(serializer),
return_dtype=return_dtype,
)
.alias(col)
)
df = df.with_columns(_build_serialization_expr(col) for col in target_cols)
# convert ColumnInfo dataclass to a serializable format
target_col_metadata = {col: target_cols[col].to_dict() for col in target_cols}
metadata = {
"gkdf_version": CURRENT_VERSION.value,
"gk_version": gk.__version__,
"target_cols": json.dumps(target_col_metadata),
}
df.sink_parquet(path, metadata=metadata)
[docs]
def read_parquet(path: str | Path, lazy: bool = False) -> pl.DataFrame | pl.LazyFrame:
"""Deserialize a Parquet file containing GenomeKit objects into a Polars DataFrame or LazyFrame.
Args:
path: The file path to read the Parquet file from.
lazy: If True, return a LazyFrame. Otherwise, return a DataFrame.
Returns:
A Polars DataFrame or LazyFrame with deserialized GenomeKit objects.
"""
pl = require_polars()
path = Path(path)
metadata = pl.read_parquet_metadata(path)
_validate_gkdf_metadata(metadata)
target_cols = json.loads(metadata.get("target_cols"))
lf = pl.scan_parquet(path)
# collect unique genome strings in the file and initialize, prevents race conditions
# on opening dganno files.
# genomes returned in dummy variable to keep weak reference alive for deserialization
_ = _init_gk_annotations(lf, target_cols)
lf = _deserialize_gk_cols(lf, target_cols)
return lf if lazy else lf.collect()