MMM Model (mmm_model.py)

This module defines the DelayedSaturatedMMM class, the primary concrete implementation of the Marketing Mix Model within the ABACUS library.

DelayedSaturatedMMM

class DelayedSaturatedMMM(
    # Order: Validation, Preprocessing, Base Model, Functionality Mixins
    vd.ValidateControlColumns, # Validation first
    pr.MaxAbsScaleTarget,      # Target scaling
    pr.MaxAbsScaleChannels,    # Channel scaling
    BaseDelayedSaturatedMMM,   # Core model logic
    MMMAnalysisMixin,          # Plotting and grid analysis
    MMMPredictMixin,           # Prediction and scenarios
    MMMCalibrateMixin,         # Lift test calibration
):
    """Media Mix Model with delayed adstock and logistic saturation (see [1]_).

    Combines geometric adstock, logistic saturation, optional control variables,
    and optional Fourier modes for seasonality. Includes data validation and
    preprocessing mixins (MaxAbs scaling for target and channels). Provides
    methods for analysis, plotting, prediction, and calibration.

    Inherits core logic from `BaseDelayedSaturatedMMM` and functionality from
    various mixin classes.

    Parameters:
        date_column (str): Column name of the date variable.
        channel_columns (List[str]): Column names of the media channel variables.
        adstock_max_lag (int): Maximum lag for the adstock transformation.
        model_config (Optional[Dict], optional): Configuration for priors and likelihood.
            Uses defaults if None. Defaults to None.
        sampler_config (Optional[Dict], optional): Configuration for the sampler.
            Uses defaults if None. Defaults to None.
        validate_data (bool, optional): Whether to validate input data. Defaults to True.
        control_columns (Optional[List[str]], optional): Column names for control variables.
            Defaults to None.
        yearly_seasonality (Optional[int], optional): Number of Fourier modes for seasonality.
            Defaults to None.
        **kwargs: Additional keyword arguments passed to `BaseDelayedSaturatedMMM`.

    Notes:
        - Target variable and media channels are scaled using MaxAbsScaler by default.
        - Control variables are validated but not scaled by default within this class structure.
        - Allows calibration via custom priors (`model_config`) and lift tests
          (`add_lift_test_measurements`).

    References:
        .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover
               and shape effects.” (2017).
        .. [2] Orduz, J. `"Media Effect Estimation with PyMC: Adstock, Saturation & Diminishing Returns"
               <https://juanitorduz.github.io/pymc_mmm/>`_.

    Examples:
        >>> import pandas as pd
        >>> from abacus.core.mmm_model import DelayedSaturatedMMM # Updated import path
        >>> # Load data...
        >>> data = pd.read_csv("your_data.csv", parse_dates=["date_column_name"])
        >>> X = data.drop("target_column_name", axis=1)
        >>> y = data["target_column_name"]
        >>> mmm = DelayedSaturatedMMM(
        ...     date_column="date_column_name",
        ...     channel_columns=["channel1", "channel2"],
        ...     control_columns=["control1"],
        ...     adstock_max_lag=4,
        ...     yearly_seasonality=2
        ... )
        >>> idata = mmm.fit(X, y, draws=1000, tune=1000) # Pass sampler args to fit
        >>> # Use methods from mixins:
        >>> fig = mmm.plot_channel_contributions_grid(start=0, stop=2, num=11)
        >>> pred_contrib = mmm.new_spend_contributions(spend=np.array([100, 200]))
    """

This class represents the standard MMM implementation in ABACUS, incorporating delayed adstock and logistic saturation effects. It builds upon BaseDelayedSaturatedMMM and integrates various mixins for a complete workflow.

Inheritance Structure:

  1. Validation (abacus.prepro.valid):

    • ValidateControlColumns: Ensures control columns specified exist in the input data.

  2. Preprocessing (abacus.prepro.prepro):

    • MaxAbsScaleTarget: Scales the target variable using MaxAbsScaler.

    • MaxAbsScaleChannels: Scales the media channel variables using MaxAbsScaler.

  3. Base Model (abacus.core.mmm_base):

    • BaseDelayedSaturatedMMM: Provides the core PyMC model definition (build_model) incorporating delayed adstock and logistic saturation, along with fundamental fitting and prediction logic inherited from ModelBuilder.

  4. Functionality Mixins (abacus.core.mixins):

    • MMMAnalysisMixin: Adds methods for analysing and plotting model results, such as contribution grids (plot_channel_contributions_grid).

    • MMMPredictMixin: Adds methods for generating predictions under different scenarios (new_spend_contributions).

    • MMMCalibrateMixin: Adds methods for incorporating lift test data for model calibration (add_lift_test_measurements).

Initialization

The __init__ method is inherited from BaseDelayedSaturatedMMM (and ultimately BaseMMM and ModelBuilder). Refer to the documentation for BaseDelayedSaturatedMMM and BaseMMM for details on initialization parameters like date_column, channel_columns, control_columns, adstock_max_lag, yearly_seasonality, model_config, sampler_config, etc.

Overridden Methods

channel_contributions_forward_pass

    def channel_contributions_forward_pass(self, channel_data: np.ndarray) -> np.ndarray:

This method overrides the one in BaseDelayedSaturatedMMM. It calculates the estimated contribution of each channel based on the input channel_data (after applying adstock and saturation transformations defined in the base model) and then inverse-transforms the result using the target scaler (provided by the MaxAbsScaleTarget mixin). This ensures the returned contributions are on the original scale of the target variable.

Parameters:

  • channel_data (np.ndarray): Input channel data, potentially preprocessed. Shape should be compatible with model coordinates.

Returns:

  • np.ndarray: Estimated channel contributions in the original target scale. Shape: (chains, draws, dates, channels).

Note: Most other functionalities (fitting, prediction, plotting, optimization, diagnostics) are provided by the inherited base classes and mixins. Refer to their respective documentation (BaseDelayedSaturatedMMM, BaseMMM, ModelBuilder, and the various mixin classes) for details.