@@ -40,6 +40,7 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:
4040 return sorted (inv_shuffle_arr [values ])
4141
4242 return _shuffle
43+
4344 @pytest .fixture
4445 def interactions_equal_timestamps (self , shuffle_arr : np .ndarray ) -> Interactions :
4546 df = pd .DataFrame (
@@ -62,37 +63,29 @@ def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions
6263 @pytest .mark .parametrize (
6364 "swap_targets,expected_test_ids, target_item" ,
6465 (
65- (
66- False ,
67- {9 , 7 , 8 },
68- 6
69- ),
70- (
71- True ,
72- {9 , 7 , 8 },
73- 3
74- ),
66+ (False , {9 , 7 , 8 }, 6 ),
67+ (True , {9 , 7 , 8 }, 3 ),
7568 ),
7669 )
7770 def test_correct_last_interactions (
7871 self ,
7972 interactions_equal_timestamps : Interactions ,
8073 swap_targets : bool ,
81- expected_test_ids : tp .List [int ],
74+ expected_test_ids : tp .Set [int ],
8275 target_item : int ,
8376 ) -> None :
8477 # Do not using shuffle fixture, otherwise no valid answers
8578 interactions_et = interactions_equal_timestamps
8679 splitter = LastNSplitter (1 , 1 , False , False , False )
8780 if swap_targets :
8881 df_swap = interactions_equal_timestamps .df
89- df_swap .iloc [[4 ,9 ]] = df_swap .iloc [[9 ,4 ]]
82+ df_swap .iloc [[4 , 9 ]] = df_swap .iloc [[9 , 4 ]]
9083 interactions_et = Interactions (df_swap )
9184 loo_split = list (splitter .split (interactions_et , collect_fold_stats = True ))
9285 target_ids = loo_split [0 ][1 ]
9386 assert set (target_ids ) == expected_test_ids
94- assert set (loo_split [0 ][0 ]) == set (range (len (interactions_et .df ))) - expected_test_ids
95- assert target_item in set (interactions_et .df .iloc [target_ids ][Columns .Item ])
87+ assert set (loo_split [0 ][0 ]) == set (range (len (interactions_et .df ))) - expected_test_ids
88+ assert target_item in set (interactions_et .df .iloc [target_ids ][Columns .Item ])
9689
9790 @pytest .fixture
9891 def interactions (self , shuffle_arr : np .ndarray ) -> Interactions :
0 commit comments