-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbst-script.py
More file actions
1 lines (1 loc) · 172 KB
/
bst-script.py
File metadata and controls
1 lines (1 loc) · 172 KB
1
{"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"sourceId":10523408,"sourceType":"datasetVersion","datasetId":6142016}],"dockerImageVersionId":30919,"isInternetEnabled":true,"language":"python","sourceType":"script","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# %% [markdown]\n# # INSTALL REQUIREMENTS\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:14:09.498442Z\",\"iopub.execute_input\":\"2025-03-19T10:14:09.498695Z\",\"iopub.status.idle\":\"2025-03-19T10:15:34.925158Z\",\"shell.execute_reply.started\":\"2025-03-19T10:14:09.498667Z\",\"shell.execute_reply\":\"2025-03-19T10:15:34.924139Z\"}}\n!pip install -r /kaggle/input/requirements1.txt\n\n# %% [markdown]\n# # DATA PROCESSING UTILS\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:34.926240Z\",\"iopub.execute_input\":\"2025-03-19T10:15:34.926496Z\",\"iopub.status.idle\":\"2025-03-19T10:15:34.932229Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:34.926464Z\",\"shell.execute_reply\":\"2025-03-19T10:15:34.931560Z\"}}\nimport numpy as np\n\n\nclass StatelessRandomGenerator:\n def __init__(self, seed=42):\n self.seed = seed\n\n def set_seed(self, new_seed):\n self.seed = new_seed\n\n def random(self, size=None):\n rng = np.random.default_rng(self.seed)\n return rng.random(size)\n\n def integers(self, low, high=None, size=None):\n rng = np.random.default_rng(self.seed)\n return rng.integers(low, high, size)\n\n def choice(self, a, size=None, replace=True, p=None):\n rng = np.random.default_rng(self.seed)\n return rng.choice(a, size, replace, p)\n\n\nglobal_rng = StatelessRandomGenerator(42)\n\n\ndef set_global_seed(new_seed):\n global_rng.set_seed(new_seed)\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:34.933031Z\",\"iopub.execute_input\":\"2025-03-19T10:15:34.933344Z\",\"iopub.status.idle\":\"2025-03-19T10:15:36.562512Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:34.933314Z\",\"shell.execute_reply\":\"2025-03-19T10:15:36.561645Z\"}}\nimport torch\n\n\ndef wmape_metric(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:\n return torch.sum(torch.abs(pred - true), dim=0) / torch.sum(true, dim=0)\n\n# %% [markdown]\n# # DATA PROCESSING\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:36.564784Z\",\"iopub.execute_input\":\"2025-03-19T10:15:36.565098Z\",\"iopub.status.idle\":\"2025-03-19T10:15:36.750394Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:36.565079Z\",\"shell.execute_reply\":\"2025-03-19T10:15:36.748944Z\"}}\nfrom datetime import datetime\nimport json\nfrom pathlib import Path\nimport polars as pl\n#from data_processing.utils.stateless_rng import global_rng\n\ndef filter_purchases_purchases_per_month_pl(\n df_pl: pl.DataFrame, train_end: datetime.date, group_by_channel_id: bool = False\n):\n \"\"\"Filters extreme customers and groups purchases by date and optionally by sales channel.\n\n This function:\n 1. Groups transactions by customer, date, and optionally sales channel\n 2. Identifies extreme customers based on the 99th percentile of total items purchased\n 3. Removes these customers from the dataset\n\n Args:\n df_pl (pl.DataFrame): Input transaction dataframe containing:\n - customer_id: Customer identifier\n - date: Transaction date\n - article_id: Product identifier\n - price: Transaction price\n - sales_channel_id: Sales channel identifier\n train_end (datetime.date): End date for training period.\n group_by_channel_id (bool, optional): Whether to group transactions by sales channel. Defaults to False.\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:\n - grouped_df: Grouped transaction data with columns:\n - customer_id, date, [sales_channel_id], article_ids, total_price, prices, num_items\n - extreme_customers: DataFrame of customers identified as outliers based on purchase behavior\n\n Notes:\n Extreme customers are identified using the 99th percentile of total items purchased\n during the training period.\n \"\"\"\n # Used for multi variate time series\n if group_by_channel_id:\n grouped_df = (\n df_pl.lazy()\n .group_by([\"customer_id\", \"date\", \"sales_channel_id\"])\n .agg(\n [\n pl.col(\"article_id\").explode().alias(\"article_ids\"),\n pl.col(\"price\").sum().round(2).alias(\"total_price\"),\n pl.col(\"price\").explode().alias(\"prices\"),\n ]\n )\n .with_columns(pl.col(\"article_ids\").list.len().alias(\"num_items\"))\n )\n else:\n grouped_df = (\n df_pl.lazy()\n .group_by([\"customer_id\", \"date\"])\n .agg(\n [\n pl.col(\"article_id\").explode().alias(\"article_ids\"),\n pl.col(\"price\").sum().round(2).alias(\"total_price\"),\n pl.col(\"sales_channel_id\").explode().alias(\"sales_channel_ids\"),\n pl.col(\"price\").explode().alias(\"prices\"),\n ]\n )\n .with_columns(pl.col(\"article_ids\").list.len().alias(\"num_items\"))\n )\n\n # Only remove customers with extreme purchases in train period\n customers_summary = (\n df_pl.lazy()\n .filter(pl.col(\"date\") < train_end)\n .group_by(\"customer_id\")\n .agg(\n [\n pl.col(\"date\").n_unique().alias(\"total_purchases\"),\n pl.col(\"price\").sum().round(2).alias(\"total_spent\"),\n pl.col(\"article_id\").flatten().alias(\"flattened_ids\")\n ]\n )\n .with_columns(pl.col(\"flattened_ids\").list.len().alias(\"total_items\"))\n )\n\n quantile = 0.99\n total_purchases_99, total_spending_99, total_items_99 = (\n customers_summary.select(\n [\n pl.col(\"total_purchases\").quantile(quantile),\n pl.col(\"total_spent\").quantile(quantile),\n pl.col(\"total_items\").quantile(quantile),\n ]\n )\n .collect()\n .to_numpy()\n .flatten()\n )\n\n # Currently only remove customers with very large number of total items purchased\n extreme_customers = customers_summary.filter(\n (pl.col(\"total_items\") >= total_items_99)\n # | (pl.col(\"total_purchases\") >= total_purchases_99)\n # | (pl.col(\"total_spent\") >= total_spending_99)\n )\n\n extreme_customers = extreme_customers.select(\"customer_id\").unique()\n extreme_customers = extreme_customers.collect()\n\n print(\n f\"\"\"\n Cutoff Values for {quantile*100}th Percentiles:\n -----------------------------------\n Total items bought: {total_items_99:.0f} items\n\n -----------------------------------\n Removed Customers: {len(extreme_customers):,}\n \"\"\"\n )\n\n return grouped_df.collect(), extreme_customers\n\ndef train_test_split(\n train_df: pl.DataFrame,\n test_df: pl.DataFrame,\n subset: int = None,\n train_subsample_percentage: float = None,\n) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:\n \"\"\"Splits data into train, validation, and test sets with optional subsampling.\n\n The function performs the following operations:\n 1. Optional subsampling of both train and test data\n 2. Optional percentage-based subsampling of training data\n 3. Creates a validation set from 10% of the training data\n\n Args:\n train_df (pl.DataFrame): Training dataset.\n test_df (pl.DataFrame): Test dataset.\n subset (int, optional): If provided, limits both train and test sets to first n rows. \n Defaults to None.\n train_subsample_percentage (float, optional): If provided, randomly samples this percentage \n of training data. Defaults to None.\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Tuple containing:\n - train_df: Final training dataset (90% of training data after subsampling)\n - val_df: Validation dataset (10% of training data)\n - test_df: Test dataset (potentially subsampled)\n\n Notes:\n If both subset and train_subsample_percentage are provided, subset is applied first.\n The validation set is always 10% of the remaining training data after any subsampling.\n \"\"\"\n\n if subset is not None:\n train_df = train_df[:subset]\n test_df = test_df[:subset]\n elif train_subsample_percentage is not None:\n sampled_indices = global_rng.choice(\n len(train_df),\n size=int(train_subsample_percentage * len(train_df)),\n replace=False,\n )\n train_df = train_df[sampled_indices]\n\n # Train-val-split\n # Calculate 10% of the length of the array\n sampled_indices = global_rng.choice(\n len(train_df), size=int(0.1 * len(train_df)), replace=False\n )\n val_df = train_df[sampled_indices]\n train_df = train_df.filter(~pl.arange(0, pl.count()).is_in(sampled_indices))\n\n return train_df, val_df, test_df\n\ndef map_article_ids(df: pl.DataFrame, data_path: Path) -> pl.DataFrame:\n \"\"\"Maps article IDs to new running IDs using a mapping dictionary from JSON.\n\n Args:\n df (pl.DataFrame): DataFrame with 'article_id' column to be mapped.\n data_path (Path): Path to directory with 'running_id_dict.json' containing ID mappings.\n\n Returns:\n pl.DataFrame: DataFrame with mapped article IDs, sorted by new IDs. Non-mapped articles are removed.\n \"\"\"\n with open(data_path / \"running_id_dict.json\", \"r\") as f:\n data = json.load(f)\n article_id_dict = data[\"combined\"]\n\n mapping_df = pl.DataFrame(\n {\n \"old_id\": list(article_id_dict.keys()),\n \"new_id\": list(article_id_dict.values()),\n },\n schema_overrides={\"old_id\": pl.Int32, \"new_id\": pl.Int32},\n )\n\n # Join and select\n df = df.join(\n mapping_df, left_on=\"article_id\", right_on=\"old_id\", how=\"inner\"\n ).select(\n pl.col(\"new_id\").alias(\"article_id\"),\n pl.all().exclude([\"article_id\", \"old_id\", \"new_id\"]),\n )\n df = df.sort(\"article_id\")\n\n return df\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:36.752205Z\",\"iopub.execute_input\":\"2025-03-19T10:15:36.752507Z\",\"iopub.status.idle\":\"2025-03-19T10:15:36.769276Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:36.752477Z\",\"shell.execute_reply\":\"2025-03-19T10:15:36.768498Z\"}}\n#from pathlib import Path\n#from data_processing.customer_df.customer_df import get_customer_df_benchmarks\n#from data_processing.transaction_df.transaction_df import get_tx_article_dfs\nimport polars as pl\n\n\ndef expand_list_columns(\n df: pl.DataFrame, date_col: str = \"days_before_lst\", num_col: str = \"num_items_lst\"\n) -> pl.DataFrame:\n \"\"\"\n Expand a Polars DataFrame by repeating each element in a list column according to\n the counts specified in another list column.\n\n Args:\n df: Input Polars DataFrame with list columns\n date_col: Name of the column containing the lists to be expanded\n num_col: Name of the column containing lists of counts\n\n Returns:\n A new Polars DataFrame where the list elements in date_col have been expanded\n \"\"\"\n expanded = df.with_columns(\n pl.struct([date_col, num_col])\n .map_elements(\n lambda x: [\n date\n for date, count in zip(x[date_col], x[num_col])\n for _ in range(count)\n ]\n )\n .alias(date_col)\n )\n\n return expanded\n\n\ndef add_benchmark_tx_features(df: pl.DataFrame) -> pl.DataFrame:\n \"\"\"Creates benchmark transaction features from aggregated customer transaction data.\n\n Args:\n df: A Polars DataFrame containing aggregated transaction data with list columns\n including total_price_lst, num_items_lst, days_before_lst, price_lst,\n and CLV_label.\n\n Returns:\n pl.DataFrame: A DataFrame with derived features including:\n - total_spent: Sum of all transaction amounts\n - total_purchases: Count of transactions\n - total_items: Sum of items purchased\n - days_since_last_purchase: Days since most recent transaction\n - days_since_first_purchase: Days since first transaction\n - avg_spent_per_transaction: Mean transaction amount\n - avg_items_per_transaction: Mean items per transaction\n - avg_days_between: Mean days between transactions\n - regression_label: CLV label for regression\n - classification_label: Binary CLV label (>0)\n\n Note:\n The avg_days_between calculation may return None for customers with single\n transactions, which is handled by tree-based algorithms.\n \"\"\"\n return df.select(\n \"customer_id\",\n pl.col(\"total_price_lst\").list.sum().alias(\"total_spent\"),\n pl.col(\"total_price_lst\").list.len().alias(\"total_purchases\"),\n pl.col(\"num_items_lst\").list.sum().alias(\"total_items\"),\n pl.col(\"days_before_lst\").list.get(-1).alias(\"days_since_last_purchase\"),\n pl.col(\"days_before_lst\").list.get(0).alias(\"days_since_first_purchase\"),\n pl.col(\"price_lst\").list.mean().alias(\"avg_spent_per_transaction\"),\n (\n pl.col(\"num_items_lst\")\n .list.mean()\n .cast(pl.Float32)\n .alias(\"avg_items_per_transaction\")\n ),\n # Code below returns None values for customers with single Tx\n # Tree algos should be able to handle this\n (\n pl.col(\"days_before_lst\")\n .list.diff(null_behavior=\"drop\")\n .list.mean()\n .mul(-1)\n .cast(pl.Float32)\n .alias(\"avg_days_between\")\n ),\n pl.col(\"CLV_label\").alias(\"regression_label\"),\n pl.col(\"CLV_label\").gt(0).cast(pl.Int32).alias(\"classification_label\"),\n )\n\n\ndef process_dataframe(df: pl.DataFrame, max_length: int = 20) -> pl.DataFrame:\n \"\"\"Processes a polars DataFrame by expanding list columns and selecting specific columns with transformations.\n\n This function performs several operations on the input DataFrame:\n 1. Expands list columns using the expand_list_columns function\n 2. Selects and renames specific columns\n 3. Truncates list columns to a maximum length\n\n Args:\n df: A polars DataFrame containing customer transaction data\n max_length: Maximum number of elements to keep in list columns (default: 20)\n\n Returns:\n A processed polars DataFrame with the following columns:\n - customer_id: Customer identifier\n - days_before_lst: Truncated list of days before some reference date\n - articles_ids_lst: Truncated list of article identifiers\n - regression_label: CLV label for regression tasks\n - classification_label: Binary classification label derived from CLV\n \"\"\"\n df = expand_list_columns(df, date_col=\"days_before_lst\", num_col=\"num_items_lst\")\n return df.select(\n \"customer_id\",\n \"days_before_lst\",\n \"articles_ids_lst\",\n pl.col(\"CLV_label\").alias(\"regression_label\"),\n pl.col(\"CLV_label\").gt(0).cast(pl.Int32).alias(\"classification_label\"),\n ).with_columns(\n pl.col(\"days_before_lst\").list.tail(max_length),\n pl.col(\"articles_ids_lst\").list.tail(max_length),\n )\n\n\ndef get_benchmark_dfs(\n data_path: Path, config: dict\n) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:\n \"\"\"Creates benchmark train, validation, and test datasets with transaction and customer features.\n\n Args:\n data_path: Path object pointing to the data directory\n config: Dictionary containing configuration parameters for data processing\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: A tuple containing:\n - train_df: Training dataset with benchmark features\n - val_df: Validation dataset with benchmark features\n - test_df: Test dataset with benchmark features\n\n Each DataFrame contains transaction-derived features joined with customer features.\n \"\"\"\n train_article, val_article, test_article = get_tx_article_dfs(\n data_path=data_path,\n config=config,\n cols_to_aggregate=[\n \"date\",\n \"days_before\",\n \"article_ids\",\n \"sales_channel_ids\",\n \"total_price\",\n \"prices\",\n \"num_items\",\n ],\n keep_customer_id=True,\n )\n\n customer_df = get_customer_df_benchmarks(data_path=data_path, config=config)\n\n train_df = process_dataframe(\n df=train_article, max_length=config[\"max_length\"]\n ).join(customer_df, on=\"customer_id\", how=\"left\")\n val_df = process_dataframe(df=val_article, max_length=config[\"max_length\"]).join(\n customer_df, on=\"customer_id\", how=\"left\"\n )\n test_df = process_dataframe(df=test_article, max_length=config[\"max_length\"]).join(\n customer_df, on=\"customer_id\", how=\"left\"\n )\n\n return train_df, val_df, test_df\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:36.770134Z\",\"iopub.execute_input\":\"2025-03-19T10:15:36.770338Z\",\"iopub.status.idle\":\"2025-03-19T10:15:36.794388Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:36.770320Z\",\"shell.execute_reply\":\"2025-03-19T10:15:36.793698Z\"}}\nimport polars as pl\n#from pathlib import Path\n\n\ndef get_customer_df_benchmarks(data_path: Path, config: dict):\n \"\"\"Processes customer data with age grouping and zip code mapping.\n\n Args:\n data_path (Path): Path to directory containing 'customers.csv' and 'zip_code_count.csv'.\n config (dict): Configuration with 'min_zip_code_count'. Updated with 'num_age_groups' and 'num_zip_codes'.\n\n Returns:\n pl.DataFrame: Processed DataFrame with customer_id, age_group (0-6), and mapped zip_code_id.\n \"\"\"\n file_path = data_path / \"customers.csv\"\n df = pl.scan_csv(file_path).select(\n (\n \"customer_id\",\n pl.col(\"age\").fill_null(strategy=\"mean\"),\n \"postal_code\",\n )\n )\n\n # df = df.with_columns(\n # [\n # pl.when(pl.col(\"age\").is_null())\n # .then(0)\n # .when(pl.col(\"age\") < 25)\n # .then(1)\n # .when(pl.col(\"age\").is_between(25, 34))\n # .then(2)\n # .when(pl.col(\"age\").is_between(35, 44))\n # .then(3)\n # .when(pl.col(\"age\").is_between(45, 54))\n # .then(4)\n # .when(pl.col(\"age\").is_between(55, 64))\n # .then(5)\n # .otherwise(6)\n # .alias(\"age_group\")\n # ]\n # )\n # config[\"num_age_groups\"] = 7\n\n return df.collect()\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:36.795291Z\",\"iopub.execute_input\":\"2025-03-19T10:15:36.795576Z\",\"iopub.status.idle\":\"2025-03-19T10:15:36.819759Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:36.795548Z\",\"shell.execute_reply\":\"2025-03-19T10:15:36.819046Z\"}}\n#from datetime import datetime\n#from pathlib import Path\n#import polars as pl\n\n#from data_processing.utils.utils_transaction_df import (\n # filter_purchases_purchases_per_month_pl,\n # map_article_ids,\n # train_test_split,\n#)\n\n\ndef generate_clv_data_pl(\n df: pl.DataFrame,\n agg_df: pl.DataFrame,\n label_threshold: datetime.date,\n pred_end: datetime.date,\n clv_periods: list,\n log_clv: bool = False,\n):\n \"\"\"Generates Customer Lifetime Value (CLV) data from transaction dataframe.\n\n Args:\n df (pl.DataFrame): Input transaction dataframe containing customer purchases.\n agg_df (pl.DataFrame): Aggregated dataframe containing customer data.\n label_threshold (datetime.date): Start date for CLV calculation period.\n pred_end (datetime.date): End date for CLV calculation period.\n clv_periods (list): List of periods for CLV calculation (currently supports single period only).\n log_clv (bool, optional): Whether to apply log1p transformation to CLV values. Defaults to False.\n\n Returns:\n pl.DataFrame: Aggregated dataframe with added CLV calculations.\n\n Raises:\n ValueError: If more than one CLV period is provided.\n \"\"\"\n if len(clv_periods) > 1:\n raise ValueError(\"CLV periods should be a single number for now.\")\n\n # Filter transactions between label_threshold and end_date for each period\n filtered_df = df.filter(\n (pl.col(\"date\") >= label_threshold) & (pl.col(\"date\") <= pred_end)\n )\n\n # Sum total_price for the filtered transactions by customer_id. This is the CLV\n summed_period_df = filtered_df.group_by(\"customer_id\").agg(\n pl.sum(\"total_price\").round(2).alias(f\"CLV_label\")\n )\n if log_clv:\n summed_period_df = summed_period_df.with_columns(\n pl.col(f\"CLV_label\").log1p().round(2).alias(f\"CLV_label\")\n )\n\n agg_df = agg_df.join(summed_period_df, on=\"customer_id\", how=\"left\")\n\n agg_df = agg_df.fill_null(0)\n return agg_df\n\n\ndef group_and_convert_df_pl(\n df: pl.DataFrame,\n label_start_date: datetime.date,\n pred_end: datetime.date,\n clv_periods: list,\n cols_to_aggregate: list = [\n \"date\",\n \"days_before\",\n \"num_items\",\n \"article_ids\",\n \"sales_channel_ids\",\n \"total_price\",\n \"prices\",\n ],\n keep_customer_id: bool = True,\n log_clv: bool = False,\n) -> pl.DataFrame:\n \"\"\"Groups and converts transaction data into aggregated customer-level features.\n\n Args:\n df (pl.DataFrame): Input transaction dataframe.\n label_start_date (datetime.date): Start date for clv label period.\n pred_end (datetime.date): End date for prediction period.\n clv_periods (list): List of periods for CLV calculation.\n cols_to_aggregate (list, optional): Columns to include in aggregation. Defaults to standard transaction columns.\n keep_customer_id (bool, optional): Whether to retain customer_id in output. Defaults to True.\n log_clv (bool, optional): Whether to apply log1p transformation to CLV values. Defaults to False.\n\n Returns:\n pl.DataFrame: Aggregated customer-level dataframe.\n\n Raises:\n ValueError: If required columns (days_before, article_ids, num_items) are missing from cols_to_aggregate.\n \"\"\"\n\n if any(\n col not in cols_to_aggregate\n for col in [\"days_before\", \"article_ids\", \"num_items\"]\n ):\n raise ValueError(\n \"The columns days_before, article_ids, and num_items are required \"\n \"for the aggregation\"\n )\n\n mapping = {\n \"date\": \"date_lst\",\n \"days_before\": \"days_before_lst\",\n \"article_ids\": \"articles_ids_lst\",\n \"sales_channel_ids\": \"sales_channel_id_lst\",\n \"total_price\": \"total_price_lst\",\n \"prices\": \"price_lst\",\n \"num_items\": \"num_items_lst\",\n }\n\n agg_df = (\n df.filter(pl.col(\"date\") < label_start_date)\n .with_columns(\n (label_start_date - pl.col(\"date\"))\n .dt.total_days()\n .cast(pl.Int32)\n .alias(\"days_before\"),\n (\n pl.col(\"sales_channel_ids\")\n .cast(pl.List(pl.Int32))\n .alias(\"sales_channel_ids\")\n ),\n pl.col(\"article_ids\").cast(pl.List(pl.Int32)).alias(\"article_ids\"),\n )\n .sort(\"customer_id\", \"date\")\n .group_by(\"customer_id\")\n .agg(\n pl.col(\"date\").explode().alias(\"date_lst\"),\n pl.col(\"days_before\").explode().alias(\"days_before_lst\"),\n pl.col(\"article_ids\").explode().alias(\"articles_ids_lst\"),\n pl.concat_list(pl.col(\"sales_channel_ids\")).alias(\"sales_channel_id_lst\"),\n pl.col(\"total_price\").explode().alias(\"total_price_lst\"),\n pl.col(\"prices\").explode().alias(\"price_lst\"),\n pl.col(\"num_items\").explode().alias(\"num_items_lst\"),\n )\n )\n\n if clv_periods is not None:\n agg_df = generate_clv_data_pl(\n df=df,\n agg_df=agg_df,\n label_threshold=label_start_date,\n pred_end=pred_end,\n clv_periods=clv_periods,\n log_clv=log_clv,\n )\n\n # Drop columns which are not to be aggregated\n cols_to_drop = [v for k, v in mapping.items() if k not in cols_to_aggregate]\n if not keep_customer_id:\n cols_to_drop.append(\"customer_id\")\n agg_df = agg_df.drop(*cols_to_drop)\n\n return agg_df\n\n\ndef split_df_and_group_pl(\n df: pl.DataFrame,\n clv_periods: list,\n config: dict,\n cols_to_aggregate: list = [\n \"date\",\n \"days_before\",\n \"article_ids\",\n \"sales_channel_ids\",\n \"total_price\",\n \"prices\",\n \"num_items\",\n ],\n keep_customer_id: bool = True,\n log_clv: bool = False,\n) -> tuple[pl.DataFrame, pl.DataFrame]:\n \"\"\"Splits transaction data into training and test sets and performs aggregation.\n\n Args:\n df (pl.DataFrame): Input transaction dataframe.\n clv_periods (list): List of periods for CLV calculation.\n config (dict): Configuration dictionary containing:\n cols_to_aggregate (list, optional): Columns to include in aggregation. Defaults to standard transaction columns.\n keep_customer_id (bool, optional): Whether to retain customer_id in output. Defaults to True.\n log_clv (bool, optional): Whether to apply log1p transformation to CLV values. Defaults to False.\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:\n - train_df: Aggregated training dataset\n - test_df: Aggregated test dataset\n \"\"\"\n\n train_begin = datetime.strptime(config.get(\"train_begin\"), \"%Y-%m-%d\")\n train_label_start = datetime.strptime(config.get(\"train_label_begin\"), \"%Y-%m-%d\")\n train_end = datetime.strptime(config.get(\"train_end\"), \"%Y-%m-%d\")\n test_begin = datetime.strptime(config.get(\"test_begin\"), \"%Y-%m-%d\")\n test_label_start = datetime.strptime(config.get(\"test_label_begin\"), \"%Y-%m-%d\")\n test_end = datetime.strptime(config.get(\"test_end\"), \"%Y-%m-%d\")\n\n # Creating the training DataFrame by filtering dates up to `train_end`\n train_df = df.filter(\n (pl.col(\"date\") <= train_end) & (pl.col(\"date\") >= train_begin)\n )\n\n train_df = group_and_convert_df_pl(\n df=train_df,\n label_start_date=train_label_start,\n pred_end=train_end,\n clv_periods=clv_periods,\n cols_to_aggregate=cols_to_aggregate,\n keep_customer_id=keep_customer_id,\n log_clv=log_clv,\n )\n\n # Creating the test DataFrame by filtering dates after `test_begin`\n test_df = df.filter((pl.col(\"date\") >= test_begin) & (pl.col(\"date\") <= test_end))\n\n test_df = group_and_convert_df_pl(\n df=test_df,\n label_start_date=test_label_start,\n pred_end=test_end,\n clv_periods=clv_periods,\n cols_to_aggregate=cols_to_aggregate,\n keep_customer_id=keep_customer_id,\n log_clv=log_clv,\n )\n\n return train_df, test_df\n\n\ndef load_data_rem_outlier_pl(\n data_path: Path, train_end: datetime.date, group_by_channel_id: bool = False\n):\n \"\"\"Loads transaction data, applies price scaling, and removes outliers.\n\n Args:\n data_path (Path): Path to directory containing transaction data parquet file.\n train_end (datetime.date): End date for training period.\n group_by_channel_id (bool, optional): Whether to group data by sales channel ID. Defaults to False.\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:\n - grouped_df: Processed transaction dataframe\n - extreme_customers: Dataframe of customers identified as outliers\n \"\"\"\n file_path = data_path / \"transactions_polars.parquet\"\n df_pl = pl.read_parquet(file_path)\n\n df_pl = df_pl.with_columns(\n pl.col(\"t_dat\").alias(\"date\").cast(pl.Date), pl.col(\"article_id\").cast(pl.Int32)\n )\n\n df_pl = df_pl.with_columns(\n pl.col(\"price\").mul(590).cast(pl.Float32).round(2).alias(\"price\")\n )\n\n # Map article ids to running ids so that they match with feature matrix\n df_pl = map_article_ids(df=df_pl, data_path=data_path)\n\n grouped_df, extreme_customers = filter_purchases_purchases_per_month_pl(\n df_pl, train_end=train_end, group_by_channel_id=group_by_channel_id\n )\n\n return grouped_df, extreme_customers\n\n\ndef get_customer_train_test_articles_pl(\n data_path: Path,\n config: dict,\n clv_periods: list = None,\n cols_to_aggregate: list = [\n \"date\",\n \"days_before\",\n \"article_ids\",\n \"sales_channel_ids\",\n \"total_price\",\n \"prices\",\n \"num_items\",\n ],\n keep_customer_id: bool = True,\n):\n \"\"\"Processes customer transaction data into train and test sets with article information.\n\n Args:\n data_path (Path): Path to directory containing transaction data.\n config (dict): Configuration dictionary for data processing parameters.\n clv_periods (list, optional): List of periods for CLV calculation. Defaults to None.\n cols_to_aggregate (list, optional): Columns to include in aggregation. Defaults to standard transaction columns.\n keep_customer_id (bool, optional): Whether to retain customer_id in output. Defaults to True.\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:\n - train_df: Processed training dataset with article information\n - test_df: Processed test dataset with article information\n \"\"\"\n train_end = datetime.strptime(config.get(\"train_end\"), \"%Y-%m-%d\")\n grouped_df, extreme_customers = load_data_rem_outlier_pl(\n data_path=data_path, train_end=train_end\n )\n\n train_df, test_df = split_df_and_group_pl(\n df=grouped_df,\n clv_periods=clv_periods,\n config=config,\n cols_to_aggregate=cols_to_aggregate,\n keep_customer_id=True,\n log_clv=config.get(\"log_clv\", False),\n )\n\n train_df = train_df.join(extreme_customers, on=\"customer_id\", how=\"anti\")\n test_df = test_df.join(extreme_customers, on=\"customer_id\", how=\"anti\")\n\n if not keep_customer_id:\n train_df = train_df.drop(\"customer_id\")\n test_df = test_df.drop(\"customer_id\")\n\n return train_df, test_df\n\n\ndef get_tx_article_dfs(\n data_path: Path,\n config: dict,\n cols_to_aggregate: list = [\n \"date\",\n \"days_before\",\n \"article_ids\",\n \"sales_channel_ids\",\n \"total_price\",\n \"prices\",\n \"num_items\",\n ],\n keep_customer_id: bool = True,\n):\n \"\"\"Creates train, validation, and test datasets with optional subsampling.\n\n Args:\n data_path (Path): Path to directory containing transaction data files.\n config (dict): Configuration dictionary containing:\n cols_to_aggregate (list, optional): Transaction columns to include in output.\n keep_customer_id (bool, optional): Whether to retain customer_id column.\n\n Returns:\n tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Tuple containing:\n - train_df: Final training dataset (subset of original training data)\n - val_df: Validation dataset (10% of original training data)\n - test_df: Test dataset (optionally subsampled)\n \"\"\"\n \"\"\"\n Columns of dfs:\n - customer_id\n - date_lst (list[date]): Dates of each transaction\n - days_before_lst (list[int]): Number of days between start of prediction and date of transction\n - articles_ids_lst (list[int]): Flattened list of all items a customer purchased \n - sales_channel_id_lst (list[list[int]]): Sales channel of a transaction (repeated for each item within a transaction)\n - total_price_lst (list[float]): Value of each transaction\n - price_lst (list[float]): Flattened list of prices of all items customer purchased\n - num_items_lst (list[int]): Number of items in each transaction\n - CLV_label (float): Sales in prediction period (label to be used)\n \"\"\"\n train_df, test_df = get_customer_train_test_articles_pl(\n data_path=data_path,\n config=config,\n clv_periods=config.get(\"clv_periods\", [6]),\n cols_to_aggregate=cols_to_aggregate,\n keep_customer_id=keep_customer_id,\n )\n train_df, val_df, test_df = train_test_split(\n train_df=train_df,\n test_df=test_df,\n subset=config.get(\"subset\"),\n train_subsample_percentage=config.get(\"train_subsample_percentage\"),\n )\n return train_df, val_df, test_df\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:36.820612Z\",\"iopub.execute_input\":\"2025-03-19T10:15:36.820908Z\",\"iopub.status.idle\":\"2025-03-19T10:15:36.845965Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:36.820878Z\",\"shell.execute_reply\":\"2025-03-19T10:15:36.845259Z\"}}\n#from pathlib import Path\n#from data_processing.get_data import get_benchmark_dfs\n#import polars as pl\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:15:36.846833Z\",\"iopub.execute_input\":\"2025-03-19T10:15:36.847110Z\",\"iopub.status.idle\":\"2025-03-19T10:17:22.784598Z\",\"shell.execute_reply.started\":\"2025-03-19T10:15:36.847058Z\",\"shell.execute_reply\":\"2025-03-19T10:17:22.783718Z\"}}\nconfig = {\n \"train_begin\": \"2018-09-20\",\n \"train_label_begin\": \"2019-09-20\",\n \"train_end\": \"2020-03-17\",\n \"test_begin\": \"2019-03-19\",\n \"test_label_begin\": \"2020-03-18\",\n \"test_end\": \"2020-09-13\",\n \"min_zip_code_count\": 3,\n \"date_aggregation\": \"daily\",\n \"group_by_channel_id\": False,\n \"log_clv\": False,\n \"clv_periods\": [6],\n \"subset\": None,\n \"train_subsample_percentage\": None,\n \"max_length\":20, # DEFINE HOW MANY ITEMS ARE TO BE CONSIDERED IN TRANSFORMER SEQUENCE\n}\n# data_path = Path(\"/kaggle/input/hm-dataset/data/data\")\ndata_path = Path(\"/kaggle/input/data/data/\")\n\nprint(10 * \"#\", \" Loading data \", 10 * \"#\")\ntrain_df, val_df, test_df = get_benchmark_dfs(data_path, config)\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:17:22.785551Z\",\"iopub.execute_input\":\"2025-03-19T10:17:22.785808Z\",\"iopub.status.idle\":\"2025-03-19T10:17:22.793523Z\",\"shell.execute_reply.started\":\"2025-03-19T10:17:22.785787Z\",\"shell.execute_reply\":\"2025-03-19T10:17:22.792672Z\"}}\ntest_df\n\n# %% [markdown]\n# # BST TRAINING AND TESTING (FINAL VERSION)\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:31:41.037745Z\",\"iopub.execute_input\":\"2025-03-19T10:31:41.038074Z\",\"iopub.status.idle\":\"2025-03-19T10:35:28.857754Z\",\"shell.execute_reply.started\":\"2025-03-19T10:31:41.038049Z\",\"shell.execute_reply\":\"2025-03-19T10:35:28.856716Z\"}}\nimport math\nimport torch \nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_sequence\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom typing import List, Dict, Optional\n\n\n# Custom Transformer classes\n\nclass RMSNorm(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-8) -> None:\n super().__init__()\n self.eps = eps\n self.gain = nn.Parameter(torch.ones(dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n return norm * self.gain\n\nclass CustomMultiheadAttention(nn.Module):\n def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):\n super().__init__()\n assert embed_dim % num_heads == 0\n self.embed_dim = embed_dim\n self.num_heads = num_heads\n self.head_dim = embed_dim // num_heads\n self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)\n self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n self.dropout = dropout\n \n def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):\n B, T, _ = query.size()\n qkv = self.in_proj(query)\n q, k, v = qkv.chunk(3, dim=-1)\n\n # Reshape for multi-head attention\n q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n\n attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)\n attn_output = F.scaled_dot_product_attention(\n q, k, v,\n attn_mask=attn_mask,\n dropout_p=self.dropout if self.training else 0.0,\n is_causal=False,\n )\n\n attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)\n output = self.out_proj(attn_output)\n return output\n\n def merge_masks(self,\n attn_mask: Optional[torch.Tensor],\n key_padding_mask: Optional[torch.Tensor],\n query: torch.Tensor) -> Optional[torch.Tensor]:\n merged_mask = None\n batch_size, seq_len, _ = query.shape\n\n def convert_to_float_mask(mask):\n if mask.dtype == torch.bool:\n return mask.float().masked_fill(mask, float(\"-inf\"))\n return mask\n\n if key_padding_mask is not None:\n key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(\n -1, self.num_heads, -1, -1)\n merged_mask = convert_to_float_mask(key_padding_mask)\n\n if attn_mask is not None:\n if attn_mask.dim() == 2:\n correct_2d_size = (seq_len, seq_len)\n if attn_mask.shape != correct_2d_size:\n raise RuntimeError(f\"The shape of the 2D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_2d_size}.\")\n attn_mask = attn_mask.unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)\n elif attn_mask.dim() == 3:\n correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)\n if attn_mask.shape != correct_3d_size:\n raise RuntimeError(f\"The shape of the 3D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_3d_size}.\")\n attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)\n else:\n raise RuntimeError(f\"attn_mask's dimension {attn_mask.dim()} is not supported\")\n attn_mask = convert_to_float_mask(attn_mask)\n if merged_mask is None:\n merged_mask = attn_mask\n else:\n merged_mask = merged_mask + attn_mask\n return merged_mask\n\nclass TransformerEncoderLayer(nn.Module):\n def __init__(self, config: dict):\n super().__init__()\n embed_dim = config[\"embedding_dim\"]\n num_heads = config.get(\"heads\", 8)\n dropout = config[\"transformer_dropout\"]\n dim_feedforward = config[\"dim_feedforward\"]\n self.norm_first = config.get(\"norm_first\", False)\n\n self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)\n self.linear1 = nn.Linear(embed_dim, dim_feedforward)\n self.dropout1 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, embed_dim)\n self.dropout2 = nn.Dropout(dropout)\n self.norm1 = RMSNorm(embed_dim)\n self.norm2 = RMSNorm(embed_dim)\n self.activation = nn.GELU()\n\n def _sa_block(self, src, attn_mask=None, key_padding_mask=None):\n src2 = self.self_attn(src, src, src,\n key_padding_mask=key_padding_mask,\n attn_mask=attn_mask)\n return self.dropout1(src2)\n\n def _ff_block(self, src):\n src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n return self.dropout2(src2)\n\n def forward(self,\n src: torch.Tensor,\n src_key_padding_mask: torch.Tensor = None,\n src_mask: torch.Tensor = None):\n if self.norm_first:\n src = src + self._sa_block(self.norm1(src),\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask)\n src = src + self._ff_block(self.norm2(src))\n else:\n src = self.norm1(src + self._sa_block(src,\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask))\n src = self.norm2(src + self._ff_block(src))\n return src\n\nclass TransformerEncoder(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.encoder = nn.ModuleList([\n TransformerEncoderLayer(config)\n for _ in range(config[\"num_transformer_layers\"])\n ])\n\n def forward(self,\n src,\n src_key_padding_mask=None,\n src_mask=None):\n for layer in self.encoder:\n src = layer(src,\n src_key_padding_mask=src_key_padding_mask,\n src_mask=src_mask)\n return src\n\n\n# Preparing vocabularies and the dataset (without customer_id)\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n 1) Ensure each df is a pandas DataFrame.\n 2) Builds dictionary for postal codes.\n 3) Finds max article ID, max day, and max age.\n \"\"\"\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst)\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1\n\n all_days = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return postal2idx, num_postal, num_articles, max_day, num_age\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - postal_code (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (0 means churn, 1 means not churn)\n - age (int)\n Note: customer_id is no longer used.\n \"\"\"\n def __init__(self, df, postal2idx: Dict[str, int]):\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n postal_id = self.postal2idx[row['postal_code']]\n age = int(row['age'])\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)\n return (articles, days, age, postal_id, regression_label, classification_label)\n\n# Custom collate function for variable-length sequences\ndef fixed_length_collate_fn(batch: list[tuple[torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor]], \n sequence_length: int = 8, padding_value: int = 0) -> tuple[torch.Tensor, ...]:\n \"\"\"\n Efficiently pads sequences using PyTorch's pad_sequence and then truncates them to a fixed length.\n \n Args:\n batch: List of tuples where each tuple contains \n (articles, days, age, postal_id, regression_label, classification_label)\n sequence_length: Desired length of the sequences\n padding_value: Value to use for padding sequences\n \n Returns:\n Tuple of tensors: \n (article_seqs_tensor, day_seqs_tensor, ages_tensor, postal_ids_tensor, reg_labels_tensor, class_labels_tensor)\n \"\"\"\n # Unpack the batch\n article_seqs, day_seqs, ages, postal_ids, reg_labels, class_labels = zip(*batch)\n \n # Use pad_sequence for efficient padding (batch_first gives shape: [B, L, ...])\n article_seqs_tensor = pad_sequence(article_seqs, batch_first=True, padding_value=padding_value)\n day_seqs_tensor = pad_sequence(day_seqs, batch_first=True, padding_value=padding_value)\n \n # Truncate padded tensors to the desired sequence length\n article_seqs_tensor = article_seqs_tensor[:, :sequence_length]\n day_seqs_tensor = day_seqs_tensor[:, :sequence_length]\n \n # Convert scalar values to tensors\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n reg_labels_tensor = torch.stack(reg_labels, dim=0)\n class_labels_tensor = torch.stack(class_labels, dim=0)\n \n return (\n article_seqs_tensor,\n day_seqs_tensor,\n ages_tensor,\n postal_ids_tensor,\n reg_labels_tensor,\n class_labels_tensor\n )\n\n# BST model WITHOUT customer_id input\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n postal2idx,\n # embedding dims\n article_emb_dim=16,\n day_emb_dim=8,\n age_emb_dim=4,\n postal_emb_dim=4,\n # transformer config\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n # multi-task config\n predict_churn=False, # Flag to enable/disable churn prediction\n # training\n learning_rate=0.0005\n ):\n super().__init__()\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','postal2idx'])\n self.learning_rate = learning_rate\n self.predict_churn = predict_churn\n\n # DataFrames and mapping\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.postal2idx = postal2idx\n\n # Embeddings (customer embedding removed)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # Sequence features: concatenation of article and day embeddings\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n\n # Custom Transformer setup\n config = {\n \"embedding_dim\": self.seq_feature_dim,\n \"heads\": transformer_nhead,\n \"transformer_dropout\": 0.2,\n \"dim_feedforward\": transformer_ff_dim,\n \"norm_first\": False,\n \"num_transformer_layers\": num_transformer_layers,\n }\n self.transformer = TransformerEncoder(config)\n self.transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features: only age and postal embeddings are used\n user_feature_dim = age_emb_dim + postal_emb_dim\n\n combined_dim = self.transformer_output_dim + user_feature_dim\n\n # Separate heads for regression and (optional) classification\n self.regressor_head = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n\n if self.predict_churn:\n self.classifier_head = nn.Sequential(\n nn.Linear(combined_dim, 128),\n nn.LeakyReLU(),\n nn.Linear(128, 1) # single logit for binary classification\n )\n self.classification_criterion = nn.BCEWithLogitsLoss()\n\n self.regression_criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n # Expected tuple: (articles, days, age, postal_id, regression_label, classification_label)\n articles, days, age, postal_id, regression_label, classification_label = batch\n\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n # Concatenate to form the sequence features; day_embeds serve as the time/position signal.\n transformer_input = torch.cat([article_embeds, day_embeds], dim=-1) # (B, L, seq_feature_dim)\n\n transformer_output = self.transformer(transformer_input) # (B, L, seq_feature_dim)\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([age_embed, postal_embed], dim=1)\n\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n return combined_features, regression_label, classification_label\n\n def forward(self, batch):\n features, reg_label, class_label = self.encode_input(batch)\n reg_output = self.regressor_head(features).squeeze(dim=-1)\n if self.predict_churn:\n class_output = self.classifier_head(features).squeeze(dim=-1)\n else:\n class_output = None\n return reg_output, class_output, reg_label, class_label\n\n def training_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n loss = reg_loss\n self.log(\"train_reg_loss\", reg_loss)\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"train_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"val_reg_loss\", reg_loss)\n loss = reg_loss\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"val_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"test_reg_loss\", reg_loss)\n loss = reg_loss\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"test_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)\n )\n\n\n\n# TRAIN AND TEST\n\n\n#\n# 1. for regression only\n# model = BST(\n# num_articles=num_articles,\n# max_day=max_day,\n# num_age=num_age,\n# num_postal=num_postal,\n# sequence_length=sequence_length,\n# train_df=train_df,\n# val_df=val_df,\n# test_df=test_df,\n# postal2idx=postal2idx,\n# article_emb_dim=16,\n# day_emb_dim=8,\n# age_emb_dim=4,\n# postal_emb_dim=4,\n# transformer_nhead=2,\n# transformer_ff_dim=64,\n# num_transformer_layers=1,\n# predict_churn=False,\n# learning_rate=0.0005\n# )\n\n\n# 2. with churn\nmodel = BST(\n #num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n #user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n #customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n predict_churn=True, # <--- Enable churn\n learning_rate=0.0005\n)\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n\n\n# %% [markdown]\n# # BASE VERSION\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:17:30.351326Z\",\"iopub.status.idle\":\"2025-03-19T10:17:30.351722Z\",\"shell.execute_reply\":\"2025-03-19T10:17:30.351504Z\"}}\nimport math\nimport torch\nimport torch.nn as nn\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom typing import List, Dict\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n 1) Ensures each df is a pandas DataFrame (for easy indexing).\n 2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.\n 3) Finds max article ID, max day, and max age so we can define embedding sizes.\n \"\"\"\n\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n\n # Combine for global vocabularies\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n # Map string-based customer_id -> integer\n unique_users = combined['customer_id'].unique()\n user2idx = {u: i for i, u in enumerate(unique_users)}\n num_customers = len(user2idx)\n\n # Map string-based postal_code -> integer\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n # Determine max article ID\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst) # 'lst' is a list of ints\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1 # for embedding\n\n # Determine max day\n all_days = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n # Determine max age if we treat age as discrete\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - customer_id (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (not used here)\n - age (int)\n - postal_code (str)\n \"\"\"\n def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):\n # Convert to Pandas if Polars\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n\n # Convert string-based IDs to integer indices\n user_id = self.user2idx[row['customer_id']]\n postal_id = self.postal2idx[row['postal_code']]\n\n age = int(row['age']) # embedding or numeric\n\n # articles_ids_lst and days_before_lst are lists of ints\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n\n return (\n user_id,\n articles,\n days,\n age,\n postal_id,\n regression_label\n )\n\n\n# CUSTOM COLLATE FUNCTION FOR VARIABLE-LENGTH SEQUENCES\n\n\ndef fixed_length_collate_fn(batch, sequence_length=8):\n \"\"\"\n Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.\n Each item in the batch is a tuple:\n (user_id, articles, days, age, postal_id, regression_label)\n \"\"\"\n user_ids = []\n article_seqs = []\n day_seqs = []\n ages = []\n postal_ids = []\n labels = []\n\n # 1) Unpack\n for item in batch:\n (user_id, articles, days, age, postal_id, label) = item\n user_ids.append(user_id)\n article_seqs.append(articles)\n day_seqs.append(days)\n ages.append(age)\n postal_ids.append(postal_id)\n labels.append(label)\n\n # 2) Pad or truncate each sequence\n def pad_or_trunc(seq, desired_length):\n length = seq.size(0)\n if length > desired_length:\n return seq[:desired_length]\n elif length < desired_length:\n pad_size = desired_length - length\n pad = torch.zeros(pad_size, dtype=seq.dtype)\n return torch.cat([seq, pad], dim=0)\n else:\n return seq\n\n for i in range(len(article_seqs)):\n article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)\n day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)\n\n # 3) Stack everything\n user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)\n article_seqs_tensor = torch.stack(article_seqs, dim=0) # shape: (batch_size, sequence_length)\n day_seqs_tensor = torch.stack(day_seqs, dim=0) # shape: (batch_size, sequence_length)\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n labels_tensor = torch.stack(labels, dim=0) # shape: (batch_size,)\n\n return (\n user_ids_tensor,\n article_seqs_tensor,\n day_seqs_tensor,\n ages_tensor,\n postal_ids_tensor,\n labels_tensor\n )\n\n\n# BST MODEL\nclass PositionalEmbedding(nn.Module):\n \"\"\"\n Simple positional embedding that learns a unique embedding per position (0..max_len-1).\n \"\"\"\n def __init__(self, max_len, d_model):\n super().__init__()\n self.pe = nn.Embedding(max_len, d_model)\n\n def forward(self, x):\n # x: (batch_size, seq_length, d_model)\n batch_size, seq_length, _ = x.size()\n positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, seq_length)\n return self.pe(positions) # (batch_size, seq_length, d_model)\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_customers,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n user2idx,\n postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n learning_rate=0.0005\n ):\n super().__init__()\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])\n self.learning_rate = learning_rate\n\n # DataFrames + Mappings\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n # Embeddings\n self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # Sequence dimension\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n self.positional_embedding = PositionalEmbedding(sequence_length, self.seq_feature_dim)\n\n # Transformer\n self.transformer_layer = nn.TransformerEncoderLayer(\n d_model=self.seq_feature_dim,\n nhead=transformer_nhead,\n dropout=0.2\n )\n\n # Flattened dimension after transformer\n transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features dimension\n user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim\n\n # Combined dimension\n combined_dim = transformer_output_dim + user_feature_dim\n\n # Final regressor\n self.linear = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n\n self.criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n user_id, articles, days, age, postal_id, regression_label = batch\n\n # Sequence embeddings\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n sequence_features = torch.cat([article_embeds, day_embeds], dim=-1) # (B, L, seq_feature_dim)\n\n # Positional embeddings\n pos_embeds = self.positional_embedding(sequence_features)\n transformer_input = sequence_features + pos_embeds\n\n # Transformer expects (L, B, d_model)\n transformer_input = transformer_input.transpose(0, 1) # (L, B, seq_feature_dim)\n transformer_output = self.transformer_layer(transformer_input)\n transformer_output = transformer_output.transpose(0, 1) # (B, L, seq_feature_dim)\n\n # Flatten\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n\n # User features\n customer_embed = self.embeddings_customer(user_id)\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)\n\n # Combine\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n return combined_features, regression_label\n\n def forward(self, batch):\n features, target = self.encode_input(batch)\n output = self.linear(features)\n return output.squeeze(), target\n\n def training_step(self, batch, batch_idx):\n output, target = self(batch)\n loss = self.criterion(output, target)\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n output, target = self(batch)\n loss = self.criterion(output, target)\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n output, target = self(batch)\n loss = self.criterion(output, target)\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.user2idx, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.user2idx, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n\n# TRAIN AND TEST\n\n\nuser2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age = prepare_vocabularies(\n train_df, val_df, test_df\n )\n\n\nsequence_length = 8\n\n\nmodel = BST(\n num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n learning_rate=0.0005\n)\n\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n# %% [markdown]\n# # USING ALTERNATIVE TRANSFORMER\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:17:30.352672Z\",\"iopub.status.idle\":\"2025-03-19T10:17:30.353030Z\",\"shell.execute_reply\":\"2025-03-19T10:17:30.352852Z\"}}\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom typing import List, Dict, Optional\n\n\n# Custom Transformer classes\n\n\nclass RMSNorm(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-8) -> None:\n super().__init__()\n self.eps = eps\n self.gain = nn.Parameter(torch.ones(dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n return norm * self.gain\n\n\nclass CustomMultiheadAttention(nn.Module):\n def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):\n super().__init__()\n assert embed_dim % num_heads == 0\n self.embed_dim = embed_dim\n self.num_heads = num_heads\n self.head_dim = embed_dim // num_heads\n self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)\n self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n self.dropout = dropout\n \n def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):\n B, T, _ = query.size()\n qkv = self.in_proj(query)\n q, k, v = qkv.chunk(3, dim=-1)\n\n # Reshape for multi-head\n q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n\n # Merge masks if present\n attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)\n\n # Use PyTorch's scaled dot-product attention\n attn_output = F.scaled_dot_product_attention(\n q,\n k,\n v,\n attn_mask=attn_mask,\n dropout_p=self.dropout if self.training else 0.0,\n is_causal=False,\n )\n\n # Reshape back\n attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)\n output = self.out_proj(attn_output)\n return output\n\n def merge_masks(\n self,\n attn_mask: Optional[torch.Tensor],\n key_padding_mask: Optional[torch.Tensor],\n query: torch.Tensor,\n ) -> Optional[torch.Tensor]:\n merged_mask = None\n batch_size, seq_len, _ = query.shape\n\n def convert_to_float_mask(mask):\n if mask.dtype == torch.bool:\n return mask.float().masked_fill(mask, float(\"-inf\"))\n return mask\n\n # key_padding_mask -> float mask\n if key_padding_mask is not None:\n # shape (B, T) -> (B, 1, 1, T) -> expand to (B, num_heads, 1, T)\n key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(\n -1, self.num_heads, -1, -1\n )\n merged_mask = convert_to_float_mask(key_padding_mask)\n\n # attn_mask -> float mask\n if attn_mask is not None:\n if attn_mask.dim() == 2:\n # shape (T, T) -> (B, num_heads, T, T)\n correct_2d_size = (seq_len, seq_len)\n if attn_mask.shape != correct_2d_size:\n raise RuntimeError(\n f\"The shape of the 2D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_2d_size}.\"\n )\n attn_mask = attn_mask.unsqueeze(0).expand(\n batch_size, self.num_heads, -1, -1\n )\n elif attn_mask.dim() == 3:\n # shape (B*num_heads, T, T) -> (B, num_heads, T, T)\n correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)\n if attn_mask.shape != correct_3d_size:\n raise RuntimeError(\n f\"The shape of the 3D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_3d_size}.\"\n )\n attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)\n else:\n raise RuntimeError(\n f\"attn_mask's dimension {attn_mask.dim()} is not supported\"\n )\n\n attn_mask = convert_to_float_mask(attn_mask)\n\n if merged_mask is None:\n merged_mask = attn_mask\n else:\n merged_mask = merged_mask + attn_mask\n\n return merged_mask\n\n\nclass TransformerEncoderLayer(nn.Module):\n def __init__(self, config: dict):\n super().__init__()\n embed_dim = config[\"embedding_dim\"]\n num_heads = config.get(\"heads\", 8)\n dropout = config[\"transformer_dropout\"]\n dim_feedforward = config[\"dim_feedforward\"]\n self.norm_first = config.get(\"norm_first\", False)\n\n self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)\n self.linear1 = nn.Linear(embed_dim, dim_feedforward)\n self.dropout1 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, embed_dim)\n self.dropout2 = nn.Dropout(dropout)\n\n self.norm1 = RMSNorm(embed_dim)\n self.norm2 = RMSNorm(embed_dim)\n self.activation = nn.GELU()\n\n def _sa_block(self, src, attn_mask=None, key_padding_mask=None):\n src2 = self.self_attn(\n src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask\n )\n return self.dropout1(src2)\n\n def _ff_block(self, src):\n src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n return self.dropout2(src2)\n\n def forward(\n self,\n src: torch.Tensor,\n src_key_padding_mask: torch.Tensor = None,\n src_mask: torch.Tensor = None,\n ):\n if self.norm_first:\n # Pre-norm\n src = src + self._sa_block(\n self.norm1(src),\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask,\n )\n src = src + self._ff_block(self.norm2(src))\n else:\n # Post-norm\n src = self.norm1(\n src\n + self._sa_block(\n src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n )\n )\n src = self.norm2(src + self._ff_block(src))\n return src\n\n\nclass TransformerEncoder(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.encoder = nn.ModuleList(\n [\n TransformerEncoderLayer(config)\n for _ in range(config[\"num_transformer_layers\"])\n ]\n )\n\n def forward(\n self,\n src,\n src_key_padding_mask=None,\n src_mask=None,\n ):\n \"\"\"\n src: shape (B, T, E)\n \"\"\"\n for layer in self.encoder:\n src = layer(\n src, src_key_padding_mask=src_key_padding_mask, src_mask=src_mask\n )\n return src\n\n\n\n# Preparing vocabularies and the dataset classes\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n 1) Ensures each df is a pandas DataFrame (for easy indexing).\n 2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.\n 3) Finds max article ID, max day, and max age so we can define embedding sizes.\n \"\"\"\n\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n\n # Combine for global vocabularies\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n # Map string-based customer_id -> integer\n unique_users = combined['customer_id'].unique()\n user2idx = {u: i for i, u in enumerate(unique_users)}\n num_customers = len(user2idx)\n\n # Map string-based postal_code -> integer\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n # Determine max article ID\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst) # 'lst' is a list of ints\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1 # for embedding\n\n # Determine max day\n all_days = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n # Determine max age if we treat age as discrete\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age\n\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - customer_id (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (not used here)\n - age (int)\n - postal_code (str)\n \"\"\"\n def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):\n # Convert to Pandas if Polars\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n\n # Convert string-based IDs to integer indices\n user_id = self.user2idx[row['customer_id']]\n postal_id = self.postal2idx[row['postal_code']]\n\n age = int(row['age']) # embedding or numeric\n\n # articles_ids_lst and days_before_lst are lists of ints\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n\n return (\n user_id,\n articles,\n days,\n age,\n postal_id,\n regression_label\n )\n\n\n# Custom collate function for variable-length sequences\ndef fixed_length_collate_fn(batch, sequence_length=8):\n \"\"\"\n Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.\n Each item in the batch is a tuple:\n (user_id, articles, days, age, postal_id, regression_label)\n \"\"\"\n user_ids = []\n article_seqs = []\n day_seqs = []\n ages = []\n postal_ids = []\n labels = []\n\n # 1) Unpack\n for item in batch:\n (user_id, articles, days, age, postal_id, label) = item\n user_ids.append(user_id)\n article_seqs.append(articles)\n day_seqs.append(days)\n ages.append(age)\n postal_ids.append(postal_id)\n labels.append(label)\n\n # 2) Pad or truncate each sequence\n def pad_or_trunc(seq, desired_length):\n length = seq.size(0)\n if length > desired_length:\n return seq[:desired_length]\n elif length < desired_length:\n pad_size = desired_length - length\n pad = torch.zeros(pad_size, dtype=seq.dtype)\n return torch.cat([seq, pad], dim=0)\n else:\n return seq\n\n for i in range(len(article_seqs)):\n article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)\n day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)\n\n # 3) Stack everything\n user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)\n article_seqs_tensor = torch.stack(article_seqs, dim=0) # shape: (batch_size, sequence_length)\n day_seqs_tensor = torch.stack(day_seqs, dim=0) # shape: (batch_size, sequence_length)\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n labels_tensor = torch.stack(labels, dim=0) # shape: (batch_size,)\n\n return (\n user_ids_tensor,\n article_seqs_tensor,\n day_seqs_tensor,\n ages_tensor,\n postal_ids_tensor,\n labels_tensor\n )\n\n\n\n# BST model with the custom Transformer\n\n\nclass PositionalEmbedding(nn.Module):\n \"\"\"\n Simple positional embedding that learns a unique embedding per position (0..max_len-1).\n \"\"\"\n def __init__(self, max_len, d_model):\n super().__init__()\n self.pe = nn.Embedding(max_len, d_model)\n\n def forward(self, x):\n # x: (batch_size, seq_length, d_model)\n batch_size, seq_length, _ = x.size()\n positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, seq_length)\n return self.pe(positions) # (batch_size, seq_length, d_model)\n\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_customers,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n user2idx,\n postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64, # <-- new hyperparam for the feed-forward layer\n num_transformer_layers=1, # <-- how many layers in the custom transformer\n learning_rate=0.0005\n ):\n super().__init__()\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])\n self.learning_rate = learning_rate\n\n # DataFrames + Mappings\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n # Embeddings\n self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # Sequence dimension\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n self.positional_embedding = PositionalEmbedding(sequence_length, self.seq_feature_dim)\n\n # -------------------------\n # Custom Transformer setup\n # -------------------------\n config = {\n \"embedding_dim\": self.seq_feature_dim,\n \"heads\": transformer_nhead,\n \"transformer_dropout\": 0.2,\n \"dim_feedforward\": transformer_ff_dim,\n \"norm_first\": False,\n \"num_transformer_layers\": num_transformer_layers,\n }\n self.transformer = TransformerEncoder(config)\n\n # Flattened dimension after transformer\n self.transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features dimension\n user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim\n\n # Combined dimension\n combined_dim = self.transformer_output_dim + user_feature_dim\n\n # Final regressor\n self.linear = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n\n self.criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n user_id, articles, days, age, postal_id, regression_label = batch\n\n # Sequence embeddings\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n sequence_features = torch.cat([article_embeds, day_embeds], dim=-1) # (B, L, seq_feature_dim)\n\n # Positional embeddings\n pos_embeds = self.positional_embedding(sequence_features) # (B, L, seq_feature_dim)\n transformer_input = sequence_features + pos_embeds # (B, L, seq_feature_dim)\n\n # Pass through our custom Transformer (B, L, d_model)\n transformer_output = self.transformer(transformer_input) # (B, L, seq_feature_dim)\n\n # Flatten\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n\n # User features\n customer_embed = self.embeddings_customer(user_id)\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)\n\n # Combine\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n return combined_features, regression_label\n\n def forward(self, batch):\n features, target = self.encode_input(batch)\n output = self.linear(features)\n return output.squeeze(), target\n\n def training_step(self, batch, batch_idx):\n output, target = self(batch)\n loss = self.criterion(output, target)\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n output, target = self(batch)\n loss = self.criterion(output, target)\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n output, target = self(batch)\n loss = self.criterion(output, target)\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.user2idx, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.user2idx, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n\n\n# Example usage (Train and Test)\n\n\nsequence_length = 8\nmodel = BST(\n num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n learning_rate=0.0005\n )\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n\n# %% [markdown]\n# # ADDING CHURN PREDICTION\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:17:30.354118Z\",\"iopub.status.idle\":\"2025-03-19T10:17:30.354604Z\",\"shell.execute_reply\":\"2025-03-19T10:17:30.354460Z\"}}\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom typing import List, Dict, Optional\n\n\n# Custom Transformer classes \n\nclass RMSNorm(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-8) -> None:\n super().__init__()\n self.eps = eps\n self.gain = nn.Parameter(torch.ones(dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n return norm * self.gain\n\nclass CustomMultiheadAttention(nn.Module):\n def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):\n super().__init__()\n assert embed_dim % num_heads == 0\n self.embed_dim = embed_dim\n self.num_heads = num_heads\n self.head_dim = embed_dim // num_heads\n self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)\n self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n self.dropout = dropout\n \n def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):\n B, T, _ = query.size()\n qkv = self.in_proj(query)\n q, k, v = qkv.chunk(3, dim=-1)\n\n # Reshape for multi-head\n q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n\n # Merge masks if present\n attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)\n\n # Use PyTorch's scaled dot-product attention\n attn_output = F.scaled_dot_product_attention(\n q,\n k,\n v,\n attn_mask=attn_mask,\n dropout_p=self.dropout if self.training else 0.0,\n is_causal=False,\n )\n\n # Reshape back\n attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)\n output = self.out_proj(attn_output)\n return output\n\n def merge_masks(\n self,\n attn_mask: Optional[torch.Tensor],\n key_padding_mask: Optional[torch.Tensor],\n query: torch.Tensor,\n ) -> Optional[torch.Tensor]:\n merged_mask = None\n batch_size, seq_len, _ = query.shape\n\n def convert_to_float_mask(mask):\n if mask.dtype == torch.bool:\n return mask.float().masked_fill(mask, float(\"-inf\"))\n return mask\n\n # key_padding_mask -> float mask\n if key_padding_mask is not None:\n # shape (B, T) -> (B, 1, 1, T) -> expand to (B, num_heads, 1, T)\n key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(\n -1, self.num_heads, -1, -1\n )\n merged_mask = convert_to_float_mask(key_padding_mask)\n\n # attn_mask -> float mask\n if attn_mask is not None:\n if attn_mask.dim() == 2:\n # shape (T, T) -> (B, num_heads, T, T)\n correct_2d_size = (seq_len, seq_len)\n if attn_mask.shape != correct_2d_size:\n raise RuntimeError(\n f\"The shape of the 2D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_2d_size}.\"\n )\n attn_mask = attn_mask.unsqueeze(0).expand(\n batch_size, self.num_heads, -1, -1\n )\n elif attn_mask.dim() == 3:\n # shape (B*num_heads, T, T) -> (B, num_heads, T, T)\n correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)\n if attn_mask.shape != correct_3d_size:\n raise RuntimeError(\n f\"The shape of the 3D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_3d_size}.\"\n )\n attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)\n else:\n raise RuntimeError(\n f\"attn_mask's dimension {attn_mask.dim()} is not supported\"\n )\n\n attn_mask = convert_to_float_mask(attn_mask)\n\n if merged_mask is None:\n merged_mask = attn_mask\n else:\n merged_mask = merged_mask + attn_mask\n\n return merged_mask\n\n\nclass TransformerEncoderLayer(nn.Module):\n def __init__(self, config: dict):\n super().__init__()\n embed_dim = config[\"embedding_dim\"]\n num_heads = config.get(\"heads\", 8)\n dropout = config[\"transformer_dropout\"]\n dim_feedforward = config[\"dim_feedforward\"]\n self.norm_first = config.get(\"norm_first\", False)\n\n self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)\n self.linear1 = nn.Linear(embed_dim, dim_feedforward)\n self.dropout1 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, embed_dim)\n self.dropout2 = nn.Dropout(dropout)\n\n self.norm1 = RMSNorm(embed_dim)\n self.norm2 = RMSNorm(embed_dim)\n self.activation = nn.GELU()\n\n def _sa_block(self, src, attn_mask=None, key_padding_mask=None):\n src2 = self.self_attn(\n src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask\n )\n return self.dropout1(src2)\n\n def _ff_block(self, src):\n src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n return self.dropout2(src2)\n\n def forward(\n self,\n src: torch.Tensor,\n src_key_padding_mask: torch.Tensor = None,\n src_mask: torch.Tensor = None,\n ):\n if self.norm_first:\n # Pre-norm\n src = src + self._sa_block(\n self.norm1(src),\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask,\n )\n src = src + self._ff_block(self.norm2(src))\n else:\n # Post-norm\n src = self.norm1(\n src\n + self._sa_block(\n src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n )\n )\n src = self.norm2(src + self._ff_block(src))\n return src\n\n\nclass TransformerEncoder(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.encoder = nn.ModuleList(\n [\n TransformerEncoderLayer(config)\n for _ in range(config[\"num_transformer_layers\"])\n ]\n )\n\n def forward(\n self,\n src,\n src_key_padding_mask=None,\n src_mask=None,\n ):\n \"\"\"\n src: shape (B, T, E)\n \"\"\"\n for layer in self.encoder:\n src = layer(\n src, src_key_padding_mask=src_key_padding_mask, src_mask=src_mask\n )\n return src\n\n\n# Preparing vocabularies and the Dataset (with churn)\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n 1) Ensures each df is a pandas DataFrame (for easy indexing).\n 2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.\n 3) Finds max article ID, max day, and max age so we can define embedding sizes.\n \"\"\"\n\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n\n # Combine for global vocabularies\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n # Map string-based customer_id -> integer\n unique_users = combined['customer_id'].unique()\n user2idx = {u: i for i, u in enumerate(unique_users)}\n num_customers = len(user2idx)\n\n # Map string-based postal_code -> integer\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n # Determine max article ID\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst) # 'lst' is a list of ints\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1 # for embedding\n\n # Determine max day\n all_days = []\n for df_pd in [train_pd, val_pd, test_df]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n # Determine max age if we treat age as discrete\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age\n\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - customer_id (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (0 means churn, 1 means not churn)\n - age (int)\n - postal_code (str)\n \"\"\"\n def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):\n # Convert to Pandas if Polars\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n\n # Convert string-based IDs to integer indices\n user_id = self.user2idx[row['customer_id']]\n postal_id = self.postal2idx[row['postal_code']]\n\n age = int(row['age']) # embedding or numeric\n\n # articles_ids_lst and days_before_lst are lists of ints\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)\n\n # Return 7 items now\n return (\n user_id,\n articles,\n days,\n age,\n postal_id,\n regression_label,\n classification_label\n )\n\n\n# Custom collate function for variable-length sequences\ndef fixed_length_collate_fn(batch, sequence_length=8):\n \"\"\"\n Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.\n Each item in the batch is a tuple:\n (user_id, articles, days, age, postal_id, regression_label, classification_label)\n \"\"\"\n user_ids = []\n article_seqs = []\n day_seqs = []\n ages = []\n postal_ids = []\n reg_labels = []\n class_labels = []\n\n # 1) Unpack\n for item in batch:\n (user_id, articles, days, age, postal_id, reg_label, cls_label) = item\n user_ids.append(user_id)\n article_seqs.append(articles)\n day_seqs.append(days)\n ages.append(age)\n postal_ids.append(postal_id)\n reg_labels.append(reg_label)\n class_labels.append(cls_label)\n\n # 2) Pad or truncate each sequence\n def pad_or_trunc(seq, desired_length):\n length = seq.size(0)\n if length > desired_length:\n return seq[:desired_length]\n elif length < desired_length:\n pad_size = desired_length - length\n pad = torch.zeros(pad_size, dtype=seq.dtype)\n return torch.cat([seq, pad], dim=0)\n else:\n return seq\n\n for i in range(len(article_seqs)):\n article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)\n day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)\n\n # 3) Stack everything\n user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)\n article_seqs_tensor = torch.stack(article_seqs, dim=0) # shape: (batch_size, sequence_length)\n day_seqs_tensor = torch.stack(day_seqs, dim=0) # shape: (batch_size, sequence_length)\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n reg_labels_tensor = torch.stack(reg_labels, dim=0) # shape: (batch_size,)\n class_labels_tensor = torch.stack(class_labels, dim=0) # shape: (batch_size,)\n\n return (\n user_ids_tensor,\n article_seqs_tensor,\n day_seqs_tensor,\n ages_tensor,\n postal_ids_tensor,\n reg_labels_tensor,\n class_labels_tensor\n )\n\n\n\n# BST model with optional Churn Head\n\n\nclass PositionalEmbedding(nn.Module):\n \"\"\"\n Simple positional embedding that learns a unique embedding per position (0..max_len-1).\n \"\"\"\n def __init__(self, max_len, d_model):\n super().__init__()\n self.pe = nn.Embedding(max_len, d_model)\n\n def forward(self, x):\n # x: (batch_size, seq_length, d_model)\n batch_size, seq_length, _ = x.size()\n positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, seq_length)\n return self.pe(positions) # (batch_size, seq_length, d_model)\n\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_customers,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n user2idx,\n postal2idx,\n # embedding dims\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n # transformer config\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n # multi-task config\n predict_churn=False, # <-- Flag to enable or disable churn prediction\n # training\n learning_rate=0.0005\n ):\n super().__init__()\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])\n self.learning_rate = learning_rate\n self.predict_churn = predict_churn # store the flag\n\n # DataFrames + Mappings\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n # Embeddings\n self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # Sequence dimension\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n self.positional_embedding = PositionalEmbedding(sequence_length, self.seq_feature_dim)\n\n # -------------------------\n # Custom Transformer setup\n # -------------------------\n config = {\n \"embedding_dim\": self.seq_feature_dim,\n \"heads\": transformer_nhead,\n \"transformer_dropout\": 0.2,\n \"dim_feedforward\": transformer_ff_dim,\n \"norm_first\": False,\n \"num_transformer_layers\": num_transformer_layers,\n }\n self.transformer = TransformerEncoder(config)\n\n # Flattened dimension after transformer\n self.transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features dimension\n user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim\n\n # Combined dimension\n combined_dim = self.transformer_output_dim + user_feature_dim\n\n # -------------------------\n # Separate heads:\n # 1) Regression (always)\n # 2) Classification (optional)\n # -------------------------\n self.regressor_head = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n\n if self.predict_churn:\n # Binary classification head (churn=0, not-churn=1)\n self.classifier_head = nn.Sequential(\n nn.Linear(combined_dim, 128),\n nn.LeakyReLU(),\n nn.Linear(128, 1) # single logit for BCE\n )\n self.classification_criterion = nn.BCEWithLogitsLoss()\n\n # MSE for regression\n self.regression_criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n # We now receive 7 items instead of 6\n (user_id, articles, days, age, postal_id, regression_label, classification_label) = batch\n\n # Sequence embeddings\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n sequence_features = torch.cat([article_embeds, day_embeds], dim=-1) # (B, L, seq_feature_dim)\n\n # Positional embeddings\n pos_embeds = self.positional_embedding(sequence_features) # (B, L, seq_feature_dim)\n transformer_input = sequence_features + pos_embeds # (B, L, seq_feature_dim)\n\n # Pass through our custom Transformer (B, L, d_model)\n transformer_output = self.transformer(transformer_input) # (B, L, seq_feature_dim)\n\n # Flatten\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n\n # User features\n customer_embed = self.embeddings_customer(user_id)\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)\n\n # Combine\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n return combined_features, regression_label, classification_label\n\n def forward(self, batch):\n features, reg_label, class_label = self.encode_input(batch)\n\n # 1) Regression output\n reg_output = self.regressor_head(features).squeeze(dim=-1)\n\n # 2) Classification output (only if predict_churn=True)\n if self.predict_churn:\n class_output = self.classifier_head(features).squeeze(dim=-1) # logit\n else:\n class_output = None\n\n return reg_output, class_output, reg_label, class_label\n\n def training_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n\n # Always compute regression loss\n reg_loss = self.regression_criterion(reg_output, reg_label)\n loss = reg_loss # default if churn is off\n\n self.log(\"train_reg_loss\", reg_loss)\n\n # If churn is enabled, compute classification loss\n if self.predict_churn:\n # BCEWithLogitsLoss expects float targets of 0 or 1\n class_label = class_label.float()\n class_loss = self.classification_criterion(class_output, class_label)\n loss = reg_loss + class_loss # simple combined loss\n self.log(\"train_class_loss\", class_loss)\n\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n\n # Regression\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"val_reg_loss\", reg_loss)\n\n loss = reg_loss\n\n # Classification\n if self.predict_churn:\n class_label = class_label.float()\n class_loss = self.classification_criterion(class_output, class_label)\n self.log(\"val_class_loss\", class_loss)\n loss = reg_loss + class_loss\n\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n\n # Regression\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"test_reg_loss\", reg_loss)\n\n loss = reg_loss\n\n # Classification\n if self.predict_churn:\n class_label = class_label.float()\n class_loss = self.classification_criterion(class_output, class_label)\n self.log(\"test_class_loss\", class_loss)\n loss = reg_loss + class_loss\n\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.user2idx, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.user2idx, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n\n# train and test\n\n# do regression only\nmodel = BST(\n num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n predict_churn=False, # <--- Disable churn\n learning_rate=0.0005\n)\n\n\n\n# also predict churn (multi-task: regression + classification):\nmodel = BST(\n num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n predict_churn=True, # <--- Enable churn\n learning_rate=0.0005\n)\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n\n\n# %% [markdown]\n# # REMOVED ADDITIVE POSITIONAL ENCODING\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:17:30.358056Z\",\"iopub.status.idle\":\"2025-03-19T10:17:30.358379Z\",\"shell.execute_reply\":\"2025-03-19T10:17:30.358214Z\"}}\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom typing import List, Dict, Optional\n\n\n# custom Transformer class\n\nclass RMSNorm(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-8) -> None:\n super().__init__()\n self.eps = eps\n self.gain = nn.Parameter(torch.ones(dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n return norm * self.gain\n\nclass CustomMultiheadAttention(nn.Module):\n def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):\n super().__init__()\n assert embed_dim % num_heads == 0\n self.embed_dim = embed_dim\n self.num_heads = num_heads\n self.head_dim = embed_dim // num_heads\n self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)\n self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n self.dropout = dropout\n \n def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):\n B, T, _ = query.size()\n qkv = self.in_proj(query)\n q, k, v = qkv.chunk(3, dim=-1)\n\n # Reshape for multi-head\n q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n\n # Merge masks if present\n attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)\n\n # Use PyTorch's scaled dot-product attention\n attn_output = F.scaled_dot_product_attention(\n q,\n k,\n v,\n attn_mask=attn_mask,\n dropout_p=self.dropout if self.training else 0.0,\n is_causal=False,\n )\n\n # Reshape back\n attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)\n output = self.out_proj(attn_output)\n return output\n\n def merge_masks(\n self,\n attn_mask: Optional[torch.Tensor],\n key_padding_mask: Optional[torch.Tensor],\n query: torch.Tensor,\n ) -> Optional[torch.Tensor]:\n merged_mask = None\n batch_size, seq_len, _ = query.shape\n\n def convert_to_float_mask(mask):\n if mask.dtype == torch.bool:\n return mask.float().masked_fill(mask, float(\"-inf\"))\n return mask\n\n # key_padding_mask -> float mask\n if key_padding_mask is not None:\n # shape (B, T) -> (B, 1, 1, T) -> expand to (B, num_heads, 1, T)\n key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(\n -1, self.num_heads, -1, -1\n )\n merged_mask = convert_to_float_mask(key_padding_mask)\n\n # attn_mask -> float mask\n if attn_mask is not None:\n if attn_mask.dim() == 2:\n # shape (T, T) -> (B, num_heads, T, T)\n correct_2d_size = (seq_len, seq_len)\n if attn_mask.shape != correct_2d_size:\n raise RuntimeError(\n f\"The shape of the 2D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_2d_size}.\"\n )\n attn_mask = attn_mask.unsqueeze(0).expand(\n batch_size, self.num_heads, -1, -1\n )\n elif attn_mask.dim() == 3:\n # shape (B*num_heads, T, T) -> (B, num_heads, T, T)\n correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)\n if attn_mask.shape != correct_3d_size:\n raise RuntimeError(\n f\"The shape of the 3D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_3d_size}.\"\n )\n attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)\n else:\n raise RuntimeError(\n f\"attn_mask's dimension {attn_mask.dim()} is not supported\"\n )\n\n attn_mask = convert_to_float_mask(attn_mask)\n\n if merged_mask is None:\n merged_mask = attn_mask\n else:\n merged_mask = merged_mask + attn_mask\n\n return merged_mask\n\n\nclass TransformerEncoderLayer(nn.Module):\n def __init__(self, config: dict):\n super().__init__()\n embed_dim = config[\"embedding_dim\"]\n num_heads = config.get(\"heads\", 8)\n dropout = config[\"transformer_dropout\"]\n dim_feedforward = config[\"dim_feedforward\"]\n self.norm_first = config.get(\"norm_first\", False)\n\n self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)\n self.linear1 = nn.Linear(embed_dim, dim_feedforward)\n self.dropout1 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, embed_dim)\n self.dropout2 = nn.Dropout(dropout)\n\n self.norm1 = RMSNorm(embed_dim)\n self.norm2 = RMSNorm(embed_dim)\n self.activation = nn.GELU()\n\n def _sa_block(self, src, attn_mask=None, key_padding_mask=None):\n src2 = self.self_attn(\n src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask\n )\n return self.dropout1(src2)\n\n def _ff_block(self, src):\n src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n return self.dropout2(src2)\n\n def forward(\n self,\n src: torch.Tensor,\n src_key_padding_mask: torch.Tensor = None,\n src_mask: torch.Tensor = None,\n ):\n if self.norm_first:\n # Pre-norm\n src = src + self._sa_block(\n self.norm1(src),\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask,\n )\n src = src + self._ff_block(self.norm2(src))\n else:\n # Post-norm\n src = self.norm1(\n src\n + self._sa_block(\n src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask\n )\n )\n src = self.norm2(src + self._ff_block(src))\n return src\n\n\nclass TransformerEncoder(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.encoder = nn.ModuleList(\n [\n TransformerEncoderLayer(config)\n for _ in range(config[\"num_transformer_layers\"])\n ]\n )\n\n def forward(\n self,\n src,\n src_key_padding_mask=None,\n src_mask=None,\n ):\n \"\"\"\n src: shape (B, T, E)\n \"\"\"\n for layer in self.encoder:\n src = layer(\n src, src_key_padding_mask=src_key_padding_mask, src_mask=src_mask\n )\n return src\n\n\n\n# Preparing vocabularies and the Dataset (with churn)\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n 1) Ensures each df is a pandas DataFrame (for easy indexing).\n 2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.\n 3) Finds max article ID, max day, and max age so we can define embedding sizes.\n \"\"\"\n\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n\n # Combine for global vocabularies\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n # Map string-based customer_id -> integer\n unique_users = combined['customer_id'].unique()\n user2idx = {u: i for i, u in enumerate(unique_users)}\n num_customers = len(user2idx)\n\n # Map string-based postal_code -> integer\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n # Determine max article ID\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst) # 'lst' is a list of ints\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1 # for embedding\n\n # Determine max day\n all_days = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n # Determine max age if we treat age as discrete\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age\n\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - customer_id (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (0 means churn, 1 means not churn)\n - age (int)\n - postal_code (str)\n \"\"\"\n def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):\n # Convert to Pandas if Polars\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n\n # Convert string-based IDs to integer indices\n user_id = self.user2idx[row['customer_id']]\n postal_id = self.postal2idx[row['postal_code']]\n\n age = int(row['age']) # embedding or numeric\n\n # articles_ids_lst and days_before_lst are lists of ints\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)\n\n return (\n user_id,\n articles,\n days,\n age,\n postal_id,\n regression_label,\n classification_label\n )\n\n\n# Custom collate function for variable-length sequences\ndef fixed_length_collate_fn(batch, sequence_length=8):\n \"\"\"\n Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.\n Each item in the batch is a tuple:\n (user_id, articles, days, age, postal_id, regression_label, classification_label)\n \"\"\"\n user_ids = []\n article_seqs = []\n day_seqs = []\n ages = []\n postal_ids = []\n reg_labels = []\n class_labels = []\n\n # 1) Unpack\n for item in batch:\n (user_id, articles, days, age, postal_id, reg_label, cls_label) = item\n user_ids.append(user_id)\n article_seqs.append(articles)\n day_seqs.append(days)\n ages.append(age)\n postal_ids.append(postal_id)\n reg_labels.append(reg_label)\n class_labels.append(cls_label)\n\n # 2) Pad or truncate each sequence\n def pad_or_trunc(seq, desired_length):\n length = seq.size(0)\n if length > desired_length:\n return seq[:desired_length]\n elif length < desired_length:\n pad_size = desired_length - length\n pad = torch.zeros(pad_size, dtype=seq.dtype)\n return torch.cat([seq, pad], dim=0)\n else:\n return seq\n\n for i in range(len(article_seqs)):\n article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)\n day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)\n\n # 3) Stack everything\n user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)\n article_seqs_tensor = torch.stack(article_seqs, dim=0) # shape: (batch_size, sequence_length)\n day_seqs_tensor = torch.stack(day_seqs, dim=0) # shape: (batch_size, sequence_length)\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n reg_labels_tensor = torch.stack(reg_labels, dim=0) # shape: (batch_size,)\n class_labels_tensor = torch.stack(class_labels, dim=0) # shape: (batch_size,)\n\n return (\n user_ids_tensor,\n article_seqs_tensor,\n day_seqs_tensor,\n ages_tensor,\n postal_ids_tensor,\n reg_labels_tensor,\n class_labels_tensor\n )\n\n\n\n# BST model WITHOUT separate pos. embed\n\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_customers,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n user2idx,\n postal2idx,\n # embedding dims\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n # transformer config\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n # multi-task config\n predict_churn=False, # <-- Flag to enable or disable churn prediction\n # training\n learning_rate=0.0005\n ):\n super().__init__()\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])\n self.learning_rate = learning_rate\n self.predict_churn = predict_churn # store the flag\n\n # DataFrames + Mappings\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.user2idx = user2idx\n self.postal2idx = postal2idx\n\n # Embeddings\n self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # We treat \"day_embeds\" as the positional/time feature. \n # So final dimension of each timestep = article_emb_dim + day_emb_dim\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n\n # -------------------------\n # Custom Transformer setup\n # -------------------------\n config = {\n \"embedding_dim\": self.seq_feature_dim,\n \"heads\": transformer_nhead,\n \"transformer_dropout\": 0.2,\n \"dim_feedforward\": transformer_ff_dim,\n \"norm_first\": False,\n \"num_transformer_layers\": num_transformer_layers,\n }\n self.transformer = TransformerEncoder(config)\n\n # Flattened dimension after transformer\n self.transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features dimension\n user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim\n\n # Combined dimension\n combined_dim = self.transformer_output_dim + user_feature_dim\n\n # -------------------------\n # Separate heads:\n # 1) Regression (always)\n # 2) Classification (optional)\n # -------------------------\n self.regressor_head = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n\n if self.predict_churn:\n # Binary classification head (churn=0, not-churn=1)\n self.classifier_head = nn.Sequential(\n nn.Linear(combined_dim, 128),\n nn.LeakyReLU(),\n nn.Linear(128, 1) # single logit for BCE\n )\n self.classification_criterion = nn.BCEWithLogitsLoss()\n\n # MSE for regression\n self.regression_criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n \"\"\"\n Returns:\n combined_features: (B, combined_dim)\n regression_label: (B,)\n classification_label: (B,)\n \"\"\"\n (user_id, articles, days, age, postal_id, regression_label, classification_label) = batch\n\n # Sequence embeddings\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n\n # No separate positional embedding. \n # We treat 'days' as our time/pos feature.\n transformer_input = torch.cat([article_embeds, day_embeds], dim=-1) \n # shape: (B, L, seq_feature_dim)\n\n # Pass through our custom Transformer\n transformer_output = self.transformer(transformer_input) \n # shape: (B, L, seq_feature_dim)\n\n # Flatten the sequence dimension\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n\n # User-level features\n customer_embed = self.embeddings_customer(user_id)\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)\n\n # Combine sequence output + user features\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n\n return combined_features, regression_label, classification_label\n\n def forward(self, batch):\n features, reg_label, class_label = self.encode_input(batch)\n\n # 1) Regression output\n reg_output = self.regressor_head(features).squeeze(dim=-1)\n\n # 2) Classification output (only if predict_churn=True)\n if self.predict_churn:\n class_output = self.classifier_head(features).squeeze(dim=-1) # logit\n else:\n class_output = None\n\n return reg_output, class_output, reg_label, class_label\n\n def training_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n\n # Always compute regression loss\n reg_loss = self.regression_criterion(reg_output, reg_label)\n loss = reg_loss # default if churn is off\n self.log(\"train_reg_loss\", reg_loss)\n\n # If churn is enabled, compute classification loss\n if self.predict_churn:\n class_label = class_label.float() # BCE expects float\n class_loss = self.classification_criterion(class_output, class_label)\n loss = reg_loss + class_loss # combine them\n self.log(\"train_class_loss\", class_loss)\n\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n\n # Regression\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"val_reg_loss\", reg_loss)\n loss = reg_loss\n\n # Classification\n if self.predict_churn:\n class_label = class_label.float()\n class_loss = self.classification_criterion(class_output, class_label)\n self.log(\"val_class_loss\", class_loss)\n loss = reg_loss + class_loss\n\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n\n # Regression\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"test_reg_loss\", reg_loss)\n loss = reg_loss\n\n # Classification\n if self.predict_churn:\n class_label = class_label.float()\n class_loss = self.classification_criterion(class_output, class_label)\n self.log(\"test_class_loss\", class_loss)\n loss = reg_loss + class_loss\n\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.user2idx, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.user2idx, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n\n\n# Train and Test\n\n\nmodel = BST(\n num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n predict_churn=True, # <--- Enable churn\n learning_rate=0.0005\n)\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n\n\n# %% [markdown]\n# # REMOVED customer_id\n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:17:30.360894Z\",\"iopub.status.idle\":\"2025-03-19T10:17:30.361166Z\",\"shell.execute_reply\":\"2025-03-19T10:17:30.361049Z\"}}\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom typing import List, Dict, Optional\n\n\n# Custom Transformer Classes\n\nclass RMSNorm(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-8) -> None:\n super().__init__()\n self.eps = eps\n self.gain = nn.Parameter(torch.ones(dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n return norm * self.gain\n\nclass CustomMultiheadAttention(nn.Module):\n def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):\n super().__init__()\n assert embed_dim % num_heads == 0\n self.embed_dim = embed_dim\n self.num_heads = num_heads\n self.head_dim = embed_dim // num_heads\n self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)\n self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n self.dropout = dropout\n \n def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):\n B, T, _ = query.size()\n qkv = self.in_proj(query)\n q, k, v = qkv.chunk(3, dim=-1)\n\n # Reshape for multi-head attention\n q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n\n attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)\n attn_output = F.scaled_dot_product_attention(\n q, k, v,\n attn_mask=attn_mask,\n dropout_p=self.dropout if self.training else 0.0,\n is_causal=False,\n )\n\n attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)\n output = self.out_proj(attn_output)\n return output\n\n def merge_masks(self,\n attn_mask: Optional[torch.Tensor],\n key_padding_mask: Optional[torch.Tensor],\n query: torch.Tensor) -> Optional[torch.Tensor]:\n merged_mask = None\n batch_size, seq_len, _ = query.shape\n\n def convert_to_float_mask(mask):\n if mask.dtype == torch.bool:\n return mask.float().masked_fill(mask, float(\"-inf\"))\n return mask\n\n if key_padding_mask is not None:\n key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(\n -1, self.num_heads, -1, -1)\n merged_mask = convert_to_float_mask(key_padding_mask)\n\n if attn_mask is not None:\n if attn_mask.dim() == 2:\n correct_2d_size = (seq_len, seq_len)\n if attn_mask.shape != correct_2d_size:\n raise RuntimeError(f\"The shape of the 2D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_2d_size}.\")\n attn_mask = attn_mask.unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)\n elif attn_mask.dim() == 3:\n correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)\n if attn_mask.shape != correct_3d_size:\n raise RuntimeError(f\"The shape of the 3D attn_mask is {attn_mask.shape}, \"\n f\"but should be {correct_3d_size}.\")\n attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)\n else:\n raise RuntimeError(f\"attn_mask's dimension {attn_mask.dim()} is not supported\")\n attn_mask = convert_to_float_mask(attn_mask)\n if merged_mask is None:\n merged_mask = attn_mask\n else:\n merged_mask = merged_mask + attn_mask\n return merged_mask\n\nclass TransformerEncoderLayer(nn.Module):\n def __init__(self, config: dict):\n super().__init__()\n embed_dim = config[\"embedding_dim\"]\n num_heads = config.get(\"heads\", 8)\n dropout = config[\"transformer_dropout\"]\n dim_feedforward = config[\"dim_feedforward\"]\n self.norm_first = config.get(\"norm_first\", False)\n\n self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)\n self.linear1 = nn.Linear(embed_dim, dim_feedforward)\n self.dropout1 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, embed_dim)\n self.dropout2 = nn.Dropout(dropout)\n self.norm1 = RMSNorm(embed_dim)\n self.norm2 = RMSNorm(embed_dim)\n self.activation = nn.GELU()\n\n def _sa_block(self, src, attn_mask=None, key_padding_mask=None):\n src2 = self.self_attn(src, src, src,\n key_padding_mask=key_padding_mask,\n attn_mask=attn_mask)\n return self.dropout1(src2)\n\n def _ff_block(self, src):\n src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n return self.dropout2(src2)\n\n def forward(self,\n src: torch.Tensor,\n src_key_padding_mask: torch.Tensor = None,\n src_mask: torch.Tensor = None):\n if self.norm_first:\n src = src + self._sa_block(self.norm1(src),\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask)\n src = src + self._ff_block(self.norm2(src))\n else:\n src = self.norm1(src + self._sa_block(src,\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask))\n src = self.norm2(src + self._ff_block(src))\n return src\n\nclass TransformerEncoder(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.encoder = nn.ModuleList([\n TransformerEncoderLayer(config)\n for _ in range(config[\"num_transformer_layers\"])\n ])\n\n def forward(self,\n src,\n src_key_padding_mask=None,\n src_mask=None):\n for layer in self.encoder:\n src = layer(src,\n src_key_padding_mask=src_key_padding_mask,\n src_mask=src_mask)\n return src\n\n\n# Preparing vocabularies and the dataset (without customer_id)\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n 1) Ensure each df is a pandas DataFrame.\n 2) Builds dictionary for postal codes.\n 3) Finds max article ID, max day, and max age.\n \"\"\"\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst)\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1\n\n all_days = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return postal2idx, num_postal, num_articles, max_day, num_age\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - postal_code (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (0 means churn, 1 means not churn)\n - age (int)\n Note: customer_id is no longer used.\n \"\"\"\n def __init__(self, df, postal2idx: Dict[str, int]):\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n postal_id = self.postal2idx[row['postal_code']]\n age = int(row['age'])\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)\n return (articles, days, age, postal_id, regression_label, classification_label)\n\n# Custom collate function for variable-length sequences (without customer_id)\ndef fixed_length_collate_fn(batch, sequence_length=8):\n articles_list, days_list, ages, postal_ids, reg_labels, class_labels = [], [], [], [], [], []\n for item in batch:\n articles, days, age, postal_id, reg_label, cls_label = item\n articles_list.append(articles)\n days_list.append(days)\n ages.append(age)\n postal_ids.append(postal_id)\n reg_labels.append(reg_label)\n class_labels.append(cls_label)\n\n def pad_or_trunc(seq, desired_length):\n length = seq.size(0)\n if length > desired_length:\n return seq[:desired_length]\n elif length < desired_length:\n pad = torch.zeros(desired_length - length, dtype=seq.dtype)\n return torch.cat([seq, pad], dim=0)\n else:\n return seq\n\n for i in range(len(articles_list)):\n articles_list[i] = pad_or_trunc(articles_list[i], sequence_length)\n days_list[i] = pad_or_trunc(days_list[i], sequence_length)\n\n articles_tensor = torch.stack(articles_list, dim=0)\n days_tensor = torch.stack(days_list, dim=0)\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n reg_labels_tensor = torch.stack(reg_labels, dim=0)\n class_labels_tensor = torch.stack(class_labels, dim=0)\n\n return (articles_tensor, days_tensor, ages_tensor, postal_ids_tensor, reg_labels_tensor, class_labels_tensor)\n\n\n# BST model WITHOUT customer_id input\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n postal2idx,\n # embedding dims\n article_emb_dim=16,\n day_emb_dim=8,\n age_emb_dim=4,\n postal_emb_dim=4,\n # transformer config\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n # multi-task config\n predict_churn=False, # Flag to enable/disable churn prediction\n # training\n learning_rate=0.0005\n ):\n super().__init__()\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','postal2idx'])\n self.learning_rate = learning_rate\n self.predict_churn = predict_churn\n\n # DataFrames and mapping\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.postal2idx = postal2idx\n\n # Embeddings (customer embedding removed)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # Sequence features: concatenation of article and day embeddings\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n\n # Custom Transformer setup\n config = {\n \"embedding_dim\": self.seq_feature_dim,\n \"heads\": transformer_nhead,\n \"transformer_dropout\": 0.2,\n \"dim_feedforward\": transformer_ff_dim,\n \"norm_first\": False,\n \"num_transformer_layers\": num_transformer_layers,\n }\n self.transformer = TransformerEncoder(config)\n self.transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features: only age and postal embeddings are used\n user_feature_dim = age_emb_dim + postal_emb_dim\n\n combined_dim = self.transformer_output_dim + user_feature_dim\n\n # Separate heads for regression and (optional) classification\n self.regressor_head = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n\n if self.predict_churn:\n self.classifier_head = nn.Sequential(\n nn.Linear(combined_dim, 128),\n nn.LeakyReLU(),\n nn.Linear(128, 1) # single logit for binary classification\n )\n self.classification_criterion = nn.BCEWithLogitsLoss()\n\n self.regression_criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n # Expected tuple: (articles, days, age, postal_id, regression_label, classification_label)\n articles, days, age, postal_id, regression_label, classification_label = batch\n\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n # Concatenate to form the sequence features; day_embeds serve as the time/position signal.\n transformer_input = torch.cat([article_embeds, day_embeds], dim=-1) # (B, L, seq_feature_dim)\n\n transformer_output = self.transformer(transformer_input) # (B, L, seq_feature_dim)\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([age_embed, postal_embed], dim=1)\n\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n return combined_features, regression_label, classification_label\n\n def forward(self, batch):\n features, reg_label, class_label = self.encode_input(batch)\n reg_output = self.regressor_head(features).squeeze(dim=-1)\n if self.predict_churn:\n class_output = self.classifier_head(features).squeeze(dim=-1)\n else:\n class_output = None\n return reg_output, class_output, reg_label, class_label\n\n def training_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n loss = reg_loss\n self.log(\"train_reg_loss\", reg_loss)\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"train_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"val_reg_loss\", reg_loss)\n loss = reg_loss\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"val_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"test_reg_loss\", reg_loss)\n loss = reg_loss\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"test_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)\n )\n\n\n# train and test\n\nmodel = BST(\n #num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n #user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n #customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n predict_churn=True, # <--- Enable churn\n learning_rate=0.0005\n)\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n\n\n# %% [markdown]\n# # ALTERNATIVE VERSION (WITH CUSTOM COLLATOR)\n\n# %% [code]\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nimport pandas as pd\nfrom torch.utils.data import Dataset, DataLoader\nfrom torch.nn.utils.rnn import pad_sequence\nfrom typing import List, Dict, Optional\n\n\nclass RMSNorm(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-8) -> None:\n super().__init__()\n self.eps = eps\n self.gain = nn.Parameter(torch.ones(dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n return norm * self.gain\n\nclass CustomMultiheadAttention(nn.Module):\n def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):\n super().__init__()\n assert embed_dim % num_heads == 0\n self.embed_dim = embed_dim\n self.num_heads = num_heads\n self.head_dim = embed_dim // num_heads\n self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)\n self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n self.dropout = dropout\n \n def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):\n B, T, _ = query.size()\n qkv = self.in_proj(query)\n q, k, v = qkv.chunk(3, dim=-1)\n # Reshape for multi-head attention\n q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)\n attn_output = F.scaled_dot_product_attention(\n q, k, v,\n attn_mask=attn_mask,\n dropout_p=self.dropout if self.training else 0.0,\n is_causal=False,\n )\n attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)\n output = self.out_proj(attn_output)\n return output\n\n def merge_masks(\n self,\n attn_mask: Optional[torch.Tensor],\n key_padding_mask: Optional[torch.Tensor],\n query: torch.Tensor,\n ) -> Optional[torch.Tensor]:\n merged_mask = None\n batch_size, seq_len, _ = query.shape\n\n def convert_to_float_mask(mask):\n if mask.dtype == torch.bool:\n return mask.float().masked_fill(mask, float(\"-inf\"))\n return mask\n\n if key_padding_mask is not None:\n key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(\n -1, self.num_heads, -1, -1)\n merged_mask = convert_to_float_mask(key_padding_mask)\n\n if attn_mask is not None:\n if attn_mask.dim() == 2:\n correct_2d_size = (seq_len, seq_len)\n if attn_mask.shape != correct_2d_size:\n raise RuntimeError(\n f\"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.\"\n )\n attn_mask = attn_mask.unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)\n elif attn_mask.dim() == 3:\n correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)\n if attn_mask.shape != correct_3d_size:\n raise RuntimeError(\n f\"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.\"\n )\n attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)\n else:\n raise RuntimeError(f\"attn_mask's dimension {attn_mask.dim()} is not supported\")\n attn_mask = convert_to_float_mask(attn_mask)\n if merged_mask is None:\n merged_mask = attn_mask\n else:\n merged_mask = merged_mask + attn_mask\n\n return merged_mask\n\nclass TransformerEncoderLayer(nn.Module):\n def __init__(self, config: dict):\n super().__init__()\n embed_dim = config[\"embedding_dim\"]\n num_heads = config.get(\"heads\", 8)\n dropout = config[\"transformer_dropout\"]\n dim_feedforward = config[\"dim_feedforward\"]\n self.norm_first = config.get(\"norm_first\", False)\n self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)\n self.linear1 = nn.Linear(embed_dim, dim_feedforward)\n self.dropout1 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, embed_dim)\n self.dropout2 = nn.Dropout(dropout)\n self.norm1 = RMSNorm(embed_dim)\n self.norm2 = RMSNorm(embed_dim)\n self.activation = nn.GELU()\n\n def _sa_block(self, src, attn_mask=None, key_padding_mask=None):\n src2 = self.self_attn(src, src, src,\n key_padding_mask=key_padding_mask,\n attn_mask=attn_mask)\n return self.dropout1(src2)\n\n def _ff_block(self, src):\n src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n return self.dropout2(src2)\n\n def forward(self,\n src: torch.Tensor,\n src_key_padding_mask: torch.Tensor = None,\n src_mask: torch.Tensor = None):\n if self.norm_first:\n src = src + self._sa_block(self.norm1(src),\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask)\n src = src + self._ff_block(self.norm2(src))\n else:\n src = self.norm1(src + self._sa_block(src,\n attn_mask=src_mask,\n key_padding_mask=src_key_padding_mask))\n src = self.norm2(src + self._ff_block(src))\n return src\n\nclass TransformerEncoder(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.encoder = nn.ModuleList([\n TransformerEncoderLayer(config)\n for _ in range(config[\"num_transformer_layers\"])\n ])\n\n def forward(self,\n src,\n src_key_padding_mask=None,\n src_mask=None):\n for layer in self.encoder:\n src = layer(src,\n src_key_padding_mask=src_key_padding_mask,\n src_mask=src_mask)\n return src\n\n\ndef prepare_vocabularies(train_df, val_df, test_df):\n \"\"\"\n Ensures each df is a pandas DataFrame, builds a dictionary for postal codes,\n and finds max article ID, max day, and max age.\n \"\"\"\n def to_pandas_if_polars(df):\n return df.to_pandas() if not hasattr(df, \"iloc\") else df\n\n train_pd = to_pandas_if_polars(train_df)\n val_pd = to_pandas_if_polars(val_df)\n test_pd = to_pandas_if_polars(test_df)\n combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)\n\n unique_postals = combined['postal_code'].unique()\n postal2idx = {p: i for i, p in enumerate(unique_postals)}\n num_postal = len(postal2idx)\n\n all_articles = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['articles_ids_lst']:\n all_articles.extend(lst)\n max_article_id = max(all_articles)\n num_articles = max_article_id + 1\n\n all_days = []\n for df_pd in [train_pd, val_pd, test_pd]:\n for lst in df_pd['days_before_lst']:\n all_days.extend(lst)\n max_day = max(all_days)\n\n max_age = combined['age'].max()\n num_age = max_age + 1\n\n return postal2idx, num_postal, num_articles, max_day, num_age\n\nclass CustomerDataset(Dataset):\n \"\"\"\n Expects columns:\n - postal_code (str)\n - days_before_lst (list[int])\n - articles_ids_lst (list[int])\n - regression_label (float)\n - classification_label (int) (0 means churn, 1 means not churn)\n - age (int)\n Note: customer_id is not used.\n \"\"\"\n def __init__(self, df, postal2idx: Dict[str, int]):\n if not hasattr(df, \"iloc\"):\n df = df.to_pandas()\n self.data = df\n self.postal2idx = postal2idx\n\n def __len__(self):\n return len(self.data)\n \n def __getitem__(self, idx):\n row = self.data.iloc[idx]\n postal_id = self.postal2idx[row['postal_code']]\n age = int(row['age'])\n articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)\n days = torch.tensor(row['days_before_lst'], dtype=torch.long)\n regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)\n classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)\n return (articles, days, age, postal_id, regression_label, classification_label)\n\n\ndef fixed_length_collate_fn(batch: list[tuple[torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor]],\n sequence_length: int = 8, padding_value: int = 0) -> tuple[torch.Tensor, ...]:\n \"\"\"\n Efficiently pads sequences using PyTorch's pad_sequence and then truncates them.\n \n Args:\n batch: List of tuples where each tuple contains \n (articles, days, age, postal_id, regression_label, classification_label)\n sequence_length: Desired length of the sequences\n padding_value: Value to use for padding sequences\n \n Returns:\n Tuple of tensors: \n (article_seqs_tensor, day_seqs_tensor, ages_tensor, postal_ids_tensor, reg_labels_tensor, class_labels_tensor)\n \"\"\"\n article_seqs, day_seqs, ages, postal_ids, reg_labels, class_labels = zip(*batch)\n article_seqs_tensor = pad_sequence(article_seqs, batch_first=True, padding_value=padding_value)\n day_seqs_tensor = pad_sequence(day_seqs, batch_first=True, padding_value=padding_value)\n # Truncate to the desired sequence length\n article_seqs_tensor = article_seqs_tensor[:, :sequence_length]\n day_seqs_tensor = day_seqs_tensor[:, :sequence_length]\n ages_tensor = torch.tensor(ages, dtype=torch.long)\n postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)\n reg_labels_tensor = torch.stack(reg_labels, dim=0)\n class_labels_tensor = torch.stack(class_labels, dim=0)\n return (\n article_seqs_tensor,\n day_seqs_tensor,\n ages_tensor,\n postal_ids_tensor,\n reg_labels_tensor,\n class_labels_tensor\n )\n\n\nclass BST(pl.LightningModule):\n def __init__(\n self,\n num_articles,\n max_day,\n num_age,\n num_postal,\n sequence_length,\n train_df,\n val_df,\n test_df,\n postal2idx,\n # Embedding dimensions\n article_emb_dim: int = 16,\n day_emb_dim: int = 8,\n age_emb_dim: int = 4,\n postal_emb_dim: int = 4,\n # Transformer config\n transformer_nhead: int = 2,\n transformer_ff_dim: int = 64,\n num_transformer_layers: int = 1,\n # Multi-task config\n predict_churn: bool = False,\n # Training\n learning_rate: float = 0.0005\n ):\n super().__init__()\n # Save hyperparameters (including sequence_length)\n self.save_hyperparameters(ignore=['train_df','val_df','test_df','postal2idx'])\n self.learning_rate = learning_rate\n self.predict_churn = predict_churn\n\n # DataFrames and mapping\n self.train_df = train_df\n self.val_df = val_df\n self.test_df = test_df\n self.postal2idx = postal2idx\n\n # Embeddings (customer_id removed)\n self.embeddings_age = nn.Embedding(num_age, age_emb_dim)\n self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)\n self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)\n self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)\n\n # Sequence features: concatenation of article and day embeddings\n self.seq_feature_dim = article_emb_dim + day_emb_dim\n\n # Custom Transformer setup\n config = {\n \"embedding_dim\": self.seq_feature_dim,\n \"heads\": transformer_nhead,\n \"transformer_dropout\": 0.2,\n \"dim_feedforward\": transformer_ff_dim,\n \"norm_first\": False,\n \"num_transformer_layers\": num_transformer_layers,\n }\n self.transformer = TransformerEncoder(config)\n self.transformer_output_dim = sequence_length * self.seq_feature_dim\n\n # User features: only age and postal embeddings are used\n user_feature_dim = age_emb_dim + postal_emb_dim\n combined_dim = self.transformer_output_dim + user_feature_dim\n\n # Separate heads for regression and optional classification\n self.regressor_head = nn.Sequential(\n nn.Linear(combined_dim, 512),\n nn.LeakyReLU(),\n nn.Linear(512, 256),\n nn.LeakyReLU(),\n nn.Linear(256, 1)\n )\n if self.predict_churn:\n self.classifier_head = nn.Sequential(\n nn.Linear(combined_dim, 128),\n nn.LeakyReLU(),\n nn.Linear(128, 1) # single logit for binary classification\n )\n self.classification_criterion = nn.BCEWithLogitsLoss()\n self.regression_criterion = nn.MSELoss()\n\n def encode_input(self, batch):\n # Expecting: (articles, days, age, postal_id, regression_label, classification_label)\n articles, days, age, postal_id, regression_label, classification_label = batch\n article_embeds = self.embeddings_article(articles) # (B, L, article_emb_dim)\n day_embeds = self.embeddings_day(days) # (B, L, day_emb_dim)\n # Concatenate article and day embeddings as the sequence features\n transformer_input = torch.cat([article_embeds, day_embeds], dim=-1) # (B, L, seq_feature_dim)\n transformer_output = self.transformer(transformer_input) # (B, L, seq_feature_dim)\n transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)\n age_embed = self.embeddings_age(age)\n postal_embed = self.embeddings_postal(postal_id)\n user_features = torch.cat([age_embed, postal_embed], dim=1)\n combined_features = torch.cat([transformer_output_flat, user_features], dim=1)\n return combined_features, regression_label, classification_label\n\n def forward(self, batch):\n features, reg_label, class_label = self.encode_input(batch)\n reg_output = self.regressor_head(features).squeeze(dim=-1)\n if self.predict_churn:\n class_output = self.classifier_head(features).squeeze(dim=-1)\n else:\n class_output = None\n return reg_output, class_output, reg_label, class_label\n\n def training_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n loss = reg_loss\n self.log(\"train_reg_loss\", reg_loss)\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"train_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"train_loss\", loss)\n return loss\n\n def validation_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"val_reg_loss\", reg_loss)\n loss = reg_loss\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"val_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"val_loss\", loss)\n return loss\n\n def test_step(self, batch, batch_idx):\n reg_output, class_output, reg_label, class_label = self(batch)\n reg_loss = self.regression_criterion(reg_output, reg_label)\n self.log(\"test_reg_loss\", reg_loss)\n loss = reg_loss\n if self.predict_churn:\n class_loss = self.classification_criterion(class_output, class_label.float())\n self.log(\"test_class_loss\", class_loss)\n loss = reg_loss + class_loss\n self.log(\"test_loss\", loss)\n return loss\n\n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n\n def setup(self, stage=None):\n if stage == \"fit\" or stage is None:\n self.train_dataset = CustomerDataset(self.train_df, self.postal2idx)\n self.val_dataset = CustomerDataset(self.val_df, self.postal2idx)\n if stage == \"test\" or stage is None:\n self.test_dataset = CustomerDataset(self.test_df, self.postal2idx)\n\n def train_dataloader(self):\n return DataLoader(\n self.train_dataset,\n batch_size=128,\n shuffle=True,\n num_workers=4,\n collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)\n )\n\n def val_dataloader(self):\n return DataLoader(\n self.val_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)\n )\n\n def test_dataloader(self):\n return DataLoader(\n self.test_dataset,\n batch_size=128,\n shuffle=False,\n num_workers=4,\n collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)\n )\npostal2idx, num_postal, num_articles, max_day, num_age = prepare_vocabularies(train_df, val_df, test_df)\n\n # Define the desired sequence length\nsequence_length = 8 \nmodel = BST(\n #num_customers=num_customers,\n num_articles=num_articles,\n max_day=max_day,\n num_age=num_age,\n num_postal=num_postal,\n sequence_length=sequence_length,\n train_df=train_df,\n val_df=val_df,\n test_df=test_df,\n #user2idx=user2idx,\n postal2idx=postal2idx,\n article_emb_dim=16,\n day_emb_dim=8,\n #customer_emb_dim=16,\n age_emb_dim=4,\n postal_emb_dim=4,\n transformer_nhead=2,\n transformer_ff_dim=64,\n num_transformer_layers=1,\n predict_churn=True, # <--- Enable churn\n learning_rate=0.0005\n)\n\ntrainer = pl.Trainer(accelerator=\"gpu\", devices=\"auto\", max_epochs=1)\ntrainer.fit(model)\ntrainer.test(model)\n\n# %% [markdown]\n# # GET ENVIRONMENT \n\n# %% [code] {\"execution\":{\"iopub.status.busy\":\"2025-03-19T10:36:31.565379Z\",\"iopub.execute_input\":\"2025-03-19T10:36:31.565814Z\",\"iopub.status.idle\":\"2025-03-19T10:36:33.155288Z\",\"shell.execute_reply.started\":\"2025-03-19T10:36:31.565776Z\",\"shell.execute_reply\":\"2025-03-19T10:36:33.154293Z\"}}\n!pip freeze > requirements.txt\n","metadata":{"_uuid":"89ec5c40-c8ad-4af7-bf21-fd23377925a2","_cell_guid":"bc349ae9-a85f-4236-b88c-61d4c803a458","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}},"outputs":[],"execution_count":null}]}