Abacus Sketch: Plot Input (plot_input.py)

This module provides functions for visualising the raw or processed input data used in the Abacus Marketing Mix Model (MMM). These plots are useful for initial data exploration and understanding trends, correlations, and potential issues in the input variables.

Functions

plot_all_metrics

def plot_all_metrics(input_data, output_dir: str, suffix: str) -> None:

Generates a multi-panel plot showing the time series for each input metric: media channel volumes, media channel costs, extra features, and the target variable.

Each metric is plotted on its own subplot against time. The primary x-axis shows the observation index, while a secondary x-axis displays corresponding date labels. This plot provides a quick overview of all input time series.

Parameters:

  • input_data: An instance of abacus.prepro.input_data.InputData containing the data arrays and metadata (names, dates, granularity).

  • output_dir (str): The directory path where the output plot file (metrics_{suffix}.png) will be saved.

  • suffix (str): A suffix (e.g., “raw”, “processed”) to append to the output filename.

Returns:

  • None: Saves the plot to the specified directory.


plot_correlation_matrix

def plot_correlation_matrix(
    input_data, per_observation_df: pd.DataFrame
) -> tuple['plotly.graph_objs.Figure', pd.DataFrame]:

Calculates and plots a correlation matrix heatmap for media channel volumes and the target variable.

This function selects columns containing ‘volume’ in their name and the target column name from the per_observation_df, calculates the pairwise Pearson correlation between them, and displays the result as a heatmap using Plotly Express. High correlation coefficients (e.g., absolute value >= 0.7) might indicate multicollinearity issues.

Parameters:

  • input_data: An instance of abacus.prepro.input_data.InputData (used to get the target column name).

  • per_observation_df (pd.DataFrame): A DataFrame containing the data, typically derived from DataToFit.to_data_frame(). Must contain columns with ‘volume’ in the name and the target column.

Returns:

  • tuple:

    • plotly.graph_objs.Figure: The Plotly figure object for the heatmap.

    • pd.DataFrame: The calculated correlation matrix.


plot_all_media_spend

def plot_all_media_spend(
    input_data, per_observation_df: pd.DataFrame
) -> 'plotly.graph_objs.Figure':

Plots the target variable over time using Plotly Express.

Note: Despite the function name, this function plots the target variable specified by input_data.target_name against time (using the index of per_observation_df), not the media spend.

Parameters:

  • input_data: An instance of abacus.prepro.input_data.InputData (used to get the target column name).

  • per_observation_df (pd.DataFrame): A DataFrame containing the target variable column and a date/time index.

Returns:

  • plotly.graph_objs.Figure: The Plotly figure object showing the target variable’s time series.

Raises:

  • ValueError: If the target column (input_data.target_name) is not found in per_observation_df.


(Private helper function _plot_one_metric is used internally by plot_all_metrics to handle the plotting logic for a single metric time series on a given Matplotlib axis.)