Mise à jour de Monitor.py et autres scripts

This commit is contained in:
Debian
2025-07-23 10:46:27 +02:00
parent 7081418ce0
commit 7de3e0fb50
8604 changed files with 2789953 additions and 295 deletions

View File

@@ -0,0 +1,111 @@
from __future__ import annotations
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
LazyExpr,
LazyExprNamespace,
)
from narwhals._compliant.group_by import (
CompliantGroupBy,
DepthTrackingGroupBy,
EagerGroupBy,
LazyGroupBy,
)
from narwhals._compliant.namespace import (
CompliantNamespace,
DepthTrackingNamespace,
EagerNamespace,
LazyNamespace,
)
from narwhals._compliant.selectors import (
CompliantSelector,
CompliantSelectorNamespace,
EagerSelectorNamespace,
LazySelectorNamespace,
)
from narwhals._compliant.series import (
CompliantSeries,
EagerSeries,
EagerSeriesCatNamespace,
EagerSeriesDateTimeNamespace,
EagerSeriesListNamespace,
EagerSeriesNamespace,
EagerSeriesStringNamespace,
EagerSeriesStructNamespace,
)
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantSeriesOrNativeExprT_co,
CompliantSeriesT,
EagerDataFrameT,
EagerSeriesT,
EvalNames,
EvalSeries,
IntoCompliantExpr,
NativeFrameT_co,
NativeSeriesT_co,
)
from narwhals._compliant.when_then import (
CompliantThen,
CompliantWhen,
EagerWhen,
LazyThen,
LazyWhen,
)
from narwhals._compliant.window import WindowInputs
__all__ = [
"CompliantDataFrame",
"CompliantExpr",
"CompliantExprT",
"CompliantFrameT",
"CompliantGroupBy",
"CompliantLazyFrame",
"CompliantNamespace",
"CompliantSelector",
"CompliantSelectorNamespace",
"CompliantSeries",
"CompliantSeriesOrNativeExprT_co",
"CompliantSeriesT",
"CompliantThen",
"CompliantWhen",
"DepthTrackingExpr",
"DepthTrackingGroupBy",
"DepthTrackingNamespace",
"EagerDataFrame",
"EagerDataFrameT",
"EagerExpr",
"EagerGroupBy",
"EagerNamespace",
"EagerSelectorNamespace",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
"EagerSeriesT",
"EagerWhen",
"EvalNames",
"EvalSeries",
"IntoCompliantExpr",
"LazyExpr",
"LazyExprNamespace",
"LazyGroupBy",
"LazyNamespace",
"LazySelectorNamespace",
"LazyThen",
"LazyWhen",
"NativeFrameT_co",
"NativeSeriesT_co",
"WindowInputs",
]

View File

@@ -0,0 +1,89 @@
"""`Expr` and `Series` namespace accessor protocols."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from narwhals._utils import CompliantT_co, _StoresCompliant
if TYPE_CHECKING:
from typing import Callable
from narwhals.typing import TimeUnit
__all__ = [
"CatNamespace",
"DateTimeNamespace",
"ListNamespace",
"NameNamespace",
"StringNamespace",
"StructNamespace",
]
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def get_categories(self) -> CompliantT_co: ...
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def to_string(self, format: str) -> CompliantT_co: ...
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
def date(self) -> CompliantT_co: ...
def year(self) -> CompliantT_co: ...
def month(self) -> CompliantT_co: ...
def day(self) -> CompliantT_co: ...
def hour(self) -> CompliantT_co: ...
def minute(self) -> CompliantT_co: ...
def second(self) -> CompliantT_co: ...
def millisecond(self) -> CompliantT_co: ...
def microsecond(self) -> CompliantT_co: ...
def nanosecond(self) -> CompliantT_co: ...
def ordinal_day(self) -> CompliantT_co: ...
def weekday(self) -> CompliantT_co: ...
def total_minutes(self) -> CompliantT_co: ...
def total_seconds(self) -> CompliantT_co: ...
def total_milliseconds(self) -> CompliantT_co: ...
def total_microseconds(self) -> CompliantT_co: ...
def total_nanoseconds(self) -> CompliantT_co: ...
def truncate(self, every: str) -> CompliantT_co: ...
def offset_by(self, by: str) -> CompliantT_co: ...
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def len(self) -> CompliantT_co: ...
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def keep(self) -> CompliantT_co: ...
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
def prefix(self, prefix: str) -> CompliantT_co: ...
def suffix(self, suffix: str) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def len_chars(self) -> CompliantT_co: ...
def replace(
self, pattern: str, value: str, *, literal: bool, n: int
) -> CompliantT_co: ...
def replace_all(
self, pattern: str, value: str, *, literal: bool
) -> CompliantT_co: ...
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
def starts_with(self, prefix: str) -> CompliantT_co: ...
def ends_with(self, suffix: str) -> CompliantT_co: ...
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
def split(self, by: str) -> CompliantT_co: ...
def to_datetime(self, format: str | None) -> CompliantT_co: ...
def to_date(self, format: str | None) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
def zfill(self, width: int) -> CompliantT_co: ...
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def field(self, name: str) -> CompliantT_co: ...

View File

@@ -0,0 +1,501 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprT_contra,
CompliantLazyFrameAny,
CompliantSeriesT,
EagerExprT,
EagerSeriesT,
NativeExprT,
NativeFrameT,
NativeSeriesT,
)
from narwhals._translate import (
ArrowConvertible,
DictConvertible,
FromNative,
NumpyConvertible,
ToNarwhals,
ToNarwhalsT_co,
)
from narwhals._typing_compat import assert_never, deprecated
from narwhals._utils import (
ValidateBackendVersion,
Version,
_StoresNative,
check_columns_exist,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_sized_multi_index_selector,
is_slice_index,
is_slice_none,
)
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self, TypeAlias
from narwhals._compliant.expr import LazyExpr
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
from narwhals._compliant.namespace import EagerNamespace
from narwhals._compliant.window import WindowInputs
from narwhals._translate import IntoArrowTable
from narwhals._utils import Implementation, _LimitedContext
from narwhals.dataframe import DataFrame
from narwhals.dtypes import DType
from narwhals.exceptions import ColumnNotFoundError
from narwhals.schema import Schema
from narwhals.typing import (
AsofJoinStrategy,
JoinStrategy,
LazyUniqueKeepStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_2DArray,
_SliceIndex,
_SliceName,
)
Incomplete: TypeAlias = Any
__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]
T = TypeVar("T")
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
class CompliantDataFrame(
NumpyConvertible["_2DArray", "_2DArray"],
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
ArrowConvertible["pa.Table", "IntoArrowTable"],
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Sized,
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
_native_frame: NativeFrameT
_implementation: Implementation
_version: Version
def __narwhals_dataframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
def __getitem__(
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
MultiColSelector[CompliantSeriesT],
],
) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
# NOTE: Ignore is to avoid an intermittent false positive
return self.select(*exprs) # pyright: ignore[reportArgumentType]
def _with_version(self, version: Version) -> Self: ...
@property
def native(self) -> NativeFrameT:
return self._native_frame
@property
def columns(self) -> Sequence[str]: ...
@property
def schema(self) -> Mapping[str, DType]: ...
@property
def shape(self) -> tuple[int, int]: ...
def clone(self) -> Self: ...
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def estimated_size(self, unit: SizeUnit) -> int | float: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> DataFrameGroupBy[Self, Any]: ...
def head(self, n: int) -> Self: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
def is_unique(self) -> CompliantSeriesT: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def row(self, index: int) -> tuple[Any, ...]: ...
def rows(
self, *, named: bool
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Table: ...
def to_pandas(self) -> pd.DataFrame: ...
def to_polars(self) -> pl.DataFrame: ...
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
class CompliantLazyFrame(
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
_native_frame: NativeFrameT
_implementation: Implementation
_version: Version
def __narwhals_lazyframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
...
def _with_version(self, version: Version) -> Self: ...
@property
def native(self) -> NativeFrameT:
return self._native_frame
@property
def columns(self) -> Sequence[str]: ...
@property
def schema(self) -> Mapping[str, DType]: ...
def _iter_columns(self) -> Iterator[Any]: ...
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
@deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
def gather_every(self, n: int, offset: int) -> Self: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
def head(self, n: int) -> Self: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
@deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.")
def tail(self, n: int) -> Self: ...
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
result = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_window_expr(
self,
expr: LazyExpr[Self, NativeExprT],
/,
window_inputs: WindowInputs[NativeExprT],
) -> NativeExprT:
result = expr.window_function(self, window_inputs)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
class EagerDataFrame(
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
ValidateBackendVersion,
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __narwhals_namespace__(
self,
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
def to_narwhals(self) -> DataFrame[NativeFrameT]:
return self._version.dataframe(self, level="full")
def _with_native(
self, df: NativeFrameT, *, validate_column_names: bool = True
) -> Self: ...
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
"""Evaluate `expr` and ensure it has a **single** output."""
result: Sequence[EagerSeriesT] = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
# NOTE: Ignore is to avoid an intermittent false positive
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
"""Return list of raw columns.
For eager backends we alias operations at each step.
As a safety precaution, here we can check that the expected result names match those
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
"""
aliases = expr._evaluate_aliases(self)
result = expr(self)
if list(aliases) != (
result_aliases := [s.name for s in result]
): # pragma: no cover
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
raise AssertionError(msg)
return result
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
"""Extract native Series, broadcasting to `len(self)` if necessary."""
...
@staticmethod
def _numpy_column_names(
data: _2DArray, columns: Sequence[str] | None, /
) -> list[str]:
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def _select_multi_index(
self, columns: SizedMultiIndexSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_name(
self, columns: SizedMultiNameSelector[NativeSeriesT]
) -> Self: ...
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
def _select_slice_name(self, columns: _SliceName) -> Self: ...
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
MultiColSelector[EagerSeriesT],
],
) -> Self:
rows, columns = item
compliant = self
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return compliant.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
compliant = compliant._select_slice_index(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_index(columns.native)
else:
compliant = compliant._select_multi_index(columns)
elif isinstance(columns, slice):
compliant = compliant._select_slice_name(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_name(columns.native)
elif is_sequence_like(columns):
compliant = self._select_multi_name(columns)
else:
assert_never(columns)
if not is_slice_none(rows):
if isinstance(rows, int):
compliant = compliant._gather([rows])
elif isinstance(rows, (slice, range)):
compliant = compliant._gather_slice(rows)
elif is_compliant_series(rows):
compliant = compliant._gather(rows.native)
elif is_sized_multi_index_selector(rows):
compliant = compliant._gather(rows)
else:
assert_never(rows)
return compliant

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,214 @@
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypeVar
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantDataFrameT_co,
CompliantExprT_contra,
CompliantFrameT_co,
CompliantLazyFrameAny,
CompliantLazyFrameT_co,
DepthTrackingExprAny,
DepthTrackingExprT_contra,
EagerExprT_contra,
LazyExprT_contra,
NarwhalsAggregation,
NativeExprT_co,
)
from narwhals._typing_compat import Protocol38
from narwhals._utils import is_sequence_of
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
_SameFrameT = TypeVar("_SameFrameT", CompliantDataFrameAny, CompliantLazyFrameAny)
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy", "LazyGroupBy"]
NativeAggregationT_co = TypeVar(
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
)
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]):
_compliant_frame: Any
@property
def compliant(self) -> CompliantFrameT_co:
return self._compliant_frame # type: ignore[no-any-return]
def __init__(
self,
compliant_frame: CompliantFrameT_co,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None: ...
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
class DataFrameGroupBy(
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
Protocol38[CompliantDataFrameT_co, CompliantExprT_contra],
):
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
class ParseKeysGroupBy(
CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra],
Protocol38[CompliantFrameT_co, CompliantExprT_contra],
):
def _parse_keys(
self,
compliant_frame: _SameFrameT,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
) -> tuple[_SameFrameT, list[str], list[str]]:
if is_sequence_of(keys, str):
keys_str = list(keys)
return compliant_frame, keys_str, keys_str.copy()
else:
return self._parse_expr_keys(compliant_frame, keys=keys)
@staticmethod
def _parse_expr_keys(
compliant_frame: _SameFrameT, keys: Sequence[CompliantExprT_contra]
) -> tuple[_SameFrameT, list[str], list[str]]:
"""Parses key expressions to set up `.agg` operation with correct information.
Since keys are expressions, it's possible to alias any such key to match
other dataframe column names.
In order to match polars behavior and not overwrite columns when evaluating keys:
- We evaluate what the output key names should be, in order to remap temporary column
names to the expected ones, and to exclude those from unnamed expressions in
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
- Create temporary names for evaluated key expressions that are guaranteed to have
no overlap with any existing column name.
- Add these temporary columns to the compliant dataframe.
"""
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
def _temporary_name(key: str) -> str:
# 5 is the length of `__tmp`
key_str = str(key) # pandas allows non-string column names :sob:
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
output_names = compliant_frame._evaluate_aliases(*keys)
safe_keys = [
# multi-output expression cannot have duplicate names, hence it's safe to suffix
key.name.map(_temporary_name)
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
# otherwise it's single named and we can use Expr.alias
else key.alias(_temporary_name(new_name))
for key, new_name in zip(keys, output_names)
]
return (
compliant_frame.with_columns(*safe_keys),
compliant_frame._evaluate_aliases(*safe_keys),
output_names,
)
class DepthTrackingGroupBy(
ParseKeysGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra],
Protocol38[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co],
):
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
"""Mapping from `narwhals` to native representation.
Note:
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
"""
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
for expr in exprs:
if not self._is_simple(expr):
name = self.compliant._implementation.name.lower()
msg = (
f"Non-trivial complex aggregation found.\n\n"
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
f"{name!r} table.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)
@classmethod
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
@classmethod
def _remap_expr_name(
cls, name: NarwhalsAggregation | Any, /
) -> NativeAggregationT_co:
"""Replace `name`, with some native representation.
Arguments:
name: Name of a `nw.Expr` aggregation method.
Returns:
A native compatible representation.
"""
return cls._REMAP_AGGS.get(name, name)
@classmethod
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
"""Return the last function name in the chain defined by `expr`."""
return _RE_LEAF_NAME.sub("", expr._function_name)
class EagerGroupBy(
DepthTrackingGroupBy[
CompliantDataFrameT_co, EagerExprT_contra, NativeAggregationT_co
],
DataFrameGroupBy[CompliantDataFrameT_co, EagerExprT_contra],
Protocol38[CompliantDataFrameT_co, EagerExprT_contra, NativeAggregationT_co],
): ...
class LazyGroupBy(
ParseKeysGroupBy[CompliantLazyFrameT_co, LazyExprT_contra],
CompliantGroupBy[CompliantLazyFrameT_co, LazyExprT_contra],
Protocol38[CompliantLazyFrameT_co, LazyExprT_contra, NativeExprT_co],
):
_keys: list[str]
_output_key_names: list[str]
def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]:
output_names = expr._evaluate_output_names(self.compliant)
aliases = (
expr._alias_output_names(output_names)
if expr._alias_output_names
else output_names
)
native_exprs = expr(self.compliant)
if expr._is_multi_output_unnamed():
exclude = {*self._keys, *self._output_key_names}
for native_expr, name, alias in zip(native_exprs, output_names, aliases):
if name not in exclude:
yield expr._alias_native(native_expr, alias)
else:
for native_expr, alias in zip(native_exprs, aliases):
yield expr._alias_native(native_expr, alias)
def _evaluate_exprs(
self, exprs: Iterable[LazyExprT_contra], /
) -> Iterator[NativeExprT_co]:
for expr in exprs:
yield from self._evaluate_expr(expr)

View File

@@ -0,0 +1,211 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, overload
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprT,
NativeFrameT,
NativeFrameT_co,
NativeSeriesT,
)
from narwhals._utils import (
exclude_column_names,
get_column_names,
passthrough_column_names,
)
from narwhals.dependencies import is_numpy_array_2d
if TYPE_CHECKING:
from collections.abc import Container, Iterable, Mapping, Sequence
from typing_extensions import TypeAlias
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import (
ConcatMethod,
Into1DArray,
IntoDType,
NonNestedLiteral,
_2DArray,
)
Incomplete: TypeAlias = Any
__all__ = [
"CompliantNamespace",
"DepthTrackingNamespace",
"EagerNamespace",
"LazyNamespace",
]
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
_implementation: Implementation
_version: Version
def all(self) -> CompliantExprT:
return self._expr.from_column_names(get_column_names, context=self)
def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), context=self
)
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names), context=self
)
def nth(self, *column_indices: int) -> CompliantExprT:
return self._expr.from_column_indices(*column_indices, context=self)
def len(self) -> CompliantExprT: ...
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
def all_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def any_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
) -> CompliantFrameT: ...
def when(
self, predicate: CompliantExprT
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
def concat_str(
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
@property
def _expr(self) -> type[CompliantExprT]: ...
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
):
def all(self) -> DepthTrackingExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)
def col(self, *column_names: str) -> DepthTrackingExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)
class LazyNamespace(
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
else: # pragma: no cover
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
@property
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
@overload
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
elif self._series._is_native(data):
return self._series.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
@overload
def from_numpy(
self,
data: _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> EagerDataFrameT: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
) -> EagerDataFrameT | EagerSeriesT:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self)
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def _concat_horizontal(
self, dfs: Sequence[NativeFrameT | Any], /
) -> NativeFrameT: ...
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def concat(
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
) -> EagerDataFrameT:
dfs = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else: # pragma: no cover
raise NotImplementedError
return self._dataframe.from_native(native, context=self)

View File

@@ -0,0 +1,320 @@
"""Almost entirely complete, generic `selectors` implementation."""
from __future__ import annotations
import re
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
from narwhals._compliant.expr import CompliantExpr
from narwhals._typing_compat import Protocol38
from narwhals._utils import (
_parse_time_unit_and_time_zone,
dtype_matches_time_unit_and_time_zone,
get_column_names,
is_compliant_dataframe,
)
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
from datetime import timezone
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.expr import NativeExpr
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprAny,
CompliantFrameAny,
CompliantLazyFrameAny,
CompliantSeriesAny,
CompliantSeriesOrNativeExprAny,
EvalNames,
EvalSeries,
ScalarKwargs,
)
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
__all__ = [
"CompliantSelector",
"CompliantSelectorNamespace",
"EagerSelectorNamespace",
"LazySelectorNamespace",
]
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
ExprT = TypeVar("ExprT", bound="NativeExpr")
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
SelectorOrExpr: TypeAlias = (
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
)
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
_implementation: Implementation
_version: Version
@classmethod
def from_namespace(cls, context: _LimitedContext, /) -> Self:
obj = cls.__new__(cls)
obj._implementation = context._implementation
obj._version = context._version
return obj
@property
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
def _iter_columns_dtypes(
self, df: FrameT, /
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
yield from zip(self._iter_columns(df), df.columns)
def _is_dtype(
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
return self._selector.from_callables(series, names, context=self)
def by_dtype(
self, dtypes: Collection[DType | type[DType]]
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
return self._selector.from_callables(series, names, context=self)
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
p = re.compile(pattern)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
if (
is_compliant_dataframe(df)
and not self._implementation.is_duckdb()
and not self._implementation.is_ibis()
):
return [df.get_column(col) for col in df.columns if p.search(col)]
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
def names(df: FrameT) -> Sequence[str]:
return [col for col in df.columns if p.search(col)]
return self._selector.from_callables(series, names, context=self)
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
return self._selector.from_callables(series, names, context=self)
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Categorical)
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.String)
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Boolean)
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return list(self._iter_columns(df))
return self._selector.from_callables(series, get_column_names, context=self)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> CompliantSelector[FrameT, SeriesOrExprT]:
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
matches = partial(
dtype_matches_time_unit_and_time_zone,
dtypes=self._version.dtypes,
time_units=time_units,
time_zones=time_zones,
)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if matches(tp)]
return self._selector.from_callables(series, names, context=self)
class EagerSelectorNamespace(
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
):
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
for ser in self._iter_columns(df):
yield ser.name, ser.dtype
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
yield from df.iter_columns()
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
for ser in self._iter_columns(df):
yield ser, ser.dtype
class LazySelectorNamespace(
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
):
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
yield from df.schema.items()
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
yield from df._iter_columns()
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
yield from zip(self._iter_columns(df), df.schema.values())
class CompliantSelector(
CompliantExpr[FrameT, SeriesOrExprT], Protocol38[FrameT, SeriesOrExprT]
):
_call: EvalSeries[FrameT, SeriesOrExprT]
_window_function: None
_function_name: str
_depth: int
_implementation: Implementation
_version: Version
_scalar_kwargs: ScalarKwargs
@classmethod
def from_callables(
cls,
call: EvalSeries[FrameT, SeriesOrExprT],
evaluate_output_names: EvalNames[FrameT],
*,
context: _LimitedContext,
) -> Self:
obj = cls.__new__(cls)
obj._call = call
obj._window_function = None
obj._depth = 0
obj._function_name = "selector"
obj._evaluate_output_names = evaluate_output_names
obj._alias_output_names = None
obj._implementation = context._implementation
obj._version = context._version
obj._scalar_kwargs = {}
return obj
@property
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
return self.__narwhals_namespace__().selectors
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def _is_selector(
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
return isinstance(other, type(self))
@overload
def __sub__(self, other: Self) -> Self: ...
@overload
def __sub__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __sub__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x for x, name in zip(self(df), lhs_names) if name not in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x not in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() - other
@overload
def __or__(self, other: Self) -> Self: ...
@overload
def __or__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __or__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
*(x for x, name in zip(self(df), lhs_names) if name not in rhs_names),
*other(df),
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() | other
@overload
def __and__(self, other: Self) -> Self: ...
@overload
def __and__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __and__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x, name in zip(self(df), lhs_names) if name in rhs_names]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() & other
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self.selectors.all() - self
def _eval_lhs_rhs(
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
) -> tuple[Sequence[str], Sequence[str]]:
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)

View File

@@ -0,0 +1,433 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Protocol
from narwhals._compliant.any_namespace import (
CatNamespace,
DateTimeNamespace,
ListNamespace,
StringNamespace,
StructNamespace,
)
from narwhals._compliant.typing import (
CompliantSeriesT_co,
EagerSeriesT_co,
NativeSeriesT,
NativeSeriesT_co,
)
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import assert_never
from narwhals._utils import (
_StoresCompliant,
_StoresNative,
is_compliant_series,
is_sized_multi_index_selector,
unstable,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self
from narwhals._compliant.dataframe import CompliantDataFrame
from narwhals._compliant.expr import CompliantExpr, EagerExpr
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
ClosedInterval,
FillNullStrategy,
Into1DArray,
IntoDType,
MultiIndexSelector,
NonNestedLiteral,
NumericLiteral,
RankMethod,
RollingInterpolationMethod,
SizedMultiIndexSelector,
TemporalLiteral,
_1DArray,
_SliceIndex,
)
__all__ = [
"CompliantSeries",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
]
class CompliantSeries(
NumpyConvertible["_1DArray", "Into1DArray"],
FromIterable,
FromNative[NativeSeriesT],
ToNarwhals["Series[NativeSeriesT]"],
Protocol[NativeSeriesT],
):
_implementation: Implementation
_version: Version
@property
def dtype(self) -> DType: ...
@property
def name(self) -> str: ...
@property
def native(self) -> NativeSeriesT: ...
def __narwhals_series__(self) -> Self:
return self
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
def __native_namespace__(self) -> ModuleType: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
def __contains__(self, other: Any) -> bool: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
def __iter__(self) -> Iterator[Any]: ...
def __len__(self) -> int:
return len(self.native)
def _with_native(self, series: Any) -> Self: ...
def _with_version(self, version: Version) -> Self: ...
def _to_expr(self) -> CompliantExpr[Any, Self]: ...
@classmethod
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_iterable(
cls,
data: Iterable[Any],
/,
*,
context: _LimitedContext,
name: str = "",
dtype: IntoDType | None = None,
) -> Self: ...
def to_narwhals(self) -> Series[NativeSeriesT]:
return self._version.series(self, level="full")
# Operators
def __add__(self, other: Any) -> Self: ...
def __and__(self, other: Any) -> Self: ...
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
def __floordiv__(self, other: Any) -> Self: ...
def __ge__(self, other: Any) -> Self: ...
def __gt__(self, other: Any) -> Self: ...
def __invert__(self) -> Self: ...
def __le__(self, other: Any) -> Self: ...
def __lt__(self, other: Any) -> Self: ...
def __mod__(self, other: Any) -> Self: ...
def __mul__(self, other: Any) -> Self: ...
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
def __or__(self, other: Any) -> Self: ...
def __pow__(self, other: Any) -> Self: ...
def __radd__(self, other: Any) -> Self: ...
def __rand__(self, other: Any) -> Self: ...
def __rfloordiv__(self, other: Any) -> Self: ...
def __rmod__(self, other: Any) -> Self: ...
def __rmul__(self, other: Any) -> Self: ...
def __ror__(self, other: Any) -> Self: ...
def __rpow__(self, other: Any) -> Self: ...
def __rsub__(self, other: Any) -> Self: ...
def __rtruediv__(self, other: Any) -> Self: ...
def __sub__(self, other: Any) -> Self: ...
def __truediv__(self, other: Any) -> Self: ...
def abs(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def all(self) -> bool: ...
def any(self) -> bool: ...
def arg_max(self) -> int: ...
def arg_min(self) -> int: ...
def arg_true(self) -> Self: ...
def cast(self, dtype: IntoDType) -> Self: ...
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...
def count(self) -> int: ...
def cum_count(self, *, reverse: bool) -> Self: ...
def cum_max(self, *, reverse: bool) -> Self: ...
def cum_min(self, *, reverse: bool) -> Self: ...
def cum_prod(self, *, reverse: bool) -> Self: ...
def cum_sum(self, *, reverse: bool) -> Self: ...
def diff(self) -> Self: ...
def drop_nulls(self) -> Self: ...
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self: ...
def exp(self) -> Self: ...
def sqrt(self) -> Self: ...
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self: ...
def filter(self, predicate: Any) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
@unstable
def hist(
self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_breakpoint: bool,
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def head(self, n: int) -> Self: ...
def is_between(
self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval
) -> Self: ...
def is_finite(self) -> Self: ...
def is_first_distinct(self) -> Self: ...
def is_in(self, other: Any) -> Self: ...
def is_last_distinct(self) -> Self: ...
def is_nan(self) -> Self: ...
def is_null(self) -> Self: ...
def is_sorted(self, *, descending: bool) -> bool: ...
def is_unique(self) -> Self: ...
def item(self, index: int | None) -> Any: ...
def kurtosis(self) -> float | None: ...
def len(self) -> int: ...
def log(self, base: float) -> Self: ...
def max(self) -> Any: ...
def mean(self) -> float: ...
def median(self) -> float: ...
def min(self) -> Any: ...
def mode(self) -> Self: ...
def n_unique(self) -> int: ...
def null_count(self) -> int: ...
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> float: ...
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self: ...
def rolling_mean(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def rolling_sum(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def round(self, decimals: int) -> Self: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
def shift(self, n: int) -> Self: ...
def skew(self) -> float | None: ...
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
def std(self, *, ddof: int) -> float: ...
def sum(self) -> float: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Array[Any]: ...
def to_dummies(
self, *, separator: str, drop_first: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_list(self) -> list[Any]: ...
def to_pandas(self) -> pd.Series[Any]: ...
def to_polars(self) -> pl.Series: ...
def unique(self, *, maintain_order: bool) -> Self: ...
def value_counts(
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def var(self, *, ddof: int) -> float: ...
def zip_with(self, mask: Any, other: Any) -> Self: ...
@property
def str(self) -> Any: ...
@property
def dt(self) -> Any: ...
@property
def cat(self) -> Any: ...
@property
def list(self) -> Any: ...
@property
def struct(self) -> Any: ...
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
_native_series: Any
_implementation: Implementation
_version: Version
_broadcast: bool
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@classmethod
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
"""Ensure all of `series` have the same length (and index if `pandas`).
Scalars get broadcasted to the full length of the longest Series.
This is useful when you need to construct a full Series anyway, such as:
DataFrame.select(...)
It should not be used in binary operations, such as:
nw.col("a") - nw.col("a").mean()
because then it's more efficient to extract the right-hand-side's single element as a scalar.
"""
...
def _from_scalar(self, value: Any) -> Self:
return self.from_iterable([value], name=self.name, context=self)
def _with_native(
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
) -> Self:
"""Return a new `CompliantSeries`, wrapping the native `series`.
In cases when operations are known to not affect whether a result should
be broadcast, we can pass `preserve_broadcast=True`.
Set this with care - it should only be set for unary expressions which don't
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
set it, you probably don't need it.
"""
...
def __narwhals_namespace__(
self,
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
def _to_expr(self) -> EagerExpr[Any, Any]:
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
if isinstance(item, (slice, range)):
return self._gather_slice(item)
elif is_compliant_series(item):
return self._gather(item.native)
elif is_sized_multi_index_selector(item):
return self._gather(item)
else:
assert_never(item)
@property
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
@property
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
@property
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
@property
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
@property
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
class _SeriesNamespace( # type: ignore[misc]
_StoresCompliant[CompliantSeriesT_co],
_StoresNative[NativeSeriesT_co],
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
):
_compliant_series: CompliantSeriesT_co
@property
def compliant(self) -> CompliantSeriesT_co:
return self._compliant_series
@property
def implementation(self) -> Implementation:
return self.compliant._implementation
@property
def backend_version(self) -> tuple[int, ...]:
return self.implementation._backend_version()
@property
def version(self) -> Version:
return self.compliant._version
@property
def native(self) -> NativeSeriesT_co:
return self._compliant_series.native # type: ignore[no-any-return]
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
return self.compliant._with_native(series)
class EagerSeriesNamespace(
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
Generic[EagerSeriesT_co, NativeSeriesT_co],
):
_compliant_series: EagerSeriesT_co
def __init__(self, series: EagerSeriesT_co, /) -> None:
self._compliant_series = series
class EagerSeriesCatNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
CatNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
DateTimeNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesListNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
ListNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStringNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StringNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStructNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StructNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...

View File

@@ -0,0 +1,194 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
LazyExpr,
NativeExpr,
)
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
from narwhals._compliant.series import CompliantSeries, EagerSeries
from narwhals._compliant.window import WindowInputs
from narwhals.typing import (
FillNullStrategy,
NativeFrame,
NativeSeries,
RankMethod,
RollingInterpolationMethod,
)
class ScalarKwargs(TypedDict, total=False):
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
adjust: bool
alpha: float | None
center: int
com: float | None
ddof: int
descending: bool
half_life: float | None
ignore_nulls: bool
interpolation: RollingInterpolationMethod
limit: int | None
method: RankMethod
min_samples: int
n: int
quantile: float
reverse: bool
span: float | None
strategy: FillNullStrategy | None
window_size: int
__all__ = [
"AliasName",
"AliasNames",
"CompliantDataFrameT",
"CompliantFrameT",
"CompliantLazyFrameT",
"CompliantSeriesT",
"EvalNames",
"EvalSeries",
"IntoCompliantExpr",
"NarwhalsAggregation",
"NativeFrameT_co",
"NativeSeriesT_co",
]
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
CompliantFrameAny: TypeAlias = "CompliantDataFrameAny | CompliantLazyFrameAny"
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
NativeSeriesT_contra = TypeVar(
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
)
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
NativeFrameT_contra = TypeVar(
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
CompliantExprT_contra = TypeVar(
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
)
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
)
CompliantSeriesOrNativeExprT = TypeVar(
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
)
CompliantSeriesOrNativeExprT_co = TypeVar(
"CompliantSeriesOrNativeExprT_co",
bound=CompliantSeriesOrNativeExprAny,
covariant=True,
)
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
CompliantFrameT_co = TypeVar(
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
)
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
CompliantDataFrameT_co = TypeVar(
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
)
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
CompliantLazyFrameT_co = TypeVar(
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
)
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
CompliantNamespaceT_co = TypeVar(
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
)
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
DepthTrackingExprT_contra = TypeVar(
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
)
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
"""A function aliasing a *sequence* of column names."""
AliasName: TypeAlias = Callable[[str], str]
"""A function aliasing a *single* column name."""
EvalSeries: TypeAlias = Callable[
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
]
"""A function from a `Frame` to a sequence of `Series`*.
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
"""
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
WindowFunction: TypeAlias = (
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
)
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
NarwhalsAggregation: TypeAlias = Literal[
"sum",
"mean",
"median",
"max",
"min",
"std",
"var",
"len",
"n_unique",
"count",
"quantile",
]
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
Be sure to update me if you're working on one of these:
- https://github.com/narwhals-dev/narwhals/issues/981
- https://github.com/narwhals-dev/narwhals/issues/2385
- https://github.com/narwhals-dev/narwhals/issues/2484
- https://github.com/narwhals-dev/narwhals/issues/2526
- https://github.com/narwhals-dev/narwhals/issues/2660
"""

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import (
CompliantExprAny,
CompliantFrameAny,
CompliantLazyFrameT,
CompliantSeriesOrNativeExprAny,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprAny,
NativeExprT,
NativeSeriesT,
WindowFunction,
)
from narwhals._typing_compat import Protocol38
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self, TypeAlias
from narwhals._compliant.typing import EvalSeries, ScalarKwargs
from narwhals._compliant.window import WindowInputs
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.typing import NonNestedLiteral
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen", "LazyThen", "LazyWhen"]
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
Scalar: TypeAlias = Any
"""A native literal value."""
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
"""Anything that is convertible into a `CompliantExpr`."""
class CompliantWhen(Protocol38[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
_implementation: Implementation
_version: Version
@property
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT]]: ...
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
def _window_function(
self, compliant_frame: FrameT, window_inputs: WindowInputs[Any]
) -> Sequence[SeriesT]: ...
def then(
self, value: IntoExpr[SeriesT, ExprT], /
) -> CompliantThen[FrameT, SeriesT, ExprT]:
return self._then.from_when(self, value)
@classmethod
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._version = context._version
return obj
class CompliantThen(CompliantExpr[FrameT, SeriesT], Protocol38[FrameT, SeriesT, ExprT]):
_call: EvalSeries[FrameT, SeriesT]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_function_name: str
_depth: int
_implementation: Implementation
_version: Version
_scalar_kwargs: ScalarKwargs
@classmethod
def from_when(
cls,
when: CompliantWhen[FrameT, SeriesT, ExprT],
then: IntoExpr[SeriesT, ExprT],
/,
) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._when_value = when
obj._depth = 0
obj._function_name = "whenthen"
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._version = when._version
obj._scalar_kwargs = {}
return obj
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
self._when_value._otherwise_value = otherwise
self._function_name = "whenotherwise"
return cast("ExprT", self)
class LazyThen(
CompliantThen[CompliantLazyFrameT, NativeExprT, LazyExprT],
Protocol38[CompliantLazyFrameT, NativeExprT, LazyExprT],
):
_window_function: WindowFunction[CompliantLazyFrameT, NativeExprT] | None
@classmethod
def from_when(
cls,
when: CompliantWhen[CompliantLazyFrameT, NativeExprT, LazyExprT],
then: IntoExpr[NativeExprT, LazyExprT],
/,
) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._window_function = when._window_function
obj._when_value = when
obj._depth = 0
obj._function_name = "whenthen"
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._version = when._version
obj._scalar_kwargs = {}
return obj
class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
/,
) -> NativeSeriesT: ...
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
when: EagerSeriesT = self._condition(df)[0]
then: EagerSeriesT
align = when._align_full_broadcast
if is_expr(self._then_value):
then = self._then_value(df)[0]
else:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = self._otherwise_value(df)[0]
when, then, otherwise = align(when, then, otherwise)
result = self._if_then_else(when.native, then.native, otherwise.native)
else:
when, then = align(when, then)
result = self._if_then_else(when.native, then.native, self._otherwise_value)
return [then._with_native(result)]
class LazyWhen(
CompliantWhen[CompliantLazyFrameT, NativeExprT, LazyExprT],
Protocol38[CompliantLazyFrameT, NativeExprT, LazyExprT],
):
when: Callable[..., NativeExprT]
lit: Callable[..., NativeExprT]
def __call__(self, df: CompliantLazyFrameT) -> Sequence[NativeExprT]:
is_expr = self._condition._is_expr
when = self.when
lit = self.lit
condition = df._evaluate_expr(self._condition)
then_ = self._then_value
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
other_ = self._otherwise_value
if other_ is None:
result = when(condition, then)
else:
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
result = when(condition, then).otherwise(otherwise) # type: ignore # noqa: PGH003
return [result]
@classmethod
def from_expr(cls, condition: LazyExprT, /, *, context: _LimitedContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._version = context._version
return obj
def _window_function(
self, df: CompliantLazyFrameT, window_inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
is_expr = self._condition._is_expr
condition = self._condition.window_function(df, window_inputs)[0]
then_ = self._then_value
then = (
then_.window_function(df, window_inputs)[0]
if is_expr(then_)
else self.lit(then_)
)
other_ = self._otherwise_value
if other_ is None:
result = self.when(condition, then)
else:
other = (
other_.window_function(df, window_inputs)[0]
if is_expr(other_)
else self.lit(other_)
)
result = self.when(condition, then).otherwise(other) # type: ignore # noqa: PGH003
return [result]

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic
from narwhals._compliant.typing import NativeExprT_co
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["WindowInputs"]
class WindowInputs(Generic[NativeExprT_co]):
__slots__ = ("order_by", "partition_by")
def __init__(
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
) -> None:
self.partition_by = partition_by
self.order_by = order_by