Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 66 additions & 22 deletions src/correction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <rapidjson/error/en.h>
#include <optional>
#include <algorithm>
#include <deque>
#include <stdexcept>
#include <cmath>
#include <cstdlib> // std::abort
Expand Down Expand Up @@ -72,6 +73,44 @@ namespace {
const std::vector<Variable::Type>& values;
};

// Per-thread scratch storage for Transform::evaluate.
// Depth indexing keeps nested Transform evaluations re-entrant-safe.
class TransformScratch {
public:
explicit TransformScratch(const std::vector<Variable::Type>& values):
slot_(acquire_slot())
{
slot_ = values;
}

~TransformScratch() {
release_slot();
}

std::vector<Variable::Type>& values() {
return slot_;
}

private:
static std::vector<Variable::Type>& acquire_slot() {
if (depth_ == slots_.size()) {
slots_.emplace_back();
}
return slots_[depth_++];
}

static void release_slot() {
depth_--;
}

std::vector<Variable::Type>& slot_;
static thread_local std::deque<std::vector<Variable::Type>> slots_;
static thread_local std::size_t depth_;
};

thread_local std::deque<std::vector<Variable::Type>> TransformScratch::slots_;
thread_local std::size_t TransformScratch::depth_ = 0;

std::size_t find_bin_idx(Variable::Type value_variant,
const detail::EdgesType &bins_,
const detail::FlowBehavior &flow,
Expand Down Expand Up @@ -110,7 +149,7 @@ namespace {

// otherwise we have non-uniform binning
using namespace std::string_literals;
const auto bins = std::get<detail::NonUniformBins>(bins_);
const auto& bins = std::get<detail::NonUniformBins>(bins_);
if ( flow == detail::FlowBehavior::wrap ) {
double low = bins[0];
double high = bins[bins.size() - 1];
Expand Down Expand Up @@ -309,7 +348,9 @@ Transform::Transform(const JSONObject& json, const Correction& context) {
}

double Transform::evaluate(const std::vector<Variable::Type>& values) const {
std::vector<Variable::Type> new_values(values);
TransformScratch scratch(values);
auto& new_values = scratch.values();

double vnew = std::visit(node_evaluate{values}, *rule_);
auto& v = new_values[variableIdx_];
if ( std::holds_alternative<double>(v) ) {
Expand Down Expand Up @@ -558,27 +599,25 @@ Category::Category(const JSONObject& json, const Correction& context)
double Category::evaluate(const std::vector<Variable::Type>& values) const {
const Content* child = nullptr;
if ( auto pval = std::get_if<std::string>(&values[variableIdx_]) ) {
try {
child = &std::get<StrMap>(map_).at(*pval);
} catch (std::out_of_range& ex) {
if ( default_ ) {
child = default_.get();
}
else {
throw std::out_of_range("Index not available in Category for input argument " + std::to_string(variableIdx_) + " val: " + *pval);
}
const auto& m = std::get<StrMap>(map_);
auto it = m.find(*pval);
if ( it != m.end() ) {
child = &it->second;
} else if ( default_ ) {
child = default_.get();
} else {
throw std::out_of_range("Index not available in Category for input argument " + std::to_string(variableIdx_) + " val: " + *pval);
}
}
else if ( auto pval = std::get_if<int64_t>(&values[variableIdx_]) ) {
try {
child = &std::get<IntMap>(map_).at(*pval);
} catch (std::out_of_range& ex) {
if ( default_ ) {
child = default_.get();
}
else {
throw std::out_of_range("Index not available in Category for input argument " + std::to_string(variableIdx_) + " val: " + std::to_string(*pval));
}
const auto& m = std::get<IntMap>(map_);
auto it = m.find(*pval);
if ( it != m.end() ) {
child = &it->second;
} else if ( default_ ) {
child = default_.get();
} else {
throw std::out_of_range("Index not available in Category for input argument " + std::to_string(variableIdx_) + " val: " + std::to_string(*pval));
}
} else {
throw std::runtime_error("Invalid variable type");
Expand Down Expand Up @@ -682,16 +721,21 @@ size_t CompoundCorrection::input_index(const std::string_view name) const {
}

double CompoundCorrection::evaluate(const std::vector<Variable::Type>& values) const {
// Per-thread scratch storage. This call site is not re-entrant so we
// can use a simpler implementation than for TransformScratch
static thread_local std::vector<Variable::Type> ivalues;
static thread_local std::vector<Variable::Type> cvalues;

if ( values.size() != inputs_.size() ) {
throw std::invalid_argument("Incorrect number of inputs (got " + std::to_string(values.size())
+ ", expected " + std::to_string(inputs_.size()) + ")");
}
for (size_t i=0; i < inputs_.size(); ++i) {
inputs_[i].validate(values[i]);
}
std::vector<Variable::Type> ivalues(values);
std::vector<Variable::Type> cvalues;
ivalues = values;
cvalues.reserve(values.size());

double out = 0.;
double sf = 0.;
bool start{true};
Expand Down
70 changes: 70 additions & 0 deletions tests/test_transform_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import correctionlib._core as core
from correctionlib import schemav2 as schema


def wrap(*corrs):
cset = schema.CorrectionSet(
schema_version=schema.VERSION,
corrections=list(corrs),
)
return core.CorrectionSet.from_string(cset.model_dump_json())


def test_transform_nested_rule_and_content():
# Build a nested transform tree with transforms in both the rule and
# content paths to exercise recursive evaluation and scratch-buffer reuse.
cset = wrap(
schema.Correction(
name="nested_transform",
version=2,
inputs=[
schema.Variable(name="x", type="real"),
schema.Variable(name="y", type="real"),
],
output=schema.Variable(name="out", type="real"),
data=schema.Transform(
nodetype="transform",
input="x",
rule=schema.Transform(
nodetype="transform",
input="y",
rule=1.0,
content=schema.Formula(
nodetype="formula",
expression="x + y",
parser="TFormula",
variables=["x", "y"],
),
),
content=schema.Transform(
nodetype="transform",
input="y",
rule=schema.Transform(
nodetype="transform",
input="x",
rule=2.0,
content=schema.Formula(
nodetype="formula",
expression="x + y",
parser="TFormula",
variables=["x", "y"],
),
),
content=schema.Formula(
nodetype="formula",
expression="x + y",
parser="TFormula",
variables=["x", "y"],
),
),
),
)
)

corr = cset["nested_transform"]
# x=3, y=4
# outer rule: transform y->1 then x+y => 3+1 = 4, so x becomes 4
# outer content: transform y with rule:
# inner rule transform x->2 then x+y => 2+4 = 6, so y becomes 6
# final content: x+y = 4+6 = 10
assert corr.evaluate(3.0, 4.0) == 10.0
Loading