Abacus Core Mixins: MMM Analysis (mmm_analysis.py)¶
This module provides the MMMAnalysisMixin class, designed to be inherited by Marketing Mix Model (MMM) classes in Abacus. It offers methods for analysing the relationship between channel spend and contribution, and for visualising the expected impact of hypothetical spend scenarios over time.
MMMAnalysisMixin Class
This mixin assumes the inheriting class provides several attributes and methods, including:
X: The original input DataFrame.channel_columns: List of channel names.date_column: Name of the date column inX.preprocessed_data: A dictionary containing the preprocessed data used for fitting (e.g.,preprocessed_data['X']).target_transformer: The scaler used for the target variable (if applicable).channel_contributions_forward_pass(channel_data): A method (likely from another mixin or the base class) that calculates channel contributions for given channel input data, returning results in the original target scale.new_spend_contributions(spend, one_time, ...): A method (likely fromMMMPredictMixin) that calculates the contribution trajectory resulting from a new spend scenario.
Methods
get_channel_contributions_forward_pass_grid
def get_channel_contributions_forward_pass_grid(
self, start: float, stop: float, num: int
) -> DataArray:
Generates channel contributions across a grid of relative spend multipliers applied to the preprocessed channel data.
This method simulates varying levels of spend by multiplying the preprocessed channel input data by factors ranging from start to stop. For each factor (delta), it calculates the resulting channel contributions using the model’s channel_contributions_forward_pass method. The results are returned in the original target variable’s scale.
Parameters:
start(float): The starting multiplier for the spend grid (e.g., 0.0 for zero spend). Must be >= 0.stop(float): The ending multiplier for the spend grid (e.g., 2.0 for double spend). Must be >start.num(int): The number of points (spend levels) to calculate within the grid.
Returns:
xr.DataArray: An xarray DataArray containing the channel contributions. Dimensions are (delta,chain,draw,date,channel), wheredeltacorresponds to the spend multipliers in the grid. Values are in the original target scale.
Raises:
ValueError: Ifstart< 0.RuntimeError: If the model hasn’t been fitted,preprocessed_datais missing, or the requiredchannel_contributions_forward_passmethod is not found.
plot_channel_contributions_grid
def plot_channel_contributions_grid(
self,
start: float,
stop: float,
num: int,
absolute_xrange: bool = False,
**plt_kwargs: Any,
) -> plt.Figure:
Plots the total channel contribution (summed over time) against varying spend levels, based on the grid generated by get_channel_contributions_forward_pass_grid.
This visualisation helps understand the response curve for each channel. It shows the mean contribution and a 95% Highest Density Interval (HDI) across the posterior samples. The x-axis can represent either the relative spend multiplier (delta) or the absolute total spend.
Parameters:
start(float): Start multiplier for the spend grid (passed toget_channel_contributions_forward_pass_grid). Must be >= 0.stop(float): End multiplier for the spend grid (passed toget_channel_contributions_forward_pass_grid). Must be >start.num(int): Number of points in the grid (passed toget_channel_contributions_forward_pass_grid).absolute_xrange(bool, optional): IfTrue, the x-axis shows the absolute total spend (calculated asdelta* sum of original spend for the channel). IfFalse(default), the x-axis shows the relative spend multiplierdelta.**plt_kwargs: Additional keyword arguments passed tomatplotlib.pyplot.subplots.
Returns:
plt.Figure: The matplotlib Figure object containing the plot.
Raises:
RuntimeError: If the model hasn’t been fitted or the original input dataXis missing (required for calculating absolute spend).
plot_new_spend_contributions
def plot_new_spend_contributions(
self,
spend_amount: float,
one_time: bool = True,
lower: float = 0.025,
upper: float = 0.975,
ylabel: str = "Contribution",
idx: Optional[slice] = None,
channels: Optional[List[str]] = None,
prior: bool = False,
original_scale: bool = True,
ax: Optional[plt.Axes] = None,
**sample_posterior_predictive_kwargs: Any,
) -> plt.Axes:
Plots the expected contribution trajectory over time for selected channels resulting from a specific hypothetical spend scenario.
This method simulates applying a spend_amount to the specified channels (either as a one-time pulse or continuous spend) and plots the resulting mean contribution and a credible interval (defined by lower and upper quantiles) over time since the spend occurred. It relies on the new_spend_contributions method (expected from MMMPredictMixin).
Parameters:
spend_amount(float): The amount of spend to simulate for each selected channel.one_time(bool, optional): IfTrue(default), the spend is applied only at the first time step. IfFalse, the spend is applied continuously.lower(float, optional): The lower quantile for the credible interval (default: 0.025).upper(float, optional): The upper quantile for the credible interval (default: 0.975).ylabel(str, optional): Label for the y-axis (default: “Contribution”).idx(slice, optional): A time slice (relative to the spend start time, t=0) to display on the plot. Defaults to showing all time points from t=0 onwards.channels(Optional[List[str]], optional): A list of channel names to include in the plot. IfNone(default), all channels inself.channel_columnsare plotted.prior(bool, optional): IfTrue, uses samples from the prior predictive distribution. IfFalse(default), uses samples from the posterior predictive distribution.original_scale(bool, optional): IfTrue(default), plots contributions in the original target variable’s scale.ax(Optional[plt.Axes], optional): A matplotlib Axes object to plot on. IfNone, new axes are created.**sample_posterior_predictive_kwargs: Additional keyword arguments passed to the underlying sampling method (e.g.,sample_posterior_predictive).
Returns:
plt.Axes: The matplotlib Axes object containing the plot.
Raises:
ValueError: Iflowerorupperare invalid quantiles,lower > upper, or if a specified channel is not found inself.channel_columns.RuntimeError: If the requirednew_spend_contributionsmethod is not found in the class.