diff --git a/CHANGELOG.md b/CHANGELOG.md index 558ee220..acc13ced 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased +### Added +- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287)) + ## [0.14.0] - 16.05.2025 ### Added diff --git a/rectools/dataset/dataset.py b/rectools/dataset/dataset.py index 6d7a7d52..22cd71fd 100644 --- a/rectools/dataset/dataset.py +++ b/rectools/dataset/dataset.py @@ -348,7 +348,10 @@ def get_user_item_matrix( return matrix def get_raw_interactions( - self, include_weight: bool = True, include_datetime: bool = True, include_extra_cols: bool = True + self, + include_weight: bool = True, + include_datetime: bool = True, + include_extra_cols: tp.Union[bool, tp.List[str]] = True, ) -> pd.DataFrame: """ Return interactions as a `pd.DataFrame` object with replacing internal user and item ids to external ones. diff --git a/rectools/dataset/interactions.py b/rectools/dataset/interactions.py index 3f06ba70..2cfda50a 100644 --- a/rectools/dataset/interactions.py +++ b/rectools/dataset/interactions.py @@ -167,7 +167,7 @@ def to_external( item_id_map: IdMap, include_weight: bool = True, include_datetime: bool = True, - include_extra_cols: bool = True, + include_extra_cols: tp.Union[bool, tp.List[str]] = True, ) -> pd.DataFrame: """ Convert itself to `pd.DataFrame` with replacing internal user and item ids to external ones. @@ -182,8 +182,9 @@ def to_external( Whether to include weight column into resulting table or not include_datetime : bool, default ``True`` Whether to include datetime column into resulting table or not. - include_extra_cols: bool, default ``True`` - Whether to include extra columns into resulting table or not. + include_extra_cols: bool or List[str], default ``True`` + If bool, indicates whether to include all extra columns into resulting table or not. + If list of strings, indicates which extra columns to include into resulting table. Returns ------- @@ -201,9 +202,13 @@ def to_external( cols_to_add.append(Columns.Weight) if include_datetime: cols_to_add.append(Columns.Datetime) - if include_extra_cols: + + extra_cols = [] + if isinstance(include_extra_cols, list): + extra_cols = [col for col in include_extra_cols if col in self.df and col not in Columns.Interactions] + elif include_extra_cols: extra_cols = [col for col in self.df if col not in Columns.Interactions] - cols_to_add.extend(extra_cols) + cols_to_add.extend(extra_cols) for col in cols_to_add: res[col] = self.df[col] diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 8e31d6ff..a58a4502 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -36,7 +36,7 @@ ValMaskCallable, ) from .constants import MASKING_VALUE, PADDING_VALUE -from .data_preparator import InitKwargs, TransformerDataPreparatorBase +from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -128,7 +128,7 @@ def _mask_session( def _collate_fn_train( self, - batch: List[Tuple[List[int], List[float]]], + batch: tp.List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Mask session elements to receive `x`. @@ -141,7 +141,7 @@ def _collate_fn_train( x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): masked_session, target = self._mask_session(ses) x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len] @@ -154,12 +154,12 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # until only leave-one-strategy yw = np.zeros((batch_size, 1)) # until only leave-one-strategy - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] session = input_session.copy() @@ -179,14 +179,14 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st ) return batch_dict - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: """ Right truncation, left padding to `session_max_len` During inference model will use (`session_max_len` - 1) interactions and one extra "MASK" token will be added for making predictions. """ x = np.zeros((len(batch), self.session_max_len)) - for i, (ses, _) in enumerate(batch): + for i, (ses, _, _) in enumerate(batch): session = ses.copy() session = session + [self.extra_token_ids[MASKING_VALUE]] x[i, -len(ses) - 1 :] = session[-self.session_max_len :] diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index b13ec87d..40993d40 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -32,6 +32,8 @@ from .negative_sampler import TransformerNegativeSamplerBase InitKwargs = tp.Dict[str, tp.Any] +# (user session, session weights, extra columns) +BatchElement = tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]] class SequenceDataset(TorchDataset): @@ -46,17 +48,26 @@ class SequenceDataset(TorchDataset): Weight of each interaction from the session. """ - def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]): + def __init__( + self, + sessions: tp.List[tp.List[int]], + weights: tp.List[tp.List[float]], + extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None, + ): self.sessions = sessions self.weights = weights + self.extras = extras def __len__(self) -> int: return len(self.sessions) - def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]: + def __getitem__(self, index: int) -> BatchElement: session = self.sessions[index] # [session_len] weights = self.weights[index] # [session_len] - return session, weights + extras = ( + {feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {} + ) + return session, weights, extras @classmethod def from_interactions( @@ -73,17 +84,19 @@ def from_interactions( interactions : pd.DataFrame User-item interactions. """ + cols_to_agg = [col for col in interactions.columns if col != Columns.User] sessions = ( interactions.sort_values(Columns.Datetime, kind="stable") - .groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]] + .groupby(Columns.User, sort=sort_users)[cols_to_agg] .agg(list) ) - sessions, weights = ( + sessions_items, weights = ( sessions[Columns.Item].to_list(), sessions[Columns.Weight].to_list(), ) - - return cls(sessions=sessions, weights=weights) + extra_cols = [col for col in interactions.columns if col not in Columns.Interactions] + extras = {col: sessions[col].to_list() for col in extra_cols} if len(extra_cols) > 0 else None + return cls(sessions=sessions_items, weights=weights, extras=extras) class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes @@ -114,6 +127,8 @@ class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attrib get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` Additional keyword arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types. + extra_cols: optional(List[str]), default ``None`` + Extra columns to keep in train and recommend datasets. """ # We sometimes need data preparators to add +1 to actual session_max_len @@ -133,6 +148,7 @@ def __init__( n_negatives: tp.Optional[int] = None, negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, + extra_cols: tp.Optional[tp.List[str]] = None, **kwargs: tp.Any, ) -> None: self.item_id_map: IdMap @@ -148,6 +164,7 @@ def __init__( self.shuffle_train = shuffle_train self.get_val_mask_func = get_val_mask_func self.get_val_mask_func_kwargs = get_val_mask_func_kwargs + self.extra_cols = extra_cols def get_known_items_sorted_internal_ids(self) -> np.ndarray: """Return internal item ids from processed dataset in sorted order.""" @@ -203,7 +220,8 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat def process_dataset_train(self, dataset: Dataset) -> None: """Process train dataset and save data.""" - raw_interactions = dataset.get_raw_interactions() + extra_cols = False if self.extra_cols is None else self.extra_cols + raw_interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols) # Exclude val interaction targets from train if needed interactions = raw_interactions @@ -231,7 +249,12 @@ def process_dataset_train(self, dataset: Dataset) -> None: # Prepare train dataset # User features are dropped for now because model doesn't support them - final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True) + final_interactions = Interactions.from_raw( + interactions, + user_id_map, + item_id_map, + keep_extra_cols=True, + ) self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features) self.item_id_map = self.train_dataset.item_id_map self._init_extra_token_ids() @@ -246,7 +269,9 @@ def process_dataset_train(self, dataset: Dataset) -> None: val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy() val_interactions[Columns.Weight] = 0 val_interactions = pd.concat([val_interactions, val_targets], axis=0) - self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df + self.val_interactions = Interactions.from_raw( + val_interactions, user_id_map, item_id_map, keep_extra_cols=True + ).df def _init_extra_token_ids(self) -> None: extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens) @@ -340,7 +365,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset Final item_id_map is model item_id_map constructed during training. """ # Filter interactions in dataset internal ids - interactions = dataset.interactions.df + required_cols = Columns.Interactions + if self.extra_cols is not None: + required_cols = required_cols + self.extra_cols + interactions = dataset.interactions.df[required_cols] users_internal = dataset.user_id_map.convert_to_internal(users, strict=False) items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False) interactions = interactions[interactions[Columns.User].isin(users_internal)] @@ -359,7 +387,9 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset if n_filtered > 0: explanation = f"""{n_filtered} target users were considered cold because of missing known items""" warnings.warn(explanation) - filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map) + filtered_interactions = Interactions.from_raw( + interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=True + ) filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset @@ -381,26 +411,29 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: Final user_id_map is the same as dataset original. Final item_id_map is model item_id_map constructed during training. """ - interactions = dataset.get_raw_interactions() + extra_cols = False if self.extra_cols is None else self.extra_cols + interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols) interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] - filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) + filtered_interactions = Interactions.from_raw( + interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols=True + ) filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_val( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_recommend( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index df97f882..15eeba8c 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -387,7 +387,9 @@ def _get_user_item_embeddings( for batch in recommend_dataloader: batch = {k: v.to(device) for k, v in batch.items()} batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :] + batch_embs = self.torch_model.similarity_module.session_tower_forward(batch_embs) user_embs.append(batch_embs.cpu()) + item_embs = self.torch_model.similarity_module.item_tower_forward(item_embs) return torch.cat(user_embs), item_embs diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index b8350f72..a3f0c73b 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -13,7 +13,7 @@ # limitations under the License. import typing as tp -from typing import Dict, List, Tuple +from typing import Dict import numpy as np import torch @@ -36,7 +36,7 @@ TransformerModelConfig, ValMaskCallable, ) -from .data_preparator import InitKwargs, TransformerDataPreparatorBase +from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase): def _collate_fn_train( self, - batch: List[Tuple[List[int], List[float]]], + batch: tp.List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Truncate each session from right to keep `session_max_len` items. @@ -91,7 +91,7 @@ def _collate_fn_train( x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len] yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] @@ -103,12 +103,12 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] # take only first target for leave-one-strategy @@ -126,10 +126,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st ) return batch_dict - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) - for i, (ses, _) in enumerate(batch): + for i, (ses, _, _) in enumerate(batch): x[i, -len(ses) :] = ses[-self.session_max_len :] return {"x": torch.LongTensor(x)} diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index da1ac615..006f1ba0 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -34,6 +34,14 @@ def _get_pos_neg_logits( ) -> torch.Tensor: raise NotImplementedError() + def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor: + """Forward pass for session tower.""" + return session_embs + + def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor: + """Forward pass for item tower.""" + return item_embs + def forward( self, session_embs: torch.Tensor, @@ -62,7 +70,11 @@ class DistanceSimilarityModule(SimilarityModuleBase): dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) - def __init__(self, distance: str = "dot") -> None: + def __init__( + self, + distance: str = "dot", + **kwargs: tp.Any, + ) -> None: super().__init__() if distance not in self.dist_available: raise ValueError("`dist` can only be either `dot` or `cosine`.") diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 140389aa..f89e82ee 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -34,7 +34,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable -from rectools.models.nn.transformers.data_preparator import InitKwargs +from rectools.models.nn.transformers.data_preparator import BatchElement, InitKwargs from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone @@ -640,13 +640,13 @@ def __init__( def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): y[i, -self.n_last_targets] = ses[-self.n_last_targets] yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets] x[i, -len(ses) :] = ses diff --git a/tests/models/nn/transformers/test_data_preparator.py b/tests/models/nn/transformers/test_data_preparator.py index 5f41ea8e..ec6bcd67 100644 --- a/tests/models/nn/transformers/test_data_preparator.py +++ b/tests/models/nn/transformers/test_data_preparator.py @@ -31,30 +31,37 @@ class TestSequenceDataset: def interactions_df(self) -> pd.DataFrame: interactions_df = pd.DataFrame( [ - [10, 13, 1, "2021-11-30"], - [10, 11, 1, "2021-11-29"], - [10, 12, 4, "2021-11-29"], - [30, 11, 1, "2021-11-27"], - [30, 12, 2, "2021-11-26"], - [30, 15, 1, "2021-11-25"], - [40, 11, 1, "2021-11-25"], - [40, 17, 8, "2021-11-26"], - [50, 16, 1, "2021-11-25"], - [10, 14, 1, "2021-11-28"], + [10, 13, 1, "2021-11-30", 0], + [10, 11, 1, "2021-11-29", 1], + [10, 12, 4, "2021-11-29", 1], + [30, 11, 1, "2021-11-27", 0], + [30, 12, 2, "2021-11-26", 1], + [30, 15, 1, "2021-11-25", 1], + [40, 11, 1, "2021-11-25", 2], + [40, 17, 8, "2021-11-26", 1], + [50, 16, 1, "2021-11-25", 0], + [10, 14, 1, "2021-11-28", 0], ], - columns=Columns.Interactions, + columns=Columns.Interactions + ["extra_column"], ) return interactions_df @pytest.mark.parametrize( - "expected_sessions, expected_weights", - (([[14, 11, 12, 13], [15, 12, 11], [11, 17], [16]], [[1, 1, 4, 1], [1, 2, 1], [1, 8], [1]]),), + "expected_sessions, expected_weights, expected_extras", + ( + ( + [[14, 11, 12, 13], [15, 12, 11], [11, 17], [16]], + [[1, 1, 4, 1], [1, 2, 1], [1, 8], [1]], + {"extra_column": [[0, 1, 1, 0], [1, 1, 0], [2, 1], [0]]}, + ), + ), ) def test_from_interactions( self, interactions_df: pd.DataFrame, expected_sessions: tp.List[tp.List[int]], expected_weights: tp.List[tp.List[float]], + expected_extras: tp.Dict[str, tp.List[tp.Any]], ) -> None: actual = SequenceDataset.from_interactions(interactions=interactions_df, sort_users=True) assert len(actual.sessions) == len(expected_sessions) @@ -63,6 +70,12 @@ def test_from_interactions( ) assert len(actual.weights) == len(expected_weights) assert all(actual_list == expected_list for actual_list, expected_list in zip(actual.weights, expected_weights)) + assert actual.extras is not None + assert len(actual.extras["extra_column"]) == len(expected_extras["extra_column"]) + assert all( + actual_list == expected_list + for actual_list, expected_list in zip(actual.extras["extra_column"], expected_extras["extra_column"]) + ) class TestTransformerDataPreparatorBase: @@ -71,26 +84,26 @@ class TestTransformerDataPreparatorBase: def interactions_df(self) -> pd.DataFrame: interactions_df = pd.DataFrame( [ - [10, 13, 1, "2021-11-30"], - [10, 11, 1, "2021-11-29"], - [10, 12, 1, "2021-11-29"], - [30, 11, 1, "2021-11-27"], - [30, 12, 2, "2021-11-26"], - [30, 15, 1, "2021-11-25"], - [40, 11, 1, "2021-11-25"], - [40, 17, 1, "2021-11-26"], - [50, 16, 1, "2021-11-25"], - [10, 14, 1, "2021-11-28"], - [10, 16, 1, "2021-11-27"], - [20, 13, 9, "2021-11-28"], + [10, 13, 1, "2021-11-30", 0], + [10, 11, 1, "2021-11-29", 2], + [10, 12, 1, "2021-11-29", 3], + [30, 11, 1, "2021-11-27", 4], + [30, 12, 2, "2021-11-26", 1], + [30, 15, 1, "2021-11-25", 0], + [40, 11, 1, "2021-11-25", 1], + [40, 17, 1, "2021-11-26", 1], + [50, 16, 1, "2021-11-25", 2], + [10, 14, 1, "2021-11-28", 2], + [10, 16, 1, "2021-11-27", 1], + [20, 13, 9, "2021-11-28", 1], ], - columns=Columns.Interactions, + columns=Columns.Interactions + ["extra_column"], ) return interactions_df @pytest.fixture def dataset(self, interactions_df: pd.DataFrame) -> Dataset: - return Dataset.construct(interactions_df) + return Dataset.construct(interactions_df, keep_extra_cols=True) @pytest.fixture def dataset_dense_item_features(self, interactions_df: pd.DataFrame) -> Dataset: @@ -119,6 +132,7 @@ def data_preparator(self) -> TransformerDataPreparatorBase: session_max_len=4, batch_size=4, dataloader_num_workers=0, + extra_cols=["extra_column"], ) @pytest.mark.parametrize( @@ -130,17 +144,17 @@ def data_preparator(self) -> TransformerDataPreparatorBase: Interactions( pd.DataFrame( [ - [0, 1, 1.0, "2021-11-25"], - [1, 2, 1.0, "2021-11-25"], - [0, 3, 2.0, "2021-11-26"], - [1, 4, 1.0, "2021-11-26"], - [0, 2, 1.0, "2021-11-27"], - [2, 5, 1.0, "2021-11-28"], - [2, 2, 1.0, "2021-11-29"], - [2, 3, 1.0, "2021-11-29"], - [2, 6, 1.0, "2021-11-30"], + [0, 1, 1.0, "2021-11-25", 0], + [1, 2, 1.0, "2021-11-25", 1], + [0, 3, 2.0, "2021-11-26", 1], + [1, 4, 1.0, "2021-11-26", 1], + [0, 2, 1.0, "2021-11-27", 4], + [2, 5, 1.0, "2021-11-28", 2], + [2, 2, 1.0, "2021-11-29", 2], + [2, 3, 1.0, "2021-11-29", 3], + [2, 6, 1.0, "2021-11-30", 0], ], - columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra_column"], ), ), ), @@ -165,6 +179,7 @@ def test_process_dataset_train_with_dense_item_features( dataset_dense_item_features: Dataset, data_preparator: TransformerDataPreparatorBase, ) -> None: + data_preparator.extra_cols = None data_preparator.process_dataset_train(dataset_dense_item_features) actual = data_preparator.train_dataset.item_features expected_values = np.array( @@ -192,13 +207,13 @@ def test_process_dataset_train_with_dense_item_features( Interactions( pd.DataFrame( [ - [0, 6, 1.0, "2021-11-30"], - [0, 2, 1.0, "2021-11-29"], - [0, 3, 1.0, "2021-11-29"], - [0, 5, 1.0, "2021-11-28"], - [1, 6, 9.0, "2021-11-28"], + [0, 6, 1.0, "2021-11-30", 0], + [0, 2, 1.0, "2021-11-29", 2], + [0, 3, 1.0, "2021-11-29", 3], + [0, 5, 1.0, "2021-11-28", 2], + [1, 6, 9.0, "2021-11-28", 1], ], - columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra_column"], ), ), ), @@ -228,18 +243,18 @@ def test_transform_dataset_u2i( Interactions( pd.DataFrame( [ - [0, 6, 1.0, "2021-11-30"], - [0, 2, 1.0, "2021-11-29"], - [0, 3, 1.0, "2021-11-29"], - [1, 2, 1.0, "2021-11-27"], - [1, 3, 2.0, "2021-11-26"], - [1, 1, 1.0, "2021-11-25"], - [2, 2, 1.0, "2021-11-25"], - [2, 4, 1.0, "2021-11-26"], - [0, 5, 1.0, "2021-11-28"], - [4, 6, 9.0, "2021-11-28"], + [0, 6, 1.0, "2021-11-30", 0], + [0, 2, 1.0, "2021-11-29", 2], + [0, 3, 1.0, "2021-11-29", 3], + [1, 2, 1.0, "2021-11-27", 4], + [1, 3, 2.0, "2021-11-26", 1], + [1, 1, 1.0, "2021-11-25", 0], + [2, 2, 1.0, "2021-11-25", 1], + [2, 4, 1.0, "2021-11-26", 1], + [0, 5, 1.0, "2021-11-28", 2], + [4, 6, 9.0, "2021-11-28", 1], ], - columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra_column"], ), ), ),