Data To Fit (data_to_fit.py)¶
This module defines the DataToFit class, which represents the final data structure prepared and scaled for input into the ABACUS MMM fitting process. It also includes utility functions related to data serialization.
Functions¶
switch_to_msgpack_numpy¶
def switch_to_msgpack_numpy() -> None:
Configures the msgpack library to use msgpack_numpy for serialization, enabling efficient serialization of NumPy arrays. This is typically called before saving or loading DataToFit objects using msgpack.
DataToFit¶
This class encapsulates the input data after it has been processed, split into training and testing sets, and scaled. This scaled data is directly used by the model fitting procedures.
class DataToFit:
"""
Represents input data transformed to be suitable for fitting a model.
The data undergoes the following transformations:
- Split into train and test data sets
- Scaled to smaller values for better accuracy from the Bayesian model
"""
Attributes¶
date_strs(Union[List[str], np.ndarray, jnp.ndarray]): List or array of date strings corresponding to the observations.time_granularity(str): Time granularity (e.g.,cnst.GRANULARITY_DAILY,cnst.GRANULARITY_WEEKLY).has_test_dataset(bool): Flag indicating whether a train/test split was performed (i.e., if test data exists).media_data_train_scaled(jnp.ndarray): Scaled media data for the training set[time, channel].media_data_test_scaled(jnp.ndarray): Scaled media data for the test set[time, channel]. Empty ifhas_test_datasetisFalse.media_scaler(SerializableScaler): The scaler object (fromabacus.prepro.scaler) used to scale/unscalemedia_data.media_costs_scaled(jnp.ndarray): Scaled total media costs[channel]. Scaled usingmedia_costs_scaler.media_cost_priors_scaled(jnp.ndarray): Scaled media cost priors[channel]. Scaled usingmedia_costs_scaler.learned_media_priors(jnp.ndarray): Learned media priors[channel](unscaled).media_costs_by_row_train_scaled(jnp.ndarray): Scaled media costs per observation for the training set[time, channel]. Scaled usingmedia_costs_scaler.media_costs_by_row_test_scaled(jnp.ndarray): Scaled media costs per observation for the test set[time, channel]. Empty ifhas_test_datasetisFalse. Scaled usingmedia_costs_scaler.media_costs_scaler(SerializableScaler): The scaler object used to scale/unscalemedia_costs,media_cost_priors, andmedia_costs_by_row. Fitted on the originalmedia_cost_priors.media_names(List[str]): List of media channel names.extra_features_train_scaled(jnp.ndarray): Scaled extra features data for the training set[time, feature].extra_features_test_scaled(jnp.ndarray): Scaled extra features data for the test set[time, feature]. Empty ifhas_test_datasetisFalse.extra_features_scaler(SerializableScaler): The scaler object used to scale/unscaleextra_features_data.extra_features_names(List[str]): List of extra feature names.target_train_scaled(jnp.ndarray): Scaled target variable data for the training set[time].target_test_scaled(jnp.ndarray): Scaled target variable data for the test set[time]. Empty ifhas_test_datasetisFalse.target_is_log_scale(bool): Flag indicating if the original target data was log-scaled before scaling bytarget_scaler.target_scaler(SerializableScaler): The scaler object used to scale/unscaletarget_data.target_name(str): Name of the target variable.
Static Methods¶
from_input_data¶
@staticmethod
def from_input_data(
input_data: InputData,
config: Dict[str, Any]
) -> "DataToFit":
Factory method to create a DataToFit instance from a raw InputData object and a configuration dictionary. It performs the train/test split based on config["train_test_ratio"] and fits scalers (SerializableScaler) to the full dataset before transforming the train and test splits.
Parameters:
input_data(InputData): The raw input data object.config(Dict[str, Any]): Configuration dictionary, primarily used to get thetrain_test_ratio.
Returns:
DataToFit: A new instance containing the split and scaled data.
from_dict¶
@staticmethod
def from_dict(input_dict: Dict[str, Any]) -> "DataToFit":
Factory method to recreate a DataToFit object from a dictionary representation (likely produced by to_dict). It reconstructs the scaler objects from their dictionary form.
Parameters:
input_dict(Dict[str, Any]): The dictionary containing the data, display info, scalers, and config.
Returns:
DataToFit: The reconstructedDataToFitobject.
from_file¶
@staticmethod
def from_file(input_file: str) -> "DataToFit":
Factory method to load a DataToFit object from a file, typically a .gz file saved using msgpack and gzip via the dump method.
Parameters:
input_file(str): Path to the input file (e.g.,data_to_fit.gz).
Returns:
DataToFit: The loadedDataToFitobject.
Instance Methods¶
__init__¶
def __init__(
self,
date_strs: Union[List[str], np.ndarray, jnp.ndarray],
time_granularity: str,
has_test_dataset: bool,
media_data_train_scaled: jnp.ndarray,
media_data_test_scaled: jnp.ndarray,
media_scaler: SerializableScaler,
media_costs_scaled: jnp.ndarray,
media_cost_priors_scaled: jnp.ndarray,
learned_media_priors: jnp.ndarray,
media_costs_by_row_train_scaled: jnp.ndarray,
media_costs_by_row_test_scaled: jnp.ndarray,
media_costs_scaler: SerializableScaler,
media_names: List[str],
extra_features_train_scaled: jnp.ndarray,
extra_features_test_scaled: jnp.ndarray,
extra_features_scaler: SerializableScaler,
extra_features_names: List[str],
target_train_scaled: jnp.ndarray,
target_test_scaled: jnp.ndarray,
target_is_log_scale: bool,
target_scaler: SerializableScaler,
target_name: str,
) -> None:
Initialises a DataToFit object directly with pre-split, pre-scaled data and fitted scaler objects. Typically used internally by the factory methods (from_input_data, from_dict, from_file).
(Parameters match the Attributes described above)
to_dict¶
def to_dict(self) -> Dict[str, Any]:
Converts the DataToFit object into a dictionary representation suitable for serialization (e.g., with msgpack). JAX arrays are converted to standard NumPy arrays, and scaler objects are converted using their to_dict methods.
Returns:
Dict[str, Any]: A dictionary containing nested dictionaries fordata,display,scalers, andconfig.
dump¶
def dump(self, results_dir: Union[str, Path]) -> None:
Serializes the DataToFit object using msgpack (with msgpack_numpy enabled) and saves it to a compressed gzip file named data_to_fit.gz within the specified directory.
Parameters:
results_dir(Union[str, Path]): The directory wheredata_to_fit.gzwill be saved.
to_data_frame¶
def to_data_frame(self, unscaled: bool = False) -> Tuple[pd.DataFrame, pd.DataFrame]:
Converts the scaled data back into pandas DataFrames for easier viewing and analysis. Can optionally inverse-transform the data back to its original scale.
Parameters:
unscaled(bool, optional): IfTrue, applies theinverse_transformmethod of the stored scalers to return data in the original scale. Defaults toFalse(returns scaled data).
Returns:
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing:per_observation_df: DataFrame indexed by datetime, containing time-series data (media impressions, media costs per row, extra features, target).per_channel_df: DataFrame indexed by media channel name, containing channel-level data (total cost, cost prior, learned prior).