Skip to content

Commit 87db451

Browse files
committed
[ML] Add tests for random train/val splitting
1 parent 64bbbd7 commit 87db451

1 file changed

Lines changed: 38 additions & 0 deletions

File tree

bindings/pyroot/pythonizations/test/ml_dataloader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4570,6 +4570,44 @@ def test16_vector_padding(self):
45704570
self.teardown_file(self.file_name5)
45714571
raise
45724572

4573+
def test17_shuffled_split_varies_with_seed(self):
4574+
self.create_file1()
4575+
self.create_file2()
4576+
4577+
try:
4578+
df1 = ROOT.RDataFrame(self.tree_name, self.file_name1)
4579+
df2 = ROOT.RDataFrame(self.tree_name, self.file_name2)
4580+
4581+
dl1 = ROOT.Experimental.ML.RDataLoader(
4582+
[df1, df2],
4583+
batch_size=3,
4584+
target="b2",
4585+
shuffle=True,
4586+
drop_remainder=False,
4587+
set_seed=42,
4588+
)
4589+
4590+
dl2 = ROOT.Experimental.ML.RDataLoader(
4591+
[df1, df2],
4592+
batch_size=3,
4593+
target="b2",
4594+
shuffle=True,
4595+
drop_remainder=False,
4596+
set_seed=43,
4597+
)
4598+
4599+
_, gen_val1 = dl1.train_test_split(0.4)
4600+
_, gen_val2 = dl2.train_test_split(0.4)
4601+
4602+
val_1_collected = sorted([v for x, y in gen_val1.as_numpy() for v in x.flatten().tolist()])
4603+
val_2_collected = sorted([v for x, y in gen_val2.as_numpy() for v in x.flatten().tolist()])
4604+
4605+
self.assertNotEqual(val_1_collected, val_2_collected)
4606+
4607+
finally:
4608+
self.teardown_file(self.file_name1)
4609+
self.teardown_file(self.file_name2)
4610+
45734611

45744612
class DataLoaderRandomUndersampling(unittest.TestCase):
45754613
file_name1 = "major.root"

0 commit comments

Comments
 (0)