Skip to content

Commit 735e2d7

Browse files
committed
new splitter & test
1 parent ea266cd commit 735e2d7

2 files changed

Lines changed: 62 additions & 5 deletions

File tree

rectools/model_selection/last_n_split.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 MTS (Mobile Telesystems)
1+
# Copyright 2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,13 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""LastNSplitter."""
1615

1716
import typing as tp
1817

1918
import numpy as np
2019
import pandas as pd
21-
2220
from rectools import Columns
2321
from rectools.dataset import Interactions
2422
from rectools.model_selection.splitter import Splitter
@@ -103,8 +101,14 @@ def _split_without_filter(
103101
df = interactions.df
104102
idx = pd.RangeIndex(0, len(df))
105103

106-
# last event - rank=1
107-
inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False)
104+
# Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique
105+
time_order = (
106+
df.groupby(Columns.User)[Columns.Datetime]
107+
.rank(method="first", ascending=True)
108+
.astype(int)
109+
)
110+
n_interactions = df.groupby(Columns.User).transform("size").astype(int)
111+
inv_ranks = n_interactions - time_order + 1
108112

109113
for i_split in range(self.n_splits)[::-1]:
110114
min_rank = i_split * self.n # excluded

tests/model_selection/test_last_n_split.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,59 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:
4040
return sorted(inv_shuffle_arr[values])
4141

4242
return _shuffle
43+
@pytest.fixture
44+
def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions:
45+
df = pd.DataFrame(
46+
[
47+
[1, 1, 1, "2021-09-01"], # 0
48+
[1, 2, 1, "2021-09-02"], # 1
49+
[1, 1, 1, "2021-09-03"], # 2
50+
[1, 2, 1, "2021-09-04"], # 3
51+
[1, 3, 1, "2021-09-05"], # 4
52+
[2, 3, 1, "2021-09-05"], # 5
53+
[2, 2, 1, "2021-08-20"], # 6
54+
[2, 2, 1, "2021-09-06"], # 7
55+
[3, 1, 1, "2021-09-05"], # 8
56+
[1, 6, 1, "2021-09-05"], # 9
57+
],
58+
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
59+
).astype({Columns.Datetime: "datetime64[ns]"})
60+
return Interactions(df)
61+
62+
@pytest.mark.parametrize(
63+
"swap_targets,expected_test_ids, target_item",
64+
(
65+
(
66+
False,
67+
{9, 7, 8},
68+
6
69+
),
70+
(
71+
True,
72+
{9, 7, 8},
73+
3
74+
),
75+
),
76+
)
77+
def test_correct_last_interactions(
78+
self,
79+
interactions_equal_timestamps: Interactions,
80+
swap_targets: bool,
81+
expected_test_ids: tp.List[int],
82+
target_item: int,
83+
) -> None:
84+
# Do not using shuffle fixture, otherwise no valid answers
85+
interactions_et = interactions_equal_timestamps
86+
splitter = LastNSplitter(1, 1, False, False, False)
87+
if swap_targets:
88+
df_swap = interactions_equal_timestamps.df
89+
df_swap.iloc[[4,9]] = df_swap.iloc[[9,4]]
90+
interactions_et = Interactions(df_swap)
91+
loo_split = list(splitter.split(interactions_et, collect_fold_stats=True))
92+
target_ids = loo_split[0][1]
93+
assert set(target_ids) == expected_test_ids
94+
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
95+
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])
4396

4497
@pytest.fixture
4598
def interactions(self, shuffle_arr: np.ndarray) -> Interactions:

0 commit comments

Comments
 (0)