Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

_COLUMN_ATTR = "__dataframely_columns__"
_RULE_ATTR = "__dataframely_rules__"
_ATTR_TO_ALIAS = "__dataframely_attr_to_alias__"
_USE_ATTR_NAMES = "__dataframely_use_attribute_names__"

ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__"

Expand Down Expand Up @@ -82,10 +84,12 @@ class Metadata:

columns: dict[str, Column] = field(default_factory=dict)
rules: dict[str, RuleFactory] = field(default_factory=dict)
attr_to_alias: dict[str, str | None] = field(default_factory=dict)

def update(self, other: Self) -> None:
self.columns.update(other.columns)
self.rules.update(other.rules)
self.attr_to_alias.update(other.attr_to_alias)


class SchemaMeta(ABCMeta):
Expand All @@ -95,13 +99,31 @@ def __new__(
bases: tuple[type[object], ...],
namespace: dict[str, Any],
*args: Any,
use_attribute_names: bool | None = None,
**kwargs: Any,
) -> SchemaMeta:
result = Metadata()

# Inherit use_attribute_names from parent if not explicitly set
inherited_use_attr_names = False
for base in bases:
result.update(mcs._get_metadata_recursively(base))
result.update(mcs._get_metadata(namespace))
if hasattr(base, _USE_ATTR_NAMES):
inherited_use_attr_names = getattr(base, _USE_ATTR_NAMES)

Comment thread
gab23r marked this conversation as resolved.
Outdated
# Explicit setting takes precedence over inheritance
final_use_attr_names = (
use_attribute_names
if use_attribute_names is not None
else inherited_use_attr_names
)

Comment thread
gab23r marked this conversation as resolved.
result.update(
mcs._get_metadata(namespace, use_attribute_names=final_use_attr_names)
)
namespace[_COLUMN_ATTR] = result.columns
namespace[_ATTR_TO_ALIAS] = result.attr_to_alias
namespace[_USE_ATTR_NAMES] = final_use_attr_names
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)

# Assign rules retroactively as we only encounter rule factories in the result
Expand Down Expand Up @@ -185,25 +207,34 @@ def __getattribute__(cls, name: str) -> Any:
val = super().__getattribute__(name)
# Dynamically set the name of the column if it is a `Column` instance.
if isinstance(val, Column):
val._name = val.alias or name
use_attr_names = getattr(cls, _USE_ATTR_NAMES, False)
val._name = name if use_attr_names else (val.alias or name)
return val

@staticmethod
def _get_metadata_recursively(kls: type[object]) -> Metadata:
result = Metadata()
for base in kls.__bases__:
result.update(SchemaMeta._get_metadata_recursively(base))
result.update(SchemaMeta._get_metadata(kls.__dict__)) # type: ignore
use_attr_names = getattr(kls, _USE_ATTR_NAMES, False)
result.update(
SchemaMeta._get_metadata(kls.__dict__, use_attribute_names=use_attr_names) # type: ignore
)
return result

@staticmethod
def _get_metadata(source: dict[str, Any]) -> Metadata:
def _get_metadata(
source: dict[str, Any], *, use_attribute_names: bool = False
) -> Metadata:
result = Metadata()
for attr, value in {
k: v for k, v in source.items() if not k.startswith("__")
}.items():
if isinstance(value, Column):
result.columns[value.alias or attr] = value
# When use_attribute_names=True, use attr as key; otherwise use alias or attr
col_name = attr if use_attribute_names else (value.alias or attr)
result.columns[col_name] = value
result.attr_to_alias[attr] = value.alias
if isinstance(value, RuleFactory):
# We must ensure that custom rules do not clash with internal rules.
if attr == "primary_key":
Expand Down Expand Up @@ -258,3 +289,17 @@ def _validation_rules(cls, *, with_cast: bool) -> dict[str, Rule]:
@classmethod
def _schema_validation_rules(cls) -> dict[str, Rule]:
return getattr(cls, _RULE_ATTR)

@classmethod
def _alias_mapping(cls) -> dict[str, str]:
"""Mapping from aliases to column identifier (attribute)."""
return {
alias: attr
for attr, alias in getattr(cls, _ATTR_TO_ALIAS).items()
if alias is not None and alias != attr
}

@classmethod
def _uses_attribute_names(cls) -> bool:
"""Check if the schema uses attribute names instead of aliases."""
return getattr(cls, _USE_ATTR_NAMES, False)
Comment thread
gab23r marked this conversation as resolved.
Outdated
26 changes: 26 additions & 0 deletions dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,32 @@ def cast(
return lf.collect() # type: ignore
return lf # type: ignore

@overload
@classmethod
def undo_aliases(cls, df: pl.DataFrame, /) -> pl.DataFrame: ...

@overload
@classmethod
def undo_aliases(cls, df: pl.LazyFrame, /) -> pl.LazyFrame: ...

@classmethod
def undo_aliases(
cls, df: pl.DataFrame | pl.LazyFrame, /
) -> pl.DataFrame | pl.LazyFrame:
"""Rename columns from their alias names to their attribute names.

This method renames columns that have aliases defined, mapping from the
alias (e.g., "price ($)") to the attribute name (e.g., "price").

Args:
df: The data frame whose columns should be renamed.

Returns:
The data frame with columns renamed from aliases to attribute names.
Columns without aliases are left unchanged.
"""
Comment thread
gab23r marked this conversation as resolved.
return df.rename(cls._alias_mapping())
Comment thread
gab23r marked this conversation as resolved.
Outdated

# --------------------------------- SERIALIZATION -------------------------------- #

@classmethod
Expand Down
66 changes: 66 additions & 0 deletions tests/columns/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,69 @@ def test_alias_unset() -> None:
no_alias_col = dy.Int32()
assert no_alias_col.alias is None
assert no_alias_col.name == ""


def test_alias_use_attribute_names() -> None:
class MySchema1(dy.Schema, use_attribute_names=True):
price = dy.Int64(alias="price ($)")

class MySchema2(MySchema1, use_attribute_names=False):
price2 = dy.Int64(alias="price2 ($)")

class MySchema3(MySchema2):
price3 = dy.Int64(alias="price3 ($)")

class MySchema4(MySchema3, use_attribute_names=True):
price4 = dy.Int64(alias="price4 ($)")

class MySchema5(MySchema4):
price5 = dy.Int64(alias="price5 ($)")

assert MySchema5.column_names() == [
"price",
"price2 ($)",
"price3 ($)",
"price4",
"price5",
]
Comment thread
gab23r marked this conversation as resolved.


def test_alias_mapping() -> None:
class MySchema(dy.Schema):
price = dy.Int64(alias="price ($)")
production_rank = dy.Int64(alias="Production rank")
no_alias = dy.Int64()

# _alias_mapping returns alias -> attribute name mapping
assert MySchema._alias_mapping() == {
"price ($)": "price",
"Production rank": "production_rank",
}


def test_alias_mapping_empty() -> None:
class NoAliasSchema(dy.Schema):
a = dy.Int64()
b = dy.String()

# No aliases means empty mapping
assert NoAliasSchema._alias_mapping() == {}


def test_undo_aliases() -> None:
class MySchema(dy.Schema):
price = dy.Int64(alias="price ($)")
production_rank = dy.Int64(alias="Production rank")

df = pl.DataFrame({"price ($)": [100], "Production rank": [1]})
result = MySchema.undo_aliases(df)
assert result.columns == ["price", "production_rank"]


def test_undo_aliases_lazy() -> None:
class MySchema(dy.Schema):
price = dy.Int64(alias="price ($)")

lf = pl.LazyFrame({"price ($)": [100]})
result = MySchema.undo_aliases(lf).collect()
assert result.columns == ["price"]
Loading