Abacus Core Mixins: Plotting Predictive (plotting_predictive.py)¶
This module provides the PredictivePlottingMixin class, designed to be inherited by Marketing Mix Model (MMM) classes in Abacus. It offers methods for visualising prior and posterior predictive checks, which are crucial steps in evaluating model fit and identifying potential model misspecification.
PredictivePlottingMixin Class
This mixin assumes the inheriting class provides attributes like:
prior_predictive: An ArviZInferenceDatagroup containing prior predictive samples.posterior_predictive: An ArviZInferenceDatagroup containing posterior predictive samples.output_var: The name of the target variable in the model.X: The original input DataFrame used for training.y: The original target variable Series used for training.date_column: Name of the date column inX.preprocessed_data: A dictionary containing preprocessed data (includingpreprocessed_data['y']).get_target_transformer(): A method to retrieve the target transformer (if scaling was used).
Methods
plot_prior_predictive
def plot_prior_predictive(self, samples: int = 1_000, **plt_kwargs) -> plt.Figure:
Plots the prior predictive check, comparing the model’s predictions before seeing the data against the actual observed data.
This plot helps assess whether the chosen priors are reasonable and generate plausible data patterns. It shows the 50% and 95% Highest Density Intervals (HDIs) of the prior predictive distribution for the target variable over time, overlaid with the actual observed target variable (from self.preprocessed_data['y']).
Parameters:
samples(int, optional): The number of prior predictive samples to consider (currently unused in the HDI calculation, which uses all available samples). Defaults to 1000.**plt_kwargs: Additional keyword arguments passed tomatplotlib.pyplot.subplots.
Returns:
plt.Figure: The matplotlib Figure object containing the prior predictive check plot.
Raises:
RuntimeError: If the model hasn’t been fitted withXandydata yet, or if prior predictive samples are missing.
plot_posterior_predictive
def plot_posterior_predictive(
self, original_scale: bool = False, **plt_kwargs
) -> plt.Figure:
Plots the posterior predictive check, comparing the model’s predictions after fitting to the data against the actual observed data.
This plot assesses how well the fitted model captures the patterns in the observed data. It shows the 50% and 90% Highest Density Intervals (HDIs) of the posterior predictive distribution for the target variable over time, overlaid with the actual observed target variable. The plot can be shown in the original data scale or the potentially transformed scale used during modelling.
Parameters:
original_scale(bool, optional): IfTrue, the predictive distributions and the observed data (self.y) are plotted in their original scale (inverse-transformed usingself.get_target_transformer()). IfFalse(default), they are plotted in the potentially transformed scale (usingself.preprocessed_data['y']).**plt_kwargs: Additional keyword arguments passed tomatplotlib.pyplot.subplots.
Returns:
plt.Figure: The matplotlib Figure object containing the posterior predictive check plot.
Raises:
RuntimeError: If the model hasn’t been fitted yet (missingX,y, or posterior predictive samples).