@@ -1116,6 +1116,91 @@ def test16_vector_padding(self):
11161116 self .teardown_file (self .file_name3 )
11171117 raise
11181118
1119+ def test17_JAX (self ):
1120+ file_name = "multiple_target_columns.root"
1121+
1122+ ROOT .RDataFrame (10 ).Define ("b1" , "(Short_t) rdfentry_" ).Define ("b2" , "(UShort_t) b1 * b1" ).Define (
1123+ "b3" , "(double) rdfentry_ * 10"
1124+ ).Define ("b4" , "(double) b3 * 10" ).Snapshot ("myTree" , file_name )
1125+
1126+ try :
1127+ df = ROOT .RDataFrame ("myTree" , file_name )
1128+
1129+ dl = ROOT .Experimental .ML .RDataLoader (
1130+ df ,
1131+ batch_size = 3 ,
1132+ batches_in_memory = 2 ,
1133+ target = ["b2" , "b4" ],
1134+ weights = "b3" ,
1135+ shuffle = False ,
1136+ drop_remainder = False ,
1137+ )
1138+
1139+ gen_train , gen_validation = dl .train_test_split (0.4 )
1140+
1141+ results_x_train = [0.0 , 1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]
1142+ results_x_val = [6.0 , 7.0 , 8.0 , 9.0 ]
1143+ results_y_train = [0.0 , 0.0 , 1.0 , 100.0 , 4.0 , 200.0 , 9.0 , 300.0 , 16.0 , 400.0 , 25.0 , 500.0 ]
1144+ results_y_val = [36.0 , 600.0 , 49.0 , 700.0 , 64.0 , 800.0 , 81.0 , 900.0 ]
1145+ results_z_train = [0.0 , 10.0 , 20.0 , 30.0 , 40.0 , 50.0 ]
1146+ results_z_val = [60.0 , 70.0 , 80.0 , 90.0 ]
1147+
1148+ collected_x_train = []
1149+ collected_x_val = []
1150+ collected_y_train = []
1151+ collected_y_val = []
1152+ collected_z_train = []
1153+ collected_z_val = []
1154+
1155+ iter_train = iter (gen_train .as_jax (device = "cpu" ))
1156+ iter_val = iter (gen_validation .as_jax ())
1157+
1158+ for _ in range (self .n_train_batch ):
1159+ x , y , z = next (iter_train )
1160+ self .assertTrue (x .shape == (3 , 1 ))
1161+ self .assertTrue (y .shape == (3 , 2 ))
1162+ self .assertTrue (z .shape == (3 , 1 ))
1163+ collected_x_train .append (x .tolist ())
1164+ collected_y_train .append (y .tolist ())
1165+ collected_z_train .append (z .tolist ())
1166+
1167+ for _ in range (self .n_val_batch ):
1168+ x , y , z = next (iter_val )
1169+ self .assertTrue (x .shape == (3 , 1 ))
1170+ self .assertTrue (y .shape == (3 , 2 ))
1171+ self .assertTrue (z .shape == (3 , 1 ))
1172+ collected_x_val .append (x .tolist ())
1173+ collected_y_val .append (y .tolist ())
1174+ collected_z_val .append (z .tolist ())
1175+
1176+ x , y , z = next (iter_val )
1177+ self .assertTrue (x .shape == (self .val_remainder , 1 ))
1178+ self .assertTrue (y .shape == (self .val_remainder , 2 ))
1179+ self .assertTrue (z .shape == (self .val_remainder , 1 ))
1180+ collected_x_val .append (x .tolist ())
1181+ collected_y_val .append (y .tolist ())
1182+ collected_z_val .append (z .tolist ())
1183+
1184+ flat_x_train = [x for xl in collected_x_train for xs in xl for x in xs ]
1185+ flat_x_val = [x for xl in collected_x_val for xs in xl for x in xs ]
1186+ flat_y_train = [y for yl in collected_y_train for ys in yl for y in ys ]
1187+ flat_y_val = [y for yl in collected_y_val for ys in yl for y in ys ]
1188+ flat_z_train = [z for zl in collected_z_train for zs in zl for z in zs ]
1189+ flat_z_val = [z for zl in collected_z_val for zs in zl for z in zs ]
1190+
1191+ self .assertEqual (results_x_train , flat_x_train )
1192+ self .assertEqual (results_x_val , flat_x_val )
1193+ self .assertEqual (results_y_train , flat_y_train )
1194+ self .assertEqual (results_y_val , flat_y_val )
1195+ self .assertEqual (results_z_train , flat_z_train )
1196+ self .assertEqual (results_z_val , flat_z_val )
1197+
1198+ self .teardown_file (file_name )
1199+
1200+ except :
1201+ self .teardown_file (file_name )
1202+ raise
1203+
11191204
11201205class DataLoaderEagerLoading (unittest .TestCase ):
11211206 file_name1 = "first_half.root"
0 commit comments