Source code for pipeline_builder.functions

"""
Functions interface protocol for the framework.

This module provides a protocol-based interface for PySpark functions that can
be injected into framework components. This design allows for better testability,
flexibility, and engine abstraction.

**Key Features:**
    - **Protocol-Based**: Uses Python Protocol for type safety and duck typing
    - **Engine Abstraction**: Works with both real PySpark and mock implementations
    - **Injection Support**: Functions can be injected via engine configuration
    - **Type Safety**: Provides type hints for all function signatures

**Usage:**
    The functions protocol is typically accessed through the compat module:

    >>> from pipeline_builder.compat import F
    >>> df.select(F.col("id"), F.lit("test"))

    Or via the get_default_functions() helper:

    >>> from pipeline_builder.functions import get_default_functions
    >>> F = get_default_functions()
    >>> df.select(F.col("id"))

**Supported Functions:**
    The protocol defines common PySpark functions including:
    - Column operations: col, expr, lit, when
    - Aggregations: count, countDistinct, sum, max, min, avg
    - String functions: length
    - Date functions: date_trunc, dayofweek, current_timestamp

Dependencies:
    - compat: Compatibility layer for engine detection

Example:
    >>> from pipeline_builder.functions import FunctionsProtocol, get_default_functions
    >>> from pipeline_builder.compat import F
    >>>
    >>> # Get functions from compat module
    >>> functions = get_default_functions()
    >>> # Use functions for DataFrame operations
    >>> df.select(functions.col("id"), functions.lit("value"))
"""

from __future__ import annotations

from typing import Optional, Protocol, Union, cast

from .protocols import ColumnProtocol


class FunctionsProtocol(Protocol):
    """Protocol for PySpark functions interface.

    This protocol defines the interface that all functions implementations
    must satisfy. It includes common PySpark functions for column operations,
    aggregations, and transformations.

    **Implementation Requirements:**
        Any class or module implementing this protocol must provide all
        the methods defined here with matching signatures. The protocol
        supports both real PySpark functions and mock implementations
        for testing.

    **Common Implementations:**
        - PySpark `pyspark.sql.functions` module
        - Mock functions for testing (see test utilities)
        - Custom function wrappers for specific engines

    Example:
        >>> from pipeline_builder.functions import FunctionsProtocol
        >>> from pipeline_builder.compat import F
        >>>
        >>> # F implements FunctionsProtocol
        >>> def use_functions(f: FunctionsProtocol):
        ...     return f.col("id")
        >>>
        >>> result = use_functions(F)
    """

    def col(self, col_name: str) -> ColumnProtocol:
        """Create a column reference.

        Args:
            col_name: Name of the column to reference.

        Returns:
            Column expression representing the column reference.

        Example:
            >>> F.col("user_id")
        """
        ...

    def expr(self, expr: str) -> ColumnProtocol:
        """Create an expression from a string."""
        ...

    def lit(
        self, value: Union[str, int] | Union[float, Optional[bool]]
    ) -> ColumnProtocol:
        """Create a literal column."""
        ...

    def when(
        self,
        condition: ColumnProtocol,
        value: Union[str, int] | Union[float, Optional[bool]],
    ) -> ColumnProtocol:
        """Create a conditional expression."""
        ...

    def count(self, col: Union[str, ColumnProtocol] = "*") -> ColumnProtocol:
        """Create a count aggregation."""
        ...

    def countDistinct(self, *cols: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create a count distinct aggregation."""
        ...

    def sum(self, col: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create a sum aggregation."""
        ...

    def max(self, col: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create a max aggregation."""
        ...

    def min(self, col: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create a min aggregation."""
        ...

    def avg(self, col: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create an average aggregation."""
        ...

    def length(self, col: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create a length function."""
        ...

    def date_trunc(
        self, format: str, col: Union[str, ColumnProtocol]
    ) -> ColumnProtocol:
        """Create a date truncation function."""
        ...

    def dayofweek(self, col: Union[str, ColumnProtocol]) -> ColumnProtocol:
        """Create a day of week function."""
        ...

    def current_timestamp(self) -> ColumnProtocol:
        """Create a current timestamp function."""
        ...


[docs]def get_default_functions() -> FunctionsProtocol: """Get the injected functions implementation. Returns the functions module (F) from the configured engine. This is the same as accessing `F` directly from the compat module, but provides a typed interface for dependency injection. Returns: FunctionsProtocol instance from the configured engine. This is typically the PySpark functions module or a mock equivalent. Example: >>> from pipeline_builder.functions import get_default_functions >>> F = get_default_functions() >>> # Use F for DataFrame operations >>> df.select(F.col("id"), F.count("*")) """ from .compat import F return cast(FunctionsProtocol, F)