Skip to content

Commit 8280289

Browse files
committed
Add testing of float with individual epsilon
1 parent fa33561 commit 8280289

2 files changed

Lines changed: 22 additions & 13 deletions

File tree

demo/node/rntuple_test.cxx

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ void rntuple_test()
4141
auto IntField = model->MakeField<int>("IntField");
4242
auto FloatField = model->MakeField<float>("FloatField");
4343
auto Float16Field = model->MakeField<float>("Float16Field");
44+
model->GetMutableField("Float16Field").SetColumnRepresentatives({{ROOT::ENTupleColumnType::kReal16}});
45+
46+
auto Real32Trunc = model->MakeField<float>("Real32Trunc");
47+
dynamic_cast<ROOT::RRealField<float> &>(model->GetMutableField("Real32Trunc")).SetTruncated(20);
48+
49+
auto Real32Quant = model->MakeField<float>("Real32Quant");
50+
dynamic_cast<ROOT::RRealField<float> &>(model->GetMutableField("Real32Quant")).SetQuantized(0., 1., 14);
51+
4452
auto DoubleField = model->MakeField<double>("DoubleField");
4553
auto StringField = model->MakeField<std::string>("StringField");
4654
auto BoolField = model->MakeField<bool>("BoolField");
@@ -57,11 +65,6 @@ void rntuple_test()
5765
auto MapIntDouble = model->MakeField<std::map<int,double>>("MapIntDouble");
5866
auto MapStringBool = model->MakeField<std::map<std::string,bool>>("MapStringBool");
5967

60-
for (auto &f : model->GetMutableFieldZero()) {
61-
if (f.GetTypeName() == "Float16Field")
62-
f.SetColumnRepresentatives({{ROOT::ENTupleColumnType::kReal16}});
63-
}
64-
6568
// We hand-over the data model to a newly created ntuple of name "F", stored in kNTupleFileName
6669
// In return, we get a unique pointer to an ntuple that we can fill
6770
auto writer = ROOT::RNTupleWriter::Recreate(std::move(model), "Data", kNTupleFileName);
@@ -70,7 +73,11 @@ void rntuple_test()
7073

7174
*IntField = i;
7275
*FloatField = i*i;
73-
*Float16Field = 0.1987333 * i;
76+
77+
*Float16Field = 0.1987333 * i; // stored as 16 bits float
78+
*Real32Trunc = 123.45 * i; // here only 20 bits preserved
79+
*Real32Quant = 0.03 * (i % 30); // value should be inside [0..1]
80+
7481
*DoubleField = 0.5 * i;
7582
*StringField = "entry_" + std::to_string(i);
7683
*BoolField = (i % 3 == 1);
@@ -117,8 +124,6 @@ void rntuple_test()
117124
}
118125
Vect2Float->emplace_back(vf);
119126
Vect2Bool->emplace_back(vb);
120-
121-
122127
}
123128

124129
writer->Fill();

demo/node/rntuple_test.js

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,26 +79,29 @@ else {
7979
// Setup selector to process all fields (so cluster gets loaded)
8080
const selector = new TSelector(),
8181
fields = ['IntField', 'FloatField', 'DoubleField',
82-
'Float16Field',
82+
'Float16Field', 'Real32Trunc',
8383
'StringField', 'BoolField',
8484
'ArrayInt', 'VariantField', 'TupleField',
8585
'VectString', 'VectInt', 'VectBool', 'Vect2Float', 'Vect2Bool', 'MultisetField',
86-
'MapStringFloat', 'MapIntDouble', 'MapStringBool'];
86+
'MapStringFloat', 'MapIntDouble', 'MapStringBool'],
87+
epsilonValues = { Real32Trunc: 0.5, Float16Field: 1e-2 };
88+
8789
for (const f of fields)
8890
selector.addBranch(f);
8991

9092
selector.Begin = () => {
9193
console.log('\nBegin processing to load cluster data...');
9294
};
9395

96+
9497
// Now validate entry data
9598
const EPSILON = 1e-7;
9699

97100
let any_error = false;
98101

99-
function compare(expected, value) {
102+
function compare(expected, value, eps) {
100103
if (typeof expected === 'number')
101-
return Math.abs(value - expected) < EPSILON;
104+
return Math.abs(value - expected) < (eps ?? EPSILON);
102105
if (typeof expected === 'object') {
103106
if (expected.length !== undefined) {
104107
if (expected.length !== value.length)
@@ -126,6 +129,7 @@ selector.Process = function(entryIndex) {
126129
FloatField: entryIndex * entryIndex,
127130
DoubleField: entryIndex * 0.5,
128131
Float16Field: entryIndex * 0.1987333,
132+
Real32Trunc: 123.45 * entryIndex,
129133
StringField: `entry_${entryIndex}`,
130134
BoolField: entryIndex % 3 === 1,
131135
ArrayInt: [entryIndex + 1, entryIndex + 2, entryIndex + 3, entryIndex + 4, entryIndex + 5],
@@ -175,7 +179,7 @@ selector.Process = function(entryIndex) {
175179
const value = this.tgtobj[field],
176180
expected = expectedValues[field];
177181

178-
if (!compare(expected, value)) {
182+
if (!compare(expected, value, epsilonValues[field])) {
179183
console.error(`FAILURE: ${field} at entry ${entryIndex} expected ${JSON.stringify(expected)}, got ${JSON.stringify(value)}`);
180184
any_error = true;
181185
} else

0 commit comments

Comments
 (0)