-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathdataset.py
More file actions
485 lines (409 loc) · 18.8 KB
/
dataset.py
File metadata and controls
485 lines (409 loc) · 18.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
# Copyright 2022-2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset - all data container."""
import typing as tp
from collections.abc import Hashable
import attr
import numpy as np
import pandas as pd
import typing_extensions as tpe
from pydantic import PlainSerializer
from scipy import sparse
from rectools import Columns
from rectools.utils.config import BaseConfig
from .features import AbsentIdError, DenseFeatures, Features, SparseFeatureName, SparseFeatures
from .identifiers import IdMap
from .interactions import Interactions
AnyFeatureName = tp.Union[str, SparseFeatureName]
def _serialize_feature_name(spec: tp.Any) -> Hashable:
type_error = TypeError(
f"""
Serialization for feature name '{spec}' is not supported.
Please convert your feature names and category feature values to strings, numbers, booleans
or their tuples.
"""
)
if isinstance(spec, (list, np.ndarray)):
raise type_error
if isinstance(spec, tuple):
return tuple(_serialize_feature_name(item) for item in spec)
if isinstance(spec, (int, float, str, bool)):
return spec
if hasattr(spec, "dtype") and (np.issubdtype(spec.dtype, np.number) or np.issubdtype(spec.dtype, np.bool_)):
# numpy str is handled by isinstance(spec, str)
return spec.item()
raise type_error
FeatureName = tpe.Annotated[AnyFeatureName, PlainSerializer(_serialize_feature_name, when_used="json")]
DatasetSchemaDict = tp.Dict[str, tp.Any]
class BaseFeaturesSchema(BaseConfig):
"""Features schema."""
names: tp.Tuple[FeatureName, ...]
class DenseFeaturesSchema(BaseFeaturesSchema):
"""Dense features schema."""
kind: tp.Literal["dense"] = "dense"
class SparseFeaturesSchema(BaseFeaturesSchema):
"""Sparse features schema."""
kind: tp.Literal["sparse"] = "sparse"
cat_feature_indices: tp.List[int]
cat_n_stored_values: int
class InteractionsFeaturesSchema(BaseConfig):
"""Interactions features schema."""
cat_feature_names: tp.List[str]
cat_feature_names_w_values: tp.List[tp.Tuple[str, str]]
direct_feature_names: tp.List[str]
FeaturesSchema = tp.Union[DenseFeaturesSchema, SparseFeaturesSchema]
class IdMapSchema(BaseConfig):
"""IdMap schema."""
size: int
dtype: str
class EntitySchema(BaseConfig):
"""Entity schema."""
n_hot: int
id_map: IdMapSchema
features: tp.Optional[FeaturesSchema] = None
class DatasetSchema(BaseConfig):
"""Dataset schema."""
n_interactions: int
users: EntitySchema
items: EntitySchema
interactions: tp.Optional[InteractionsFeaturesSchema] = None
@attr.s(slots=True, frozen=True)
class Dataset:
"""
Container class for all data for a recommendation model.
It stores data about internal-external id mapping,
user-item interactions, user and item features
in special `rectools` structures for convenient future usage.
WARNING: It's highly not recommended to create `Dataset` object directly.
Use `construct` class method instead.
Parameters
----------
user_id_map : IdMap
User identifiers mapping.
item_id_map : IdMap
Item identifiers mapping.
interactions : Interactions
User-item interactions.
user_features : DenseFeatures or SparseFeatures, optional
User explicit features.
item_features : DenseFeatures or SparseFeatures, optional
Item explicit features.
"""
user_id_map: IdMap = attr.ib()
item_id_map: IdMap = attr.ib()
interactions: Interactions = attr.ib()
user_features: tp.Optional[Features] = attr.ib(default=None)
item_features: tp.Optional[Features] = attr.ib(default=None)
interactions_schema: tp.Optional[InteractionsFeaturesSchema] = attr.ib(default=None)
@staticmethod
def _get_feature_schema(features: tp.Optional[Features]) -> tp.Optional[FeaturesSchema]:
if features is None:
return None
if isinstance(features, SparseFeatures):
return SparseFeaturesSchema(
names=features.names,
cat_feature_indices=features.cat_feature_indices.tolist(),
cat_n_stored_values=features.get_cat_features().values.nnz,
)
return DenseFeaturesSchema(
names=features.names,
)
@staticmethod
def _get_id_map_schema(id_map: IdMap) -> IdMapSchema:
return IdMapSchema(size=id_map.size, dtype=id_map.external_dtype.str)
def get_schema(self) -> DatasetSchemaDict:
"""Get dataset schema in a dict form that contains all the information about the dataset and its statistics."""
user_schema = EntitySchema(
n_hot=self.n_hot_users,
id_map=self._get_id_map_schema(self.user_id_map),
features=self._get_feature_schema(self.user_features),
)
item_schema = EntitySchema(
n_hot=self.n_hot_items,
id_map=self._get_id_map_schema(self.item_id_map),
features=self._get_feature_schema(self.item_features),
)
schema = DatasetSchema(
n_interactions=self.interactions.df.shape[0],
users=user_schema,
items=item_schema,
interactions=self.interactions_schema,
)
return schema.model_dump(mode="json")
@property
def n_hot_users(self) -> int:
"""
Return number of hot users in dataset.
Users with internal ids from `0` to `n_hot_users - 1` are hot (they are present in interactions).
Users with internal ids from `n_hot_users` to `dataset.user_id_map.size - 1` are warm
(they aren't present in interactions, but they have features).
"""
return self.interactions.df[Columns.User].max() + 1
@property
def n_hot_items(self) -> int:
"""
Return number of hot items in dataset.
Items with internal ids from `0` to `n_hot_items - 1` are hot (they are present in interactions).
Items with internal ids from `n_hot_items` to `dataset.item_id_map.size - 1` are warm
(they aren't present in interactions, but they have features).
"""
return self.interactions.df[Columns.Item].max() + 1
def get_hot_user_features(self) -> tp.Optional[Features]:
"""User features for hot users."""
if self.user_features is None:
return None
return self.user_features.take(range(self.n_hot_users))
def get_hot_item_features(self) -> tp.Optional[Features]:
"""Item features for hot items."""
if self.item_features is None:
return None
return self.item_features.take(range(self.n_hot_items))
@classmethod
def construct( # pylint: disable=too-many-locals
cls,
interactions_df: pd.DataFrame,
user_features_df: tp.Optional[pd.DataFrame] = None,
cat_user_features: tp.Iterable[str] = (),
make_dense_user_features: bool = False,
item_features_df: tp.Optional[pd.DataFrame] = None,
cat_item_features: tp.Iterable[str] = (),
make_dense_item_features: bool = False,
keep_extra_cols: bool = False,
interactions_cat_features: tp.Iterable[str] = (),
interactions_direct_features: tp.Iterable[str] = (),
) -> "Dataset":
"""Class method for convenient `Dataset` creation.
Use it to create dataset from raw data.
Parameters
----------
interactions_df : pd.DataFrame
Table where every row contains user-item interaction and columns are:
- `Columns.User` - user id;
- `Columns.Item` - item id;
- `Columns.Weight` - weight of interaction, `float`,
use ``1`` if interactions have no weight;
- `Columns.Datetime` - timestamp of interactions,
assign random value if you're not going to use it later.
user_features_df, item_features_df : pd.DataFrame, optional
User (item) explicit features table.
It will be used to create `SparseFeatures` using `from_flatten` class method
or `DenseFeatures` using `from_dataframe` class method
depending on `make_dense_user_features` (`make_dense_item_features`) flag.
See detailed info about the table structure in these methods description.
cat_user_features, cat_item_features : tp.Iterable[str], default ``()``
List of categorical user (item) feature names for
`SparseFeatures.from_flatten` method.
Used only if `make_dense_user_features` (`make_dense_item_features`)
flag is ``False`` and `user_features_df` (`item_features_df`) is not ``None``.
make_dense_user_features, make_dense_item_features : bool, default ``False``
Create user (item) features as dense or sparse.
Used only if `user_features_df` (`item_features_df`) is not ``None``.
- if ``False``, `SparseFeatures.from_flatten` method will be used;
- if ``True``, `DenseFeatures.from_dataframe` method will be used.
keep_extra_cols: bool, default ``False``
Flag to keep all columns from interactions besides the default ones.
interactions_cat_features : tp.Iterable[str], default ``()``
List of categorical feature names in interactions dataframe.
interactions_direct_features : tp.Iterable[str], default ``()``
List of direct (non-categorical) feature names in interactions dataframe.
Returns
-------
Dataset
Container with all input data, converted to `rectools` structures.
"""
for col in (Columns.User, Columns.Item):
if col not in interactions_df:
raise KeyError(f"Column '{col}' must be present in `interactions_df`")
# Validate interactions features
cat_features = set(interactions_cat_features)
direct_features = set(interactions_direct_features)
required_columns = cat_features | direct_features
actual_columns = set(interactions_df.columns)
if not actual_columns >= required_columns:
raise KeyError(f"Missed columns {required_columns - actual_columns}")
# Create interactions feature schema
cat_feature_names_w_values = []
for cat_feature in cat_features:
values = interactions_df[cat_feature].unique() # TODO: decide NaN values
for value in values:
cat_feature_names_w_values.append((cat_feature, value))
interactions_schema = (
InteractionsFeaturesSchema(
cat_feature_names=list(cat_features),
direct_feature_names=list(direct_features),
cat_feature_names_w_values=cat_feature_names_w_values,
)
if cat_features or direct_features
else None
)
user_id_map = IdMap.from_values(interactions_df[Columns.User].values)
item_id_map = IdMap.from_values(interactions_df[Columns.Item].values)
interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map, keep_extra_cols)
user_features, user_id_map = cls._make_features(
user_features_df,
cat_user_features,
make_dense_user_features,
user_id_map,
Columns.User,
"user",
)
item_features, item_id_map = cls._make_features(
item_features_df,
cat_item_features,
make_dense_item_features,
item_id_map,
Columns.Item,
"item",
)
return cls(
user_id_map=user_id_map,
item_id_map=item_id_map,
interactions=interactions,
user_features=user_features,
item_features=item_features,
interactions_schema=interactions_schema,
)
@staticmethod
def _make_features(
df: tp.Optional[pd.DataFrame],
cat_features: tp.Iterable[str],
make_dense: bool,
id_map: IdMap,
possible_id_col: str,
feature_type: str,
) -> tp.Tuple[tp.Optional[Features], IdMap]:
if df is None:
return None, id_map
id_col = possible_id_col if possible_id_col in df else "id"
id_map = id_map.add_ids(df[id_col].values, raise_if_already_present=False)
if make_dense:
try:
return DenseFeatures.from_dataframe(df, id_map, id_col=id_col), id_map
except AbsentIdError:
raise ValueError(
f"An error has occurred while constructing {feature_type} features: "
"When using dense features all ids from interactions must be present in features table"
)
except Exception as e: # pragma: no cover
raise RuntimeError(f"An error has occurred while constructing {feature_type} features: {e!r}")
try:
return SparseFeatures.from_flatten(df, id_map, cat_features, id_col=id_col), id_map
except Exception as e: # pragma: no cover
raise RuntimeError(f"An error has occurred while constructing {feature_type} features: {e!r}")
def get_user_item_matrix(
self,
include_weights: bool = True,
include_warm_users: bool = False,
include_warm_items: bool = False,
dtype: tp.Type = np.float32,
) -> sparse.csr_matrix:
"""
Construct user-item CSR matrix based on `interactions` attribute.
Return a resized user-item matrix.
Resizing is done using `user_id_map` and `item_id_map`,
hence if either a user or an item is not presented in interactions,
but presented in id map, then it's going to be in the returned matrix.
Parameters
----------
include_weights : bool, default ``True``
Whether include interaction weights in matrix or not.
If False, all values in returned matrix will be equal to ``1``.
include_warm : bool, default ``False``
Whether to include warm users and items into the matrix or not.
Rows and columns for warm users and items will be added to the end of matrix,
they will contain only zeros.
Returns
-------
csr_matrix
Resized user-item CSR matrix
"""
matrix = self.interactions.get_user_item_matrix(include_weights, dtype)
n_rows = self.user_id_map.size if include_warm_users else matrix.shape[0]
n_columns = self.item_id_map.size if include_warm_items else matrix.shape[1]
matrix.resize(n_rows, n_columns)
return matrix
def get_raw_interactions(
self, include_weight: bool = True, include_datetime: bool = True, include_extra_cols: bool = True
) -> pd.DataFrame:
"""
Return interactions as a `pd.DataFrame` object with replacing internal user and item ids to external ones.
Parameters
----------
include_weight : bool, default ``True``
Whether to include weight column into resulting table or not.
include_datetime : bool, default ``True``
Whether to include datetime column into resulting table or not.
include_extra_cols: bool, default ``True``
Whether to include extra columns into resulting table or not.
Returns
-------
pd.DataFrame
"""
return self.interactions.to_external(
self.user_id_map, self.item_id_map, include_weight, include_datetime, include_extra_cols
)
def filter_interactions(
self,
row_indexes_to_keep: np.ndarray,
keep_external_ids: bool = True,
keep_features_for_removed_entities: bool = True,
) -> "Dataset":
"""
Generate filtered dataset that contains only provided `row_indexes_to_keep` from original
dataset interactions dataframe.
Resulting dataset will get new id mapping for both users and items.
Parameters
----------
row_indexes_to_keep : np.ndarray
Original dataset interactions df row indexes that are to be kept
keep_external_ids : bool, default `True`
Whether to keep external ids -> 2x internal ids mapping (default).
Otherwise internal -> 2x internal ids mapping will be created.
keep_features_for_removed_entities : bool, default `True`
Whether to keep all features for users and items that are not hot any more.
Returns
-------
Dataset
Filtered dataset that has only selected interactions, new ids mapping and processed features.
"""
interactions_df = self.interactions.df.iloc[row_indexes_to_keep]
# 1x internal -> 2x internal
user_id_map = IdMap.from_values(interactions_df[Columns.User].values)
item_id_map = IdMap.from_values(interactions_df[Columns.Item].values)
# We shouldn't drop extra columns if they are present
interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map, keep_extra_cols=True)
def _handle_features(
features: tp.Optional[Features], target_id_map: IdMap, dataset_id_map: IdMap
) -> tp.Tuple[tp.Optional[Features], IdMap]:
if features is None:
return None, target_id_map
if keep_features_for_removed_entities:
all_features_ids = np.arange(len(features))
target_id_map = target_id_map.add_ids(all_features_ids, raise_if_already_present=False)
needed_ids = target_id_map.get_external_sorted_by_internal()
features = features.take(needed_ids)
return features, target_id_map
user_features_new, user_id_map = _handle_features(self.user_features, user_id_map, self.user_id_map)
item_features_new, item_id_map = _handle_features(self.item_features, item_id_map, self.item_id_map)
if keep_external_ids: # external -> 2x internal
user_id_map = IdMap(self.user_id_map.convert_to_external(user_id_map.external_ids))
item_id_map = IdMap(self.item_id_map.convert_to_external(item_id_map.external_ids))
filtered_dataset = Dataset(
user_id_map=user_id_map,
item_id_map=item_id_map,
interactions=interactions,
user_features=user_features_new,
item_features=item_features_new,
)
return filtered_dataset