-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[ML] ROOT.Experimental.ML.CreatePyTorchGenerators only uses the first dataframe #21782
Description
Check duplicate issues.
- Checked for duplicates
Description
I'm currently trying to implement a simple NN training with RDF as input using ROOT.Experimental.ML.CreatePyTorchGenerators as shown in the tutorial. I see that the function takes in RDFs as a list, so I try to pass two RDFs to it, one with signal events and one with background events, but during training it seems that only the first RDF is used.
I built a reproducer for this based on the tutorial. Let me know if there's a mistake on my end, or if this actually a bug. The issue is somewhat critical, since often signal and background samples are saved in separate files, which then requires loading them as separate RDFs.
Thanks for the help!
Reproducer
import ROOT
import torch
import shutil
batch_size = 128
chunk_size = 5000
block_size = 300
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
# Copy the file to the current directory to emulate having two files
shutil.copy(file_name, "./Higgs_data_bckg_dummy.root")
print("Creating signal RDF")
rdf_sig = ROOT.RDataFrame("sig_tree", file_name)
print("Creating background RDF")
rdf_bckg = ROOT.RDataFrame("bkg_tree", "./Higgs_data_bckg_dummy.root")
target = "Type"
print("Creating PyTorch generators")
gen_train, gen_validation = ROOT.Experimental.ML.CreatePyTorchGenerators(
[rdf_bckg, rdf_sig],
batch_size,
chunk_size,
block_size,
target=target,
validation_split=0.3,
shuffle=True,
drop_remainder=True,
)
input_columns = gen_train.train_columns
num_features = len(input_columns)
print(f"Input columns: {input_columns}, number of features: {num_features}")
def calc_accuracy(targets, pred):
return torch.sum(targets == pred.round()) / pred.size(0)
# Initialize PyTorch model
model = torch.nn.Sequential(
torch.nn.Linear(num_features, 300),
torch.nn.Tanh(),
torch.nn.Linear(300, 300),
torch.nn.Tanh(),
torch.nn.Linear(300, 300),
torch.nn.Tanh(),
torch.nn.Linear(300, 1),
torch.nn.Sigmoid(),
)
loss_fn = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
number_of_epochs = 2
number_of_bckg_events = 0
number_of_sig_events = 0
for i in range(number_of_epochs):
print("Epoch ", i)
print("Training")
model.train()
# Loop through the training set and train model
for i, (x_train, y_train) in enumerate(gen_train):
# Make prediction and calculate loss
pred = model(x_train)
loss = loss_fn(pred, y_train)
print(f"Batch {i}, Loss: {loss.item()}")
print(f"Predictions: {pred[:5].squeeze().tolist()}, Targets: {y_train[:5].squeeze().tolist()}")
number_of_bckg_events += (y_train == 0).sum().item()
number_of_sig_events += (y_train == 1).sum().item()
print(f"Number of background events in this batch: {(y_train == 0).sum().item()}")
print(f"Number of signal events in this batch: {(y_train == 1).sum().item()}")
print(f"Total number of background events so far: {number_of_bckg_events}")
print(f"Total number of signal events so far: {number_of_sig_events}")
print(f"Number of signal / background events ratio in this batch: {(y_train == 1).sum().item() / (y_train == 0).sum().item() if (y_train == 0).sum().item() > 0 else 'N/A'}")
# improve model
model.zero_grad()
loss.backward()
optimizer.step()
# Calculate accuracy
accuracy = calc_accuracy(y_train, pred)
print(f"Training => accuracy: {accuracy}")
# #################################################################
# # Validation
# #################################################################
model.eval()
# Evaluate the model on the validation set
for i, (x_val, y_val) in enumerate(gen_validation):
# Make prediction and calculate accuracy
pred = model(x_val)
accuracy = calc_accuracy(y_val, pred)
print(f"Validation => accuracy: {accuracy}")
# These are double counted over the epochs, but good enough for the demonstration
print(f"Total number of background events: {number_of_bckg_events}")
print(f"Total number of signal events: {number_of_sig_events}")
print(f"Overall number of signal / background events ratio: {number_of_sig_events / number_of_bckg_events if number_of_bckg_events > 0 else 'N/A'}")
ROOT version
ROOT Version: 6.39.01
Built for linuxx8664gcc on Apr 01 2026, 22:52:11
From heads/master@v6-39-01-1692-g04365d12bb4
Installation method
LCG Nightlies
Operating system
EL9
Additional context
No response