Data To Fit (data_to_fit.py)

This module defines the DataToFit class, which represents the final data structure prepared and scaled for input into the ABACUS MMM fitting process. It also includes utility functions related to data serialization.

Functions

switch_to_msgpack_numpy

def switch_to_msgpack_numpy() -> None:

Configures the msgpack library to use msgpack_numpy for serialization, enabling efficient serialization of NumPy arrays. This is typically called before saving or loading DataToFit objects using msgpack.

DataToFit

This class encapsulates the input data after it has been processed, split into training and testing sets, and scaled. This scaled data is directly used by the model fitting procedures.

class DataToFit:
    """
    Represents input data transformed to be suitable for fitting a model.
    The data undergoes the following transformations:
    - Split into train and test data sets
    - Scaled to smaller values for better accuracy from the Bayesian model
    """

Attributes

  • date_strs (Union[List[str], np.ndarray, jnp.ndarray]): List or array of date strings corresponding to the observations.

  • time_granularity (str): Time granularity (e.g., cnst.GRANULARITY_DAILY, cnst.GRANULARITY_WEEKLY).

  • has_test_dataset (bool): Flag indicating whether a train/test split was performed (i.e., if test data exists).

  • media_data_train_scaled (jnp.ndarray): Scaled media data for the training set [time, channel].

  • media_data_test_scaled (jnp.ndarray): Scaled media data for the test set [time, channel]. Empty if has_test_dataset is False.

  • media_scaler (SerializableScaler): The scaler object (from abacus.prepro.scaler) used to scale/unscale media_data.

  • media_costs_scaled (jnp.ndarray): Scaled total media costs [channel]. Scaled using media_costs_scaler.

  • media_cost_priors_scaled (jnp.ndarray): Scaled media cost priors [channel]. Scaled using media_costs_scaler.

  • learned_media_priors (jnp.ndarray): Learned media priors [channel] (unscaled).

  • media_costs_by_row_train_scaled (jnp.ndarray): Scaled media costs per observation for the training set [time, channel]. Scaled using media_costs_scaler.

  • media_costs_by_row_test_scaled (jnp.ndarray): Scaled media costs per observation for the test set [time, channel]. Empty if has_test_dataset is False. Scaled using media_costs_scaler.

  • media_costs_scaler (SerializableScaler): The scaler object used to scale/unscale media_costs, media_cost_priors, and media_costs_by_row. Fitted on the original media_cost_priors.

  • media_names (List[str]): List of media channel names.

  • extra_features_train_scaled (jnp.ndarray): Scaled extra features data for the training set [time, feature].

  • extra_features_test_scaled (jnp.ndarray): Scaled extra features data for the test set [time, feature]. Empty if has_test_dataset is False.

  • extra_features_scaler (SerializableScaler): The scaler object used to scale/unscale extra_features_data.

  • extra_features_names (List[str]): List of extra feature names.

  • target_train_scaled (jnp.ndarray): Scaled target variable data for the training set [time].

  • target_test_scaled (jnp.ndarray): Scaled target variable data for the test set [time]. Empty if has_test_dataset is False.

  • target_is_log_scale (bool): Flag indicating if the original target data was log-scaled before scaling by target_scaler.

  • target_scaler (SerializableScaler): The scaler object used to scale/unscale target_data.

  • target_name (str): Name of the target variable.

Static Methods

from_input_data

    @staticmethod
    def from_input_data(
        input_data: InputData,
        config: Dict[str, Any]
    ) -> "DataToFit":

Factory method to create a DataToFit instance from a raw InputData object and a configuration dictionary. It performs the train/test split based on config["train_test_ratio"] and fits scalers (SerializableScaler) to the full dataset before transforming the train and test splits.

Parameters:

  • input_data (InputData): The raw input data object.

  • config (Dict[str, Any]): Configuration dictionary, primarily used to get the train_test_ratio.

Returns:

  • DataToFit: A new instance containing the split and scaled data.

from_dict

    @staticmethod
    def from_dict(input_dict: Dict[str, Any]) -> "DataToFit":

Factory method to recreate a DataToFit object from a dictionary representation (likely produced by to_dict). It reconstructs the scaler objects from their dictionary form.

Parameters:

  • input_dict (Dict[str, Any]): The dictionary containing the data, display info, scalers, and config.

Returns:

  • DataToFit: The reconstructed DataToFit object.

from_file

    @staticmethod
    def from_file(input_file: str) -> "DataToFit":

Factory method to load a DataToFit object from a file, typically a .gz file saved using msgpack and gzip via the dump method.

Parameters:

  • input_file (str): Path to the input file (e.g., data_to_fit.gz).

Returns:

  • DataToFit: The loaded DataToFit object.

Instance Methods

__init__

    def __init__(
        self,
        date_strs: Union[List[str], np.ndarray, jnp.ndarray],
        time_granularity: str,
        has_test_dataset: bool,
        media_data_train_scaled: jnp.ndarray,
        media_data_test_scaled: jnp.ndarray,
        media_scaler: SerializableScaler,
        media_costs_scaled: jnp.ndarray,
        media_cost_priors_scaled: jnp.ndarray,
        learned_media_priors: jnp.ndarray,
        media_costs_by_row_train_scaled: jnp.ndarray,
        media_costs_by_row_test_scaled: jnp.ndarray,
        media_costs_scaler: SerializableScaler,
        media_names: List[str],
        extra_features_train_scaled: jnp.ndarray,
        extra_features_test_scaled: jnp.ndarray,
        extra_features_scaler: SerializableScaler,
        extra_features_names: List[str],
        target_train_scaled: jnp.ndarray,
        target_test_scaled: jnp.ndarray,
        target_is_log_scale: bool,
        target_scaler: SerializableScaler,
        target_name: str,
    ) -> None:

Initialises a DataToFit object directly with pre-split, pre-scaled data and fitted scaler objects. Typically used internally by the factory methods (from_input_data, from_dict, from_file).

(Parameters match the Attributes described above)

to_dict

    def to_dict(self) -> Dict[str, Any]:

Converts the DataToFit object into a dictionary representation suitable for serialization (e.g., with msgpack). JAX arrays are converted to standard NumPy arrays, and scaler objects are converted using their to_dict methods.

Returns:

  • Dict[str, Any]: A dictionary containing nested dictionaries for data, display, scalers, and config.

dump

    def dump(self, results_dir: Union[str, Path]) -> None:

Serializes the DataToFit object using msgpack (with msgpack_numpy enabled) and saves it to a compressed gzip file named data_to_fit.gz within the specified directory.

Parameters:

  • results_dir (Union[str, Path]): The directory where data_to_fit.gz will be saved.

to_data_frame

    def to_data_frame(self, unscaled: bool = False) -> Tuple[pd.DataFrame, pd.DataFrame]:

Converts the scaled data back into pandas DataFrames for easier viewing and analysis. Can optionally inverse-transform the data back to its original scale.

Parameters:

  • unscaled (bool, optional): If True, applies the inverse_transform method of the stored scalers to return data in the original scale. Defaults to False (returns scaled data).

Returns:

  • Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing:

    • per_observation_df: DataFrame indexed by datetime, containing time-series data (media impressions, media costs per row, extra features, target).

    • per_channel_df: DataFrame indexed by media channel name, containing channel-level data (total cost, cost prior, learned prior).