"""Grouped :class:`DataFrame` helper."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Union
from ...expressions.column import Column, col
from ...logical import operators
from ...logical.plan import LogicalPlan
from ..core.dataframe import DataFrame
[docs]
@dataclass(frozen=True)
class GroupedDataFrame:
"""Represents a :class:`DataFrame` grouped by one or more columns.
This is returned by :class:`DataFrame`.group_by() and provides aggregation methods.
"""
plan: LogicalPlan
keys: tuple[Column, ...]
parent: DataFrame
[docs]
def agg(self, *aggregations: Union[Column, str, Dict[str, str]]) -> DataFrame:
"""Apply aggregation functions to the grouped data.
Args:
*aggregations: One or more aggregation expressions. Can be:
- :class:`Column` expressions (e.g., sum(col("amount")))
- String column names (e.g., "amount" - defaults to sum())
- Dictionary mapping column names to aggregation functions
(e.g., {"amount": "sum", "price": "avg"})
Returns:
:class:`DataFrame` with aggregated results
Raises:
ValueError: If no aggregations are provided or if invalid
aggregation expressions are used
Example:
>>> from moltres import connect, col
>>> from moltres.expressions import functions as F
>>> from moltres.table.schema import column
>>> db = connect("sqlite:///:memory:")
>>> db.create_table("sales", [column("category", "TEXT"), column("amount", "REAL"), column("price", "REAL")]).collect()
>>> from moltres.io.records import :class:`Records`
>>> :class:`Records`(_data=[{"category": "A", "amount": 100.0, "price": 10.0}, {"category": "A", "amount": 200.0, "price": 20.0}, {"category": "B", "amount": 150.0, "price": 15.0}], _database=db).insert_into("sales")
>>> # Using :class:`Column` expressions
>>> df = db.table("sales").select()
>>> result = df.group_by("category").agg(F.sum(col("amount")).alias("total"), F.avg(col("price")).alias("avg_price"))
>>> results = result.collect()
>>> len(results)
2
>>> results[0]["total"]
300.0
>>> # Using string column names (defaults to sum)
>>> result2 = df.group_by("category").agg("amount")
>>> results2 = result2.collect()
>>> results2[0]["amount"]
300.0
>>> # Using dictionary syntax
>>> result3 = df.group_by("category").agg({"amount": "sum", "price": "avg"})
>>> results3 = result3.collect()
>>> results3[0]["amount"]
300.0
>>> db.close()
"""
if not aggregations:
raise ValueError("agg requires at least one aggregation expression")
# Normalize all aggregations to Column expressions
from ..helpers.groupby_helpers import normalize_aggregations, validate_aggregation
# Allow empty aggregations for special cases like dropDuplicates
normalized_aggs = normalize_aggregations(
aggregations, alias_with_column_name=True, allow_empty=True
)
# If no aggregations, just return grouping columns (for dropDuplicates)
if not normalized_aggs:
# Select only grouping columns and apply distinct
grouping_cols = list(self.keys)
plan = operators.project(self.plan, tuple(grouping_cols))
plan = operators.distinct(plan) # type: ignore[assignment]
else:
normalized = tuple(validate_aggregation(expr) for expr in normalized_aggs)
plan = operators.aggregate(self.plan, self.keys, normalized) # type: ignore[assignment]
return DataFrame(plan=plan, database=self.parent.database)
@staticmethod
def _create_aggregation_from_string(column_name: str, func_name: str) -> Column:
"""Create an aggregation :class:`Column` from a column name and function name string.
Args:
column_name: Name of the column to aggregate
func_name: Name of the aggregation function (e.g., "sum", "avg", "min", "max", "count")
Returns:
:class:`Column` expression for the aggregation
Raises:
ValueError: If the function name is not recognized
"""
from ..helpers.groupby_helpers import create_aggregation_from_string
return create_aggregation_from_string(column_name, func_name)
[docs]
def pivot(
self, pivot_col: str, values: Optional[Sequence[str]] = None
) -> "PivotedGroupedDataFrame":
"""Pivot the grouped data on a column.
Args:
pivot_col: :class:`Column` to pivot on (values become column headers)
values: Optional list of specific values to pivot (if None, must be provided later or discovered)
Returns:
PivotedGroupedDataFrame that can be aggregated
Example:
>>> df.group_by("category").pivot("status").agg("amount")
>>> df.group_by("category").pivot("status", values=["active", "inactive"]).agg("amount")
"""
return PivotedGroupedDataFrame(
plan=self.plan,
keys=self.keys,
pivot_column=pivot_col,
pivot_values=tuple(values) if values else None,
parent=self.parent,
)
@staticmethod
def _validate_aggregation(expr: Column) -> Column:
"""Validate that an expression is a valid aggregation.
Args:
expr: :class:`Column` expression to validate
Returns:
The validated column expression
Raises:
ValueError: If the expression is not a valid aggregation
"""
if not expr.op.startswith("agg_"):
raise ValueError(
"Aggregation expressions must be created with moltres aggregate helpers "
"(e.g., sum(), avg(), count(), min(), max())"
)
return expr
@dataclass(frozen=True)
class PivotedGroupedDataFrame:
"""Represents a :class:`DataFrame` grouped by columns with a pivot operation applied.
This is returned by :class:`GroupedDataFrame`.pivot() and provides aggregation methods
that will create pivoted columns.
"""
plan: LogicalPlan
keys: tuple[Column, ...]
pivot_column: str
pivot_values: Optional[tuple[str, ...]]
parent: DataFrame
def agg(self, *aggregations: Union[Column, str, Dict[str, str]]) -> DataFrame:
"""Apply aggregation functions to the pivoted grouped data.
Args:
*aggregations: One or more aggregation expressions. Can be:
- :class:`Column` expressions (e.g., sum(col("amount")))
- String column names (e.g., "amount" - defaults to sum())
- Dictionary mapping column names to aggregation functions
(e.g., {"amount": "sum", "price": "avg"})
Returns:
:class:`DataFrame` with pivoted aggregated results
Raises:
ValueError: If no aggregations are provided or if invalid
aggregation expressions are used
Example:
>>> from moltres import col
>>> from moltres.expressions import functions as F
>>> # Using string column name
>>> df.group_by("category").pivot("status").agg("amount")
>>> # Using :class:`Column` expression
>>> df.group_by("category").pivot("status").agg(F.sum(col("amount")))
>>> # With specific pivot values
>>> df.group_by("category").pivot("status", values=["active", "inactive"]).agg("amount")
"""
if not aggregations:
raise ValueError("agg requires at least one aggregation expression")
# Normalize all aggregations to Column expressions
from ..helpers.groupby_helpers import (
normalize_aggregations,
validate_aggregation,
extract_value_column,
extract_agg_func,
)
normalized_aggs = normalize_aggregations(aggregations, alias_with_column_name=False)
# For pivoted grouped data, we can only aggregate one column at a time
# (PySpark behavior - pivot with multiple aggregations requires different syntax)
if len(normalized_aggs) > 1:
raise ValueError(
"Pivoted grouped aggregation supports only one aggregation expression. "
"Multiple aggregations are not supported with pivot."
)
agg_expr = normalized_aggs[0]
validate_aggregation(agg_expr)
# Extract the value column from the aggregation
# For sum(col("amount")), we need "amount"
value_column = extract_value_column(agg_expr)
# Extract the aggregation function name
agg_func = extract_agg_func(agg_expr)
# If pivot_values is not provided, infer them from the data (PySpark behavior)
pivot_values = self.pivot_values
if pivot_values is None:
# Query distinct values from the pivot column
distinct_df = DataFrame(plan=self.plan, database=self.parent.database)
distinct_df = distinct_df.select(col(self.pivot_column)).distinct()
distinct_rows = distinct_df.collect()
pivot_values = tuple(
str(row[self.pivot_column])
for row in distinct_rows
if row[self.pivot_column] is not None
)
if not pivot_values:
raise ValueError(
f"No distinct values found in pivot column '{self.pivot_column}'. "
"Please provide pivot_values explicitly."
)
# Create a GroupedPivot logical plan
plan = operators.grouped_pivot(
self.plan,
grouping=self.keys,
pivot_column=self.pivot_column,
value_column=value_column,
agg_func=agg_func,
pivot_values=pivot_values,
)
return DataFrame(plan=plan, database=self.parent.database)
@staticmethod
def _extract_value_column(agg_expr: Column) -> str:
"""Extract the column name from an aggregation expression.
Args:
agg_expr: Aggregation :class:`Column` expression (e.g., sum(col("amount")))
Returns:
:class:`Column` name string (e.g., "amount")
Raises:
ValueError: If the column cannot be extracted
"""
from ..helpers.groupby_helpers import extract_value_column
return extract_value_column(agg_expr)
@staticmethod
def _extract_agg_func(agg_expr: Column) -> str:
"""Extract the aggregation function name from an aggregation expression.
Args:
agg_expr: Aggregation :class:`Column` expression (e.g., sum(col("amount")))
Returns:
Aggregation function name (e.g., "sum")
"""
from ..helpers.groupby_helpers import extract_agg_func
return extract_agg_func(agg_expr)
@staticmethod
def _create_aggregation_from_string(column_name: str, func_name: str) -> Column:
"""Create an aggregation :class:`Column` from a column name and function name string.
Args:
column_name: Name of the column to aggregate
func_name: Name of the aggregation function (e.g., "sum", "avg", "min", "max", "count")
Returns:
:class:`Column` expression for the aggregation
Raises:
ValueError: If the function name is not recognized
"""
from ..helpers.groupby_helpers import create_aggregation_from_string
return create_aggregation_from_string(column_name, func_name)
@staticmethod
def _validate_aggregation(expr: Column) -> Column:
"""Validate that an expression is a valid aggregation.
Args:
expr: :class:`Column` expression to validate
Returns:
The validated column expression
Raises:
ValueError: If the expression is not a valid aggregation
"""
from ..helpers.groupby_helpers import validate_aggregation
return validate_aggregation(expr)