Skip to content

Commit e5f4f45

Browse files
sbeinnsmith-
andauthored
Fix lwtnn schema validation and add FastSim scale factor example (#343)
* Accept LWTNN nodes in schema validation * Simplify tests / examples --------- Co-authored-by: Nick Smith <nick.smith@cern.ch>
1 parent 309971a commit e5f4f45

4 files changed

Lines changed: 777 additions & 591 deletions

File tree

src/correctionlib/schemav2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def validate_content(cls, content: list[CategoryItem]) -> list[CategoryItem]:
319319
def walk_content(content: Content, func: Callable[[Content], None]) -> None:
320320
"""Visit all content nodes in a tree, applying func to each node."""
321321
func(content)
322-
if isinstance(content, (float, Formula, FormulaRef, HashPRNG)):
322+
if isinstance(content, (float, Formula, FormulaRef, HashPRNG, LWTNN)):
323323
pass
324324
elif isinstance(content, (Binning, MultiBinning)):
325325
for bin in content.content:
@@ -354,6 +354,12 @@ def _validate_input(allowed_names: set[str], node: Content) -> None:
354354
if inp not in allowed_names:
355355
msg = f"{nodename} input {inp!r} not found in Correction inputs {allowed_names}"
356356
raise ValueError(msg)
357+
elif isinstance(node, LWTNN):
358+
for inp in node.opaque.get("inputs", []):
359+
name = inp.get("name")
360+
if name not in allowed_names:
361+
msg = f"{nodename} input {name!r} not found in Correction inputs {allowed_names}"
362+
raise ValueError(msg)
357363
# FormulaRef has no direct input names
358364

359365

src/lwtnn_demo.cc

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,25 @@
44

55
#include "correction.h"
66

7-
static constexpr double PT_EPS = 1e-4;
8-
static constexpr double ISO_EPS = 1e-6;
9-
10-
static double safe_log10(double x, double eps) {
11-
return std::log10(std::max(x, eps));
12-
}
13-
14-
int main(int argc, char** argv) {
15-
if (argc != 2) {
7+
int main(int argc, char **argv)
8+
{
9+
if (argc != 2)
10+
{
1611
std::cerr << "Usage: " << argv[0] << " lwtnn_correction.json\n";
1712
return 1;
1813
}
1914

2015
const std::string json_path = argv[1];
2116
auto correction_set = correction::CorrectionSet::from_file(json_path);
22-
auto correction = correction_set->at("electron_sf");
17+
auto correction = correction_set->at("electron_fastsim_sf");
2318

2419
// Mock "GEN-matched" electron values (replace with real NanoAOD lookup later)
25-
const double gen_pt = 15.0;
20+
const double gen_pt = 15.0;
2621
const double gen_eta = 0.4;
2722
const double gen_phi = 2.1;
2823
const double gen_iso = 1e-3;
2924

30-
const double pt_log10 = safe_log10(gen_pt, PT_EPS);
31-
const double iso_log10 = safe_log10(gen_iso, ISO_EPS);
32-
33-
double sf = correction->evaluate({pt_log10, gen_eta, gen_phi, iso_log10});
25+
double sf = correction->evaluate({gen_pt, gen_eta, gen_phi, gen_iso});
3426

3527
std::cout << std::setprecision(17);
3628
std::cout << "sf_fullOverFast " << sf << "\n";

0 commit comments

Comments
 (0)