Skip to content

Commit 1df9fc1

Browse files
committed
[Python][ML] Add test for the jax output format
1 parent 47dd04d commit 1df9fc1

1 file changed

Lines changed: 85 additions & 0 deletions

File tree

bindings/pyroot/pythonizations/test/ml_dataloader.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,91 @@ def test16_vector_padding(self):
11161116
self.teardown_file(self.file_name3)
11171117
raise
11181118

1119+
def test17_JAX(self):
1120+
file_name = "multiple_target_columns.root"
1121+
1122+
ROOT.RDataFrame(10).Define("b1", "(Short_t) rdfentry_").Define("b2", "(UShort_t) b1 * b1").Define(
1123+
"b3", "(double) rdfentry_ * 10"
1124+
).Define("b4", "(double) b3 * 10").Snapshot("myTree", file_name)
1125+
1126+
try:
1127+
df = ROOT.RDataFrame("myTree", file_name)
1128+
1129+
dl = ROOT.Experimental.ML.RDataLoader(
1130+
df,
1131+
batch_size=3,
1132+
batches_in_memory=2,
1133+
target=["b2", "b4"],
1134+
weights="b3",
1135+
shuffle=False,
1136+
drop_remainder=False,
1137+
)
1138+
1139+
gen_train, gen_validation = dl.train_test_split(0.4)
1140+
1141+
results_x_train = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
1142+
results_x_val = [6.0, 7.0, 8.0, 9.0]
1143+
results_y_train = [0.0, 0.0, 1.0, 100.0, 4.0, 200.0, 9.0, 300.0, 16.0, 400.0, 25.0, 500.0]
1144+
results_y_val = [36.0, 600.0, 49.0, 700.0, 64.0, 800.0, 81.0, 900.0]
1145+
results_z_train = [0.0, 10.0, 20.0, 30.0, 40.0, 50.0]
1146+
results_z_val = [60.0, 70.0, 80.0, 90.0]
1147+
1148+
collected_x_train = []
1149+
collected_x_val = []
1150+
collected_y_train = []
1151+
collected_y_val = []
1152+
collected_z_train = []
1153+
collected_z_val = []
1154+
1155+
iter_train = iter(gen_train.as_jax(device="cpu"))
1156+
iter_val = iter(gen_validation.as_jax())
1157+
1158+
for _ in range(self.n_train_batch):
1159+
x, y, z = next(iter_train)
1160+
self.assertTrue(x.shape == (3, 1))
1161+
self.assertTrue(y.shape == (3, 2))
1162+
self.assertTrue(z.shape == (3, 1))
1163+
collected_x_train.append(x.tolist())
1164+
collected_y_train.append(y.tolist())
1165+
collected_z_train.append(z.tolist())
1166+
1167+
for _ in range(self.n_val_batch):
1168+
x, y, z = next(iter_val)
1169+
self.assertTrue(x.shape == (3, 1))
1170+
self.assertTrue(y.shape == (3, 2))
1171+
self.assertTrue(z.shape == (3, 1))
1172+
collected_x_val.append(x.tolist())
1173+
collected_y_val.append(y.tolist())
1174+
collected_z_val.append(z.tolist())
1175+
1176+
x, y, z = next(iter_val)
1177+
self.assertTrue(x.shape == (self.val_remainder, 1))
1178+
self.assertTrue(y.shape == (self.val_remainder, 2))
1179+
self.assertTrue(z.shape == (self.val_remainder, 1))
1180+
collected_x_val.append(x.tolist())
1181+
collected_y_val.append(y.tolist())
1182+
collected_z_val.append(z.tolist())
1183+
1184+
flat_x_train = [x for xl in collected_x_train for xs in xl for x in xs]
1185+
flat_x_val = [x for xl in collected_x_val for xs in xl for x in xs]
1186+
flat_y_train = [y for yl in collected_y_train for ys in yl for y in ys]
1187+
flat_y_val = [y for yl in collected_y_val for ys in yl for y in ys]
1188+
flat_z_train = [z for zl in collected_z_train for zs in zl for z in zs]
1189+
flat_z_val = [z for zl in collected_z_val for zs in zl for z in zs]
1190+
1191+
self.assertEqual(results_x_train, flat_x_train)
1192+
self.assertEqual(results_x_val, flat_x_val)
1193+
self.assertEqual(results_y_train, flat_y_train)
1194+
self.assertEqual(results_y_val, flat_y_val)
1195+
self.assertEqual(results_z_train, flat_z_train)
1196+
self.assertEqual(results_z_val, flat_z_val)
1197+
1198+
self.teardown_file(file_name)
1199+
1200+
except:
1201+
self.teardown_file(file_name)
1202+
raise
1203+
11191204

11201205
class DataLoaderEagerLoading(unittest.TestCase):
11211206
file_name1 = "first_half.root"

0 commit comments

Comments
 (0)