Skip to content

[ML] ROOT.Experimental.ML.CreatePyTorchGenerators only uses the first dataframe #21782

@toicca

Description

@toicca

Check duplicate issues.

  • Checked for duplicates

Description

I'm currently trying to implement a simple NN training with RDF as input using ROOT.Experimental.ML.CreatePyTorchGenerators as shown in the tutorial. I see that the function takes in RDFs as a list, so I try to pass two RDFs to it, one with signal events and one with background events, but during training it seems that only the first RDF is used.

I built a reproducer for this based on the tutorial. Let me know if there's a mistake on my end, or if this actually a bug. The issue is somewhat critical, since often signal and background samples are saved in separate files, which then requires loading them as separate RDFs.

Thanks for the help!

Reproducer

import ROOT
import torch
import shutil
 
batch_size = 128
chunk_size = 5000
block_size = 300

file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"

# Copy the file to the current directory to emulate having two files
shutil.copy(file_name, "./Higgs_data_bckg_dummy.root")

print("Creating signal RDF")
rdf_sig = ROOT.RDataFrame("sig_tree", file_name)

print("Creating background RDF")
rdf_bckg = ROOT.RDataFrame("bkg_tree", "./Higgs_data_bckg_dummy.root")

target = "Type"
 
print("Creating PyTorch generators")
gen_train, gen_validation = ROOT.Experimental.ML.CreatePyTorchGenerators(
    [rdf_bckg, rdf_sig],
    batch_size,
    chunk_size,
    block_size,
    target=target,
    validation_split=0.3,
    shuffle=True,
    drop_remainder=True,
)
 
input_columns = gen_train.train_columns 
num_features = len(input_columns)
print(f"Input columns: {input_columns}, number of features: {num_features}")
 
 
def calc_accuracy(targets, pred):
    return torch.sum(targets == pred.round()) / pred.size(0)
 
 
# Initialize PyTorch model
model = torch.nn.Sequential(
    torch.nn.Linear(num_features, 300),
    torch.nn.Tanh(),
    torch.nn.Linear(300, 300),
    torch.nn.Tanh(),
    torch.nn.Linear(300, 300),
    torch.nn.Tanh(),
    torch.nn.Linear(300, 1),
    torch.nn.Sigmoid(),
)
loss_fn = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
 
number_of_epochs = 2
 
number_of_bckg_events = 0
number_of_sig_events = 0
for i in range(number_of_epochs):
    print("Epoch ", i)
    print("Training")
    model.train()

    # Loop through the training set and train model
    for i, (x_train, y_train) in enumerate(gen_train):
        # Make prediction and calculate loss
        pred = model(x_train)
        loss = loss_fn(pred, y_train)
        print(f"Batch {i}, Loss: {loss.item()}")
        print(f"Predictions: {pred[:5].squeeze().tolist()}, Targets: {y_train[:5].squeeze().tolist()}")
        number_of_bckg_events += (y_train == 0).sum().item()
        number_of_sig_events += (y_train == 1).sum().item()
        print(f"Number of background events in this batch: {(y_train == 0).sum().item()}")
        print(f"Number of signal events in this batch: {(y_train == 1).sum().item()}")
        print(f"Total number of background events so far: {number_of_bckg_events}")
        print(f"Total number of signal events so far: {number_of_sig_events}")
        print(f"Number of signal / background events ratio in this batch: {(y_train == 1).sum().item() / (y_train == 0).sum().item() if (y_train == 0).sum().item() > 0 else 'N/A'}")
 
        # improve model
        model.zero_grad()
        loss.backward()
        optimizer.step()
 
        # Calculate accuracy
        accuracy = calc_accuracy(y_train, pred)
 
        print(f"Training => accuracy: {accuracy}")
 
    # #################################################################
    # # Validation
    # #################################################################
 
    model.eval()
    # Evaluate the model on the validation set
    for i, (x_val, y_val) in enumerate(gen_validation):
        # Make prediction and calculate accuracy
        pred = model(x_val)
        accuracy = calc_accuracy(y_val, pred)
 
        print(f"Validation => accuracy: {accuracy}")

# These are double counted over the epochs, but good enough for the demonstration
print(f"Total number of background events: {number_of_bckg_events}")
print(f"Total number of signal events: {number_of_sig_events}")
print(f"Overall number of signal / background events ratio: {number_of_sig_events / number_of_bckg_events if number_of_bckg_events > 0 else 'N/A'}")

ROOT version

ROOT Version: 6.39.01
Built for linuxx8664gcc on Apr 01 2026, 22:52:11
From heads/master@v6-39-01-1692-g04365d12bb4

Installation method

LCG Nightlies

Operating system

EL9

Additional context

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions