Skip to content

Commit 2ee023a

Browse files
committed
+ linter
1 parent 735e2d7 commit 2ee023a

3 files changed

Lines changed: 10 additions & 20 deletions

File tree

rectools/model_selection/last_n_split.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import pandas as pd
20+
2021
from rectools import Columns
2122
from rectools.dataset import Interactions
2223
from rectools.model_selection.splitter import Splitter
@@ -102,11 +103,7 @@ def _split_without_filter(
102103
idx = pd.RangeIndex(0, len(df))
103104

104105
# Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique
105-
time_order = (
106-
df.groupby(Columns.User)[Columns.Datetime]
107-
.rank(method="first", ascending=True)
108-
.astype(int)
109-
)
106+
time_order = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=True).astype(int)
110107
n_interactions = df.groupby(Columns.User).transform("size").astype(int)
111108
inv_ranks = n_interactions - time_order + 1
112109

rectools/models/nn/item_net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
486486
@property
487487
def out_dim(self) -> int:
488488
"""Return item net constructor output dimension."""
489-
return self.item_net_blocks[0].out_dim # type: ignore[return-value]
489+
return self.item_net_blocks[0].out_dim

tests/model_selection/test_last_n_split.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:
4040
return sorted(inv_shuffle_arr[values])
4141

4242
return _shuffle
43+
4344
@pytest.fixture
4445
def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions:
4546
df = pd.DataFrame(
@@ -62,37 +63,29 @@ def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions
6263
@pytest.mark.parametrize(
6364
"swap_targets,expected_test_ids, target_item",
6465
(
65-
(
66-
False,
67-
{9, 7, 8},
68-
6
69-
),
70-
(
71-
True,
72-
{9, 7, 8},
73-
3
74-
),
66+
(False, {9, 7, 8}, 6),
67+
(True, {9, 7, 8}, 3),
7568
),
7669
)
7770
def test_correct_last_interactions(
7871
self,
7972
interactions_equal_timestamps: Interactions,
8073
swap_targets: bool,
81-
expected_test_ids: tp.List[int],
74+
expected_test_ids: tp.Set[int],
8275
target_item: int,
8376
) -> None:
8477
# Do not using shuffle fixture, otherwise no valid answers
8578
interactions_et = interactions_equal_timestamps
8679
splitter = LastNSplitter(1, 1, False, False, False)
8780
if swap_targets:
8881
df_swap = interactions_equal_timestamps.df
89-
df_swap.iloc[[4,9]] = df_swap.iloc[[9,4]]
82+
df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]]
9083
interactions_et = Interactions(df_swap)
9184
loo_split = list(splitter.split(interactions_et, collect_fold_stats=True))
9285
target_ids = loo_split[0][1]
9386
assert set(target_ids) == expected_test_ids
94-
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
95-
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])
87+
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
88+
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])
9689

9790
@pytest.fixture
9891
def interactions(self, shuffle_arr: np.ndarray) -> Interactions:

0 commit comments

Comments
 (0)