File tree Expand file tree Collapse file tree
bindings/pyroot/pythonizations/test Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -4570,6 +4570,44 @@ def test16_vector_padding(self):
45704570 self .teardown_file (self .file_name5 )
45714571 raise
45724572
4573+ def test17_shuffled_split_varies_with_seed (self ):
4574+ self .create_file1 ()
4575+ self .create_file2 ()
4576+
4577+ try :
4578+ df1 = ROOT .RDataFrame (self .tree_name , self .file_name1 )
4579+ df2 = ROOT .RDataFrame (self .tree_name , self .file_name2 )
4580+
4581+ dl1 = ROOT .Experimental .ML .RDataLoader (
4582+ [df1 , df2 ],
4583+ batch_size = 3 ,
4584+ target = "b2" ,
4585+ shuffle = True ,
4586+ drop_remainder = False ,
4587+ set_seed = 42 ,
4588+ )
4589+
4590+ dl2 = ROOT .Experimental .ML .RDataLoader (
4591+ [df1 , df2 ],
4592+ batch_size = 3 ,
4593+ target = "b2" ,
4594+ shuffle = True ,
4595+ drop_remainder = False ,
4596+ set_seed = 43 ,
4597+ )
4598+
4599+ _ , gen_val1 = dl1 .train_test_split (0.4 )
4600+ _ , gen_val2 = dl2 .train_test_split (0.4 )
4601+
4602+ val_1_collected = sorted ([v for x , y in gen_val1 .as_numpy () for v in x .flatten ().tolist ()])
4603+ val_2_collected = sorted ([v for x , y in gen_val2 .as_numpy () for v in x .flatten ().tolist ()])
4604+
4605+ self .assertNotEqual (val_1_collected , val_2_collected )
4606+
4607+ finally :
4608+ self .teardown_file (self .file_name1 )
4609+ self .teardown_file (self .file_name2 )
4610+
45734611
45744612class DataLoaderRandomUndersampling (unittest .TestCase ):
45754613 file_name1 = "major.root"
You can’t perform that action at this time.
0 commit comments