MMM Base (mmm_base.py)

This module defines the BaseDelayedSaturatedMMM class, which serves as a more specialized base class for Marketing Mix Models within ABACUS, specifically implementing the core logic for delayed adstock and logistic saturation effects. It inherits from abacus.core.model.MMM.

BaseDelayedSaturatedMMM

class BaseDelayedSaturatedMMM(model.MMM):
    """Base class for Media Mix Models with delayed adstock and logistic saturation."""

This class provides the specific implementation for the build_model method, defining the PyMC model structure that includes:

  • Geometric adstock transformation (abacus.core.transformers.geometric_adstock).

  • Logistic saturation transformation (abacus.core.transformers.logistic_saturation).

  • Optional control variables.

  • Optional Fourier modes for yearly seasonality.

  • Optional integration of lift test data via abacus.core.lift_test.add_lift_measurements_to_likelihood_from_saturation.

It inherits the overall structure, fitting, prediction, saving/loading, and mixin integration from abacus.core.model.MMM and abacus.core.base.BaseMMM.

Initialization

    def __init__(
        self,
        date_column: str,
        channel_columns: List[str],
        adstock_max_lag: int,
        model_config: Optional[Dict] = None,
        sampler_config: Optional[Dict] = None,
        validate_data: bool = True,
        control_columns: Optional[List[str]] = None,
        yearly_seasonality: Optional[int] = None,
        df_lift_test: Optional[pd.DataFrame] = None,
        **kwargs,
    ) -> None:

Initialises the BaseDelayedSaturatedMMM model.

Parameters:

  • date_column (str): Column name of the date variable.

  • channel_columns (List[str]): Column names of the media channel variables.

  • adstock_max_lag (int): Number of lags to consider in the adstock transformation.

  • model_config (Optional[Dict], optional): Dictionary of parameters that initialise model configuration (priors, likelihood). Uses default_model_config if None. Defaults to None.

  • sampler_config (Optional[Dict], optional): Dictionary of parameters that initialise sampler configuration. Uses default_sampler_config if None. Defaults to None.

  • validate_data (bool, optional): Whether to validate the data before fitting the model. Defaults to True.

  • control_columns (Optional[List[str]], optional): Column names of control variables. Defaults to None.

  • yearly_seasonality (Optional[int], optional): Number of Fourier modes to model yearly seasonality. Defaults to None.

  • df_lift_test (Optional[pd.DataFrame], optional): DataFrame containing lift test data for calibration. Defaults to None.

  • **kwargs: Additional keyword arguments passed to the parent class (model.MMM) __init__.

Attributes Initialised:

  • control_columns, adstock_max_lag, yearly_seasonality, date_column, validate_data, df_lift_test: Stores the corresponding initialization parameters.

  • Inherited attributes from model.MMM, BaseMMM, and ModelBuilder.

Properties

default_sampler_config

    @property
    def default_sampler_config(self) -> Dict:

Returns an empty dictionary, indicating no specific sampler defaults are set at this level.

output_var

    @property
    def output_var(self) -> str:

Returns "y", the default name for the target variable in the model.

default_model_config

    @property
    def default_model_config(self) -> Dict:

Returns the default configuration for the model’s priors and likelihood. Defines default distributions and parameters for: - intercept (LogNormal) - beta_channel (HalfNormal) - Channel effectiveness coefficients - alpha (Beta) - Adstock retention rate - lam (Gamma) - Saturation rate (lambda) - likelihood (Normal with HalfNormal sigma) - Observation model - gamma_control (Normal) - Control variable coefficients - gamma_fourier (Laplace) - Fourier mode coefficients

_serializable_model_config

    @property
    def _serializable_model_config(self) -> Dict:

Returns a JSON-serializable version of the model_config by converting NumPy arrays to lists.

Core Methods

build_model

    def build_model(self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray], **kwargs: Any) -> None:

Defines the PyMC model structure within a pm.Model context. This is the core implementation for this specific MMM type.

  1. Sets up mutable data containers for channel, target, control (if any), and Fourier (if any) data using preprocessed data stored in self.preprocessed_data.

  2. Defines priors for intercept, beta_channel, alpha (adstock), lam (saturation), gamma_control, and gamma_fourier based on model_config.

  3. Applies geometric adstock (tm.geometric_adstock) to channel data.

  4. Applies logistic saturation (tm.logistic_saturation) to the adstocked channel data.

  5. Calculates channel contributions (channel_adstock_saturated * beta_channel).

  6. Calculates the model mean (mu) by summing intercept, channel contributions, control contributions (if any), and Fourier contributions (if any).

  7. Defines the likelihood distribution using _create_likelihood_distribution based on model_config["likelihood"], linking mu to the observed target data.

  8. Optionally adds lift test constraints to the likelihood using add_lift_measurements_to_likelihood_from_saturation if self.df_lift_test was provided during initialization.

_generate_and_preprocess_model_data

    def _generate_and_preprocess_model_data(self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray]) -> None:

Prepares the data for build_model.

  1. Separates date, channel, and control columns from the input X.

  2. Generates Fourier features using _get_fourier_models_data if yearly_seasonality is set.

  3. Concatenates all relevant features (date, channels, controls, Fourier) into X_data.

  4. Sets model coordinates (self.model_coords, self.coords_mutable).

  5. Optionally calls validation methods (inherited from mixins via model.MMM).

  6. Calls preprocessing methods (inherited from mixins via model.MMM) on X_data and y.

  7. Stores the final preprocessed features and target in self.preprocessed_data.

  8. Stores the original (or potentially partially processed X_data) and y in self.X and self.y.

_data_setter

    def _data_setter(self, X: pd.DataFrame, y: Optional[Union[np.ndarray, pd.Series]] = None) -> None:

Updates the mutable data containers within an existing PyMC model (self.model) with new data provided in X and y. It applies the necessary preprocessing steps (using fitted transformers stored on the instance, typically by mixins) to the new data before setting it with pm.set_data. Handles channel, control, Fourier, and target data updates.

channel_contributions_forward_pass

    def channel_contributions_forward_pass(self, channel_data: np.ndarray) -> np.ndarray:

Calculates the estimated channel contributions based on the fitted model parameters (posterior means or samples). It applies the adstock and saturation transformations using the fitted alpha and lam parameters and multiplies by the fitted beta_channel coefficients. Note: This base implementation returns contributions in the scaled target space. The overriding method in mmm_model.DelayedSaturatedMMM handles the inverse transformation back to the original target scale.

Helper Methods

_save_input_params

    def _save_input_params(self, idata: az.InferenceData) -> None:

Saves key initialization parameters (date_column, control_columns, channel_columns, adstock_max_lag, validate_data, yearly_seasonality) to the attrs of the InferenceData object during the save process.

_create_likelihood_distribution

    def _create_likelihood_distribution(...) -> pm.Distribution:

Internal helper to construct the likelihood part of the PyMC model based on the configuration provided in model_config["likelihood"]. Handles nested parameter distributions (e.g., for sigma).

_get_fourier_models_data

    def _get_fourier_models_data(self, X: pd.DataFrame) -> pd.DataFrame:

Internal helper to generate Fourier sine and cosine features based on the date column in X and the specified yearly_seasonality order. Uses abacus.core.utils.generate_fourier_modes.

_model_config_formatting (classmethod)

    @classmethod
    def _model_config_formatting(cls, model_config: Dict) -> Dict:

Internal helper used by load to correctly format the model configuration dictionary after loading it from JSON, converting lists back to NumPy arrays or tuples where appropriate.

load (classmethod)

    @classmethod
    def load(cls, fname: str) -> "BaseDelayedSaturatedMMM":

Overrides the load method from ModelBuilder to handle the specific attributes saved by _save_input_params for this class during instantiation. It loads the InferenceData, extracts configurations, reinstantiates the class (cls), rebuilds the model, and sets the idata.

Note: Methods related to fitting (fit), prediction (predict, predict_proba, predict_posterior, sample_posterior_predictive, sample_prior_predictive), saving (save), parameter access (get_params, set_params), and various plotting/analysis functions are inherited from model.MMM, BaseMMM, ModelBuilder, and the various mixin classes.