Driver (driver.py)

This module provides the main driver class for orchestrating the ABACUS Marketing Mix Modelling (MMM) workflow.

MMMBaseDriver

Base driver class responsible for managing the end-to-end MMM process. It handles loading configuration, ingesting and validating data, fitting the MMM model, generating visualisations, and saving results.

class MMMBaseDriver:
    """
    Base driver class for orchestrating the MMM workflow.

    Handles loading configuration, ingesting data, fitting the model,
    visualising results, and saving outputs.
    """

Initialization

    def __init__(
        self,
        config_filename: str,
        input_filename: str,
        holidays_filename: str,
        results_filename: str = "results",
        results_dir: str = "./results",
        results_folder_creation_time: Optional[datetime] = None,
        model: Optional[mmm_model.DelayedSaturatedMMM] = None,
        run_id: Optional[Union[str, datetime]] = None
    ) -> None:

Initialises the MMMBaseDriver. During initialisation, it validates the input data using abacus.prepro.prepro.validate_data and splits the data into training and testing sets using abacus.prepro.prepro.split_data.

Parameters:

  • config_filename (str): Path to the YAML configuration file.

  • input_filename (str): Path to the input data file (CSV/Excel).

  • holidays_filename (str): Path to the holidays file (CSV/Excel).

  • results_filename (str, optional): Base name for the results folder. Defaults to "results".

  • results_dir (str, optional): Base directory for saving results. Defaults to "./results".

  • results_folder_creation_time (Optional[datetime], optional): Pre-defined creation time for results folder. Defaults to None.

  • model (Optional[mmm_model.DelayedSaturatedMMM], optional): Pre-fitted model instance. Defaults to None.

  • run_id (Optional[Union[str, datetime]], optional): Identifier for the run. Defaults to None.

Attributes Initialised:

  • config_raw (bytes): Raw configuration file content.

  • config (Dict[str, Any]): Parsed configuration dictionary.

  • processed_data (pd.DataFrame): Data after initial processing and validation.

  • input_data (Any): Input data structure (type depends on validate_data).

  • data_to_fit (Any): Data structure prepared for model fitting (type depends on validate_data).

  • per_observation_df (pd.DataFrame): DataFrame containing per-observation data.

  • X, y, X_train, y_train, X_test, y_test: DataFrames/Series resulting from the train/test split.

  • model: The MMM model instance (initially None or the provided pre-fitted model).

  • results_dir: Absolute path to the results directory.

  • correlations, variances, spend_fractions, variance_inflation_factors: Data quality check results (populated by check_quality).

  • corr_df: Correlation matrix DataFrame (populated by plot_correlation).

Methods

check_quality

    def check_quality(self) -> Tuple[List[pd.DataFrame], pd.DataFrame, pd.DataFrame, List[pd.DataFrame]]:

Checks the quality of the input data using abacus.prepro.prepro.check_quality.

Returns:

  • Tuple[List[pd.DataFrame], pd.DataFrame, pd.DataFrame, List[pd.DataFrame]]: A tuple containing:

    • correlations: List of correlation DataFrames.

    • variances: DataFrame of feature variances.

    • spend_fractions: DataFrame of channel spend fractions.

    • variance_inflation_factors: List of VIF DataFrames.

init_output

    def init_output(self, data_dir: str = ".") -> str:

Creates the output directory structure for saving results.

Parameters:

  • data_dir (str, optional): The base directory within which to create the results folder. Defaults to the current directory (.).

Returns:

  • str: The absolute path to the created results directory.

highlight_variances

    def highlight_variances(self, in_place: bool = False) -> pd.DataFrame:

Highlights variance values based on predefined thresholds using abacus.utils.highlight_threshold_values.

Parameters:

  • in_place (bool, optional): Whether to modify the internal variances DataFrame directly. Defaults to False.

Returns:

  • pd.DataFrame: A styled DataFrame with highlighted variance values.

highlight_low_spend_fractions

    def highlight_low_spend_fractions(self, in_place: bool = False) -> pd.DataFrame:

Highlights low spend fraction values using abacus.utils.highlight_threshold_values.

Parameters:

  • in_place (bool, optional): Whether to modify the internal spend_fractions DataFrame directly. Defaults to False.

Returns:

  • pd.DataFrame: A styled DataFrame with highlighted low spend fractions.

highlight_high_vif_values

    def highlight_high_vif_values(self, in_place: bool = False) -> pd.DataFrame:

Highlights high Variance Inflation Factor (VIF) values using abacus.utils.highlight_threshold_values.

Parameters:

  • in_place (bool, optional): Whether to modify the internal variance_inflation_factors DataFrame directly. Defaults to False.

Returns:

  • pd.DataFrame: A styled DataFrame with highlighted high VIF values.

run_feature_engineering

    def run_feature_engineering(self) -> Any:

Placeholder method for custom feature engineering steps. In the base implementation, it’s a no-op.

Returns:

  • Any: The processed input data (typically self.input_data).

describe_data

    def describe_data(
        self,
        input_data_raw: Any,
        input_data_processed: Any,
        current_commit: str
    ) -> None:

Generates descriptive summaries and plots for both raw and processed input data, saving them to the results directory. Also saves configuration details. Uses functions from abacus.sketch.depict.

Parameters:

  • input_data_raw (Any): The raw input data structure.

  • input_data_processed (Any): The processed input data structure.

  • current_commit (str): The current Git commit hash (or relevant version identifier).

plot_correlation

    def plot_correlation(self) -> Tuple[matfig.Figure, pd.DataFrame]:

Plots the correlation matrix of the input features and target variable. Uses abacus.sketch.plot_correlation_matrix.

Returns:

  • Tuple[matfig.Figure, pd.DataFrame]: A tuple containing:

    • The matplotlib Figure object of the plot.

    • The DataFrame containing the correlation values.

visualize

    def visualize(self) -> None:

Generates and saves standard visualisations of the model training and prediction results using abacus.sketch.depict.describe_mmm_training and describe_mmm_prediction.

describe_all_media_spend

    def describe_all_media_spend(self) -> pd.DataFrame:

Generates a descriptive summary DataFrame of media spend across all channels using abacus.sketch.depict.describe_all_media_spend.

Returns:

  • pd.DataFrame: DataFrame containing media spend summary statistics.

plot_all_media_spend

    def plot_all_media_spend(self) -> matfig.Figure:

Plots the media spend over time for all channels using abacus.sketch.plot_all_media_spend.

Returns:

  • matfig.Figure: The matplotlib Figure object of the plot.

fit_model

    def fit_model(
        self,
        model_filename: Optional[str] = None
    ) -> mmm_model.DelayedSaturatedMMM:

Fits the DelayedSaturatedMMM model using the training data (X_train, y_train) and configuration settings. Can optionally load a pre-fitted model from a file. Performs a variance check on training media data before fitting.

Parameters:

  • model_filename (Optional[str], optional): Path to a saved model file (.nc) to load instead of fitting. Defaults to None.

Returns:

  • mmm_model.DelayedSaturatedMMM: The fitted (or loaded) model instance.

plot_model_structure

    def plot_model_structure(self) -> gz.Digraph:

Generates a graphviz visualisation of the underlying PyMC model structure.

Returns:

  • graphviz.Digraph: The model structure graph.

plot_model_trace

    def plot_model_trace(self) -> matfig.Figure:

Plots the trace diagnostics (posterior distributions and chain history) for key model parameters using ArviZ.

Returns:

  • matfig.Figure: The matplotlib Figure object showing the trace plot.

predict_on_test

    def predict_on_test(self) -> Union[xr.DataArray, Any]:

Generates posterior predictive samples using the test dataset (X_test).

Returns:

  • Union[xr.DataArray, Any]: Posterior predictive samples for the test set (typically an xarray DataArray).

display_image

    def display_image(self, image_filename: str) -> Any:

Displays an image file located in the results directory. Primarily useful within Jupyter notebooks. Requires IPython.

Parameters:

  • image_filename (str): The filename of the image within the results directory.

Returns:

  • Any: The image object for display (typically IPython.core.display.Image).

calculate_train_r_squared

    def calculate_train_r_squared(self) -> float:

Calculates the R-squared metric for the model’s predictions on the training data. Samples the posterior predictive for the training set and compares the mean prediction against the actual y_train values.

Returns:

  • float: The R-squared value.

plot_posterior_predictive

    def plot_posterior_predictive(self) -> matfig.Figure:

Plots the posterior predictive distribution against the actual training data. Includes the calculated R-squared value in the title.

Returns:

  • matfig.Figure: The matplotlib Figure object of the posterior predictive plot.

plot_components_contributions

    def plot_components_contributions(self) -> matfig.Figure:

Plots the contribution of each model component (intercept, channels, control features) to the target variable over time. Uses the plot_components_contributions method of the fitted model.

Returns:

  • matfig.Figure: The matplotlib Figure object of the component contributions plot.

plot_posterior_predictions

    def plot_posterior_predictions(self) -> matfig.Figure:

Plots the in-sample and out-of-sample posterior predictions against actual values using abacus.sketch.plot_posterior_predictions.

Returns:

  • matfig.Figure: The matplotlib Figure object of the posterior predictions plot.

plot_waterfall_components_decomposition

    def plot_waterfall_components_decomposition(
        self,
        model: Optional[mmm_model.DelayedSaturatedMMM] = None,
        original_scale: bool = True,
        figsize: Tuple[int, int] = (14, 7),
        **kwargs: Any,
    ) -> matfig.Figure:

Plots a waterfall chart showing the decomposition of the model components’ contributions. Uses abacus.sketch.plot_waterfall_components_decomposition.

Parameters:

  • model (Optional[mmm_model.DelayedSaturatedMMM], optional): A specific model instance to plot. If None, uses the instance’s fitted model (self.model). Defaults to None.

  • original_scale (bool, optional): Whether to plot contributions on the original data scale. Defaults to True.

  • figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 7).

  • **kwargs: Additional keyword arguments passed to the underlying plotting function.

Returns:

  • matfig.Figure: The matplotlib Figure object of the waterfall plot.

create_downloadable_zip

    def create_downloadable_zip(
        self,
        excluded_files: List[str],
        source_folder: str = '/content',
        zip_name: str = 'Model_files.zip'
    ) -> None:

Creates a zip archive of a specified folder, excluding certain files, and attempts to trigger a download. Primarily intended for Google Colab environments.

Parameters:

  • excluded_files (List[str]): List of file or folder names to exclude from the zip archive.

  • source_folder (str, optional): The path to the folder to be zipped. Defaults to '/content'.

  • zip_name (str, optional): The desired name for the output zip file. Defaults to 'Model_files.zip'.

main

    def main(self) -> str:

The main execution method that orchestrates the typical MMM workflow: initialises output, runs feature engineering (if any), fits the model, saves the model, visualises results, and writes basic run information.

Returns:

  • str: The absolute path to the results directory for this run.

plot_posterior_distributions

    def plot_posterior_distributions(
        self,
        results_dir: str = '/content/results',
        filename_prefix: str = 'media_spend_posterior'
    ) -> pd.DataFrame:

Plots the posterior distributions (Kernel Density Estimate) for the beta_channel parameters (media effectiveness) for each channel. Saves the plot and returns a DataFrame with summary statistics.

Parameters:

  • results_dir (str, optional): Directory to save the plot. Defaults to '/content/results'.

  • filename_prefix (str, optional): Prefix for the saved plot filename. Defaults to 'media_spend_posterior'.

Returns:

  • pd.DataFrame: DataFrame containing summary statistics (mean, median, std, HDI, skewness, kurtosis) for each channel’s beta_channel posterior.