MMM Base (mmm_base.py)¶
This module defines the BaseDelayedSaturatedMMM class, which serves as a more specialized base class for Marketing Mix Models within ABACUS, specifically implementing the core logic for delayed adstock and logistic saturation effects. It inherits from abacus.core.model.MMM.
BaseDelayedSaturatedMMM¶
class BaseDelayedSaturatedMMM(model.MMM):
"""Base class for Media Mix Models with delayed adstock and logistic saturation."""
This class provides the specific implementation for the build_model method, defining the PyMC model structure that includes:
Geometric adstock transformation (
abacus.core.transformers.geometric_adstock).Logistic saturation transformation (
abacus.core.transformers.logistic_saturation).Optional control variables.
Optional Fourier modes for yearly seasonality.
Optional integration of lift test data via
abacus.core.lift_test.add_lift_measurements_to_likelihood_from_saturation.
It inherits the overall structure, fitting, prediction, saving/loading, and mixin integration from abacus.core.model.MMM and abacus.core.base.BaseMMM.
Initialization¶
def __init__(
self,
date_column: str,
channel_columns: List[str],
adstock_max_lag: int,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
validate_data: bool = True,
control_columns: Optional[List[str]] = None,
yearly_seasonality: Optional[int] = None,
df_lift_test: Optional[pd.DataFrame] = None,
**kwargs,
) -> None:
Initialises the BaseDelayedSaturatedMMM model.
Parameters:
date_column(str): Column name of the date variable.channel_columns(List[str]): Column names of the media channel variables.adstock_max_lag(int): Number of lags to consider in the adstock transformation.model_config(Optional[Dict], optional): Dictionary of parameters that initialise model configuration (priors, likelihood). Usesdefault_model_configif None. Defaults toNone.sampler_config(Optional[Dict], optional): Dictionary of parameters that initialise sampler configuration. Usesdefault_sampler_configif None. Defaults toNone.validate_data(bool, optional): Whether to validate the data before fitting the model. Defaults toTrue.control_columns(Optional[List[str]], optional): Column names of control variables. Defaults toNone.yearly_seasonality(Optional[int], optional): Number of Fourier modes to model yearly seasonality. Defaults toNone.df_lift_test(Optional[pd.DataFrame], optional): DataFrame containing lift test data for calibration. Defaults toNone.**kwargs: Additional keyword arguments passed to the parent class (model.MMM)__init__.
Attributes Initialised:
control_columns,adstock_max_lag,yearly_seasonality,date_column,validate_data,df_lift_test: Stores the corresponding initialization parameters.Inherited attributes from
model.MMM,BaseMMM, andModelBuilder.
Properties¶
default_sampler_config¶
@property
def default_sampler_config(self) -> Dict:
Returns an empty dictionary, indicating no specific sampler defaults are set at this level.
output_var¶
@property
def output_var(self) -> str:
Returns "y", the default name for the target variable in the model.
default_model_config¶
@property
def default_model_config(self) -> Dict:
Returns the default configuration for the model’s priors and likelihood. Defines default distributions and parameters for:
- intercept (LogNormal)
- beta_channel (HalfNormal) - Channel effectiveness coefficients
- alpha (Beta) - Adstock retention rate
- lam (Gamma) - Saturation rate (lambda)
- likelihood (Normal with HalfNormal sigma) - Observation model
- gamma_control (Normal) - Control variable coefficients
- gamma_fourier (Laplace) - Fourier mode coefficients
_serializable_model_config¶
@property
def _serializable_model_config(self) -> Dict:
Returns a JSON-serializable version of the model_config by converting NumPy arrays to lists.
Core Methods¶
build_model¶
def build_model(self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray], **kwargs: Any) -> None:
Defines the PyMC model structure within a pm.Model context. This is the core implementation for this specific MMM type.
Sets up mutable data containers for channel, target, control (if any), and Fourier (if any) data using preprocessed data stored in
self.preprocessed_data.Defines priors for
intercept,beta_channel,alpha(adstock),lam(saturation),gamma_control, andgamma_fourierbased onmodel_config.Applies geometric adstock (
tm.geometric_adstock) to channel data.Applies logistic saturation (
tm.logistic_saturation) to the adstocked channel data.Calculates channel contributions (
channel_adstock_saturated * beta_channel).Calculates the model mean (
mu) by summing intercept, channel contributions, control contributions (if any), and Fourier contributions (if any).Defines the likelihood distribution using
_create_likelihood_distributionbased onmodel_config["likelihood"], linkingmuto the observedtargetdata.Optionally adds lift test constraints to the likelihood using
add_lift_measurements_to_likelihood_from_saturationifself.df_lift_testwas provided during initialization.
_generate_and_preprocess_model_data¶
def _generate_and_preprocess_model_data(self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray]) -> None:
Prepares the data for build_model.
Separates date, channel, and control columns from the input
X.Generates Fourier features using
_get_fourier_models_dataifyearly_seasonalityis set.Concatenates all relevant features (date, channels, controls, Fourier) into
X_data.Sets model coordinates (
self.model_coords,self.coords_mutable).Optionally calls validation methods (inherited from mixins via
model.MMM).Calls preprocessing methods (inherited from mixins via
model.MMM) onX_dataandy.Stores the final preprocessed features and target in
self.preprocessed_data.Stores the original (or potentially partially processed
X_data) andyinself.Xandself.y.
_data_setter¶
def _data_setter(self, X: pd.DataFrame, y: Optional[Union[np.ndarray, pd.Series]] = None) -> None:
Updates the mutable data containers within an existing PyMC model (self.model) with new data provided in X and y. It applies the necessary preprocessing steps (using fitted transformers stored on the instance, typically by mixins) to the new data before setting it with pm.set_data. Handles channel, control, Fourier, and target data updates.
channel_contributions_forward_pass¶
def channel_contributions_forward_pass(self, channel_data: np.ndarray) -> np.ndarray:
Calculates the estimated channel contributions based on the fitted model parameters (posterior means or samples). It applies the adstock and saturation transformations using the fitted alpha and lam parameters and multiplies by the fitted beta_channel coefficients. Note: This base implementation returns contributions in the scaled target space. The overriding method in mmm_model.DelayedSaturatedMMM handles the inverse transformation back to the original target scale.
Helper Methods¶
_save_input_params¶
def _save_input_params(self, idata: az.InferenceData) -> None:
Saves key initialization parameters (date_column, control_columns, channel_columns, adstock_max_lag, validate_data, yearly_seasonality) to the attrs of the InferenceData object during the save process.
_create_likelihood_distribution¶
def _create_likelihood_distribution(...) -> pm.Distribution:
Internal helper to construct the likelihood part of the PyMC model based on the configuration provided in model_config["likelihood"]. Handles nested parameter distributions (e.g., for sigma).
_get_fourier_models_data¶
def _get_fourier_models_data(self, X: pd.DataFrame) -> pd.DataFrame:
Internal helper to generate Fourier sine and cosine features based on the date column in X and the specified yearly_seasonality order. Uses abacus.core.utils.generate_fourier_modes.
_model_config_formatting (classmethod)¶
@classmethod
def _model_config_formatting(cls, model_config: Dict) -> Dict:
Internal helper used by load to correctly format the model configuration dictionary after loading it from JSON, converting lists back to NumPy arrays or tuples where appropriate.
load (classmethod)¶
@classmethod
def load(cls, fname: str) -> "BaseDelayedSaturatedMMM":
Overrides the load method from ModelBuilder to handle the specific attributes saved by _save_input_params for this class during instantiation. It loads the InferenceData, extracts configurations, reinstantiates the class (cls), rebuilds the model, and sets the idata.
Note: Methods related to fitting (fit), prediction (predict, predict_proba, predict_posterior, sample_posterior_predictive, sample_prior_predictive), saving (save), parameter access (get_params, set_params), and various plotting/analysis functions are inherited from model.MMM, BaseMMM, ModelBuilder, and the various mixin classes.