Skip to content

Commit 3eed55d

Browse files
committed
branch: clarifications and more tests
1 parent d51ccd3 commit 3eed55d

File tree

4 files changed

+237
-2
lines changed

4 files changed

+237
-2
lines changed

docs/source/symbols.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ These can be differentiated and their generated code will feature actual `if-els
5959
Operations such as `min(), max()` use `branch` internally.
6060
Note that the use of `branch` prevents emitting SIMD code.
6161
62+
**Important:** The condition passed to `branch` is a `Scalar` whose **sign** determines which branch executes.
63+
When you write `a > b`, SymX internally stores `a - b` as the condition scalar.
64+
Opposite for `a < b`.
65+
The true branch executes when that condition scalar is **strictly positive** (`> 0`), exact zero will then fall to the false branch.
66+
Due to the floating point nature of the comparison, the user might consider writing activation mechanisms with conditions that are `+1` and `-1` (instead of `0`) to more clearly indicate sign.
67+
6268
## `Vector` and `Matrix`
6369
`Vector` and `Matrix` are simply a list of `Scalar`s that provide typical algebraic operator overloads.
6470
Further, `Vector` provides `norm(), dot(), cross3()` and such, while `Matrix` provides `det(), inv(), trace()` etc.

symx/src/compile/Compilation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ void symx::Compilation::_add_instructions_scalar(std::string& code, Sequence& se
448448
code += tab() + "}\n";
449449
}
450450
if (op.is_positive_branch()) {
451-
code += tab() + "if (" + idx(op.cond) + " >= 0.0)\n";
451+
code += tab() + "if (" + idx(op.cond) + " > 0.0)\n";
452452
code += tab() + "{\n";
453453
indentation++;
454454
}

symx/src/symbol/Expressions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ double Expressions::eval(int32_t expr_id) const
188188
return std::atan(this->eval(expr.a)); break;
189189

190190
case ExprType::Branch:
191-
if (this->eval(expr.cond) >= 0) {
191+
if (this->eval(expr.cond) > 0) {
192192
return this->eval(expr.a);
193193
}
194194
else {

tests/test_scalar.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,235 @@ TEST_CASE("Branches", "[scalar]")
355355
}
356356
}
357357

358+
TEST_CASE("Branch boundary conditions", "[scalar]")
359+
{
360+
const double eps = std::numeric_limits<double>::epsilon();
361+
362+
// -------------------------------------------------------------------
363+
// 1. Direct-scalar selector: branch(c, T, F)
364+
// True branch fires when c > 0 (strictly).
365+
// -------------------------------------------------------------------
366+
SECTION("Selector conditions +1/+eps/0/-eps/-1 - eval")
367+
{
368+
Workspace ws;
369+
Scalar c = ws.make_scalar();
370+
Scalar E = branch(c, 10.0, -5.0);
371+
372+
c.set_value(1.0); REQUIRE(approx(E.eval(), 10.0)); // +1 → true
373+
c.set_value(eps); REQUIRE(approx(E.eval(), 10.0)); // +eps → true
374+
c.set_value(0.0); REQUIRE(approx(E.eval(), -5.0)); // 0 → false
375+
c.set_value(-eps); REQUIRE(approx(E.eval(), -5.0)); // -eps → false
376+
c.set_value(-1.0); REQUIRE(approx(E.eval(), -5.0)); // -1 → false
377+
}
378+
379+
SECTION("Selector conditions +1/+eps/0/-eps/-1 - compiled")
380+
{
381+
Workspace ws;
382+
Scalar c = ws.make_scalar();
383+
Scalar E = branch(c, 10.0, -5.0);
384+
Compiled<double> compiled({E}, "branch_selector", symx::get_codegen_dir(), E.get_checksum());
385+
386+
compiled.set(c, 1.0); REQUIRE(approx(compiled.run()[0], 10.0)); // +1 → true
387+
compiled.set(c, eps); REQUIRE(approx(compiled.run()[0], 10.0)); // +eps → true
388+
compiled.set(c, 0.0); REQUIRE(approx(compiled.run()[0], -5.0)); // 0 → false
389+
compiled.set(c, -eps); REQUIRE(approx(compiled.run()[0], -5.0)); // -eps → false
390+
compiled.set(c, -1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // -1 → false
391+
}
392+
393+
// -------------------------------------------------------------------
394+
// 2. operator> : branch(a > b, T, F)
395+
// Internally stores cond = a - b. True branch fires when a - b > 0.
396+
// -------------------------------------------------------------------
397+
SECTION("operator> boundary - eval")
398+
{
399+
Workspace ws;
400+
Scalar a = ws.make_scalar();
401+
Scalar b = ws.make_scalar();
402+
Scalar E = branch(a > b, 10.0, -5.0);
403+
404+
a.set_value(2.0); b.set_value(1.0); REQUIRE(approx(E.eval(), 10.0)); // a > b → true
405+
a.set_value(1.0 + eps); b.set_value(1.0); REQUIRE(approx(E.eval(), 10.0)); // a = b+eps → true
406+
a.set_value(1.0); b.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // a == b → false (equality → false branch)
407+
a.set_value(1.0 - eps); b.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // a = b-eps → false
408+
a.set_value(0.0); b.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // a < b → false
409+
}
410+
411+
SECTION("operator> boundary - compiled")
412+
{
413+
Workspace ws;
414+
Scalar a = ws.make_scalar();
415+
Scalar b = ws.make_scalar();
416+
Scalar E = branch(a > b, 10.0, -5.0);
417+
Compiled<double> compiled({E}, "branch_gt", symx::get_codegen_dir(), E.get_checksum());
418+
419+
compiled.set(a, 2.0); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], 10.0)); // a > b
420+
compiled.set(a, 1.0 + eps); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], 10.0)); // a = b+eps
421+
compiled.set(a, 1.0); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // a == b → false
422+
compiled.set(a, 1.0 - eps); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // a = b-eps
423+
compiled.set(a, 0.0); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // a < b
424+
}
425+
426+
// -------------------------------------------------------------------
427+
// 3. operator< : branch(a < b, T, F)
428+
// Internally stores cond = b - a. True branch fires when b - a > 0.
429+
// -------------------------------------------------------------------
430+
SECTION("operator< boundary - eval")
431+
{
432+
Workspace ws;
433+
Scalar a = ws.make_scalar();
434+
Scalar b = ws.make_scalar();
435+
Scalar E = branch(a < b, 10.0, -5.0);
436+
437+
a.set_value(0.0); b.set_value(1.0); REQUIRE(approx(E.eval(), 10.0)); // a < b → true
438+
a.set_value(1.0 - eps); b.set_value(1.0); REQUIRE(approx(E.eval(), 10.0)); // a = b-eps → true
439+
a.set_value(1.0); b.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // a == b → false
440+
a.set_value(1.0 + eps); b.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // a = b+eps → false
441+
a.set_value(2.0); b.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // a > b → false
442+
}
443+
444+
SECTION("operator< boundary - compiled")
445+
{
446+
Workspace ws;
447+
Scalar a = ws.make_scalar();
448+
Scalar b = ws.make_scalar();
449+
Scalar E = branch(a < b, 10.0, -5.0);
450+
Compiled<double> compiled({E}, "branch_lt", symx::get_codegen_dir(), E.get_checksum());
451+
452+
compiled.set(a, 0.0); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], 10.0)); // a < b
453+
compiled.set(a, 1.0 - eps); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], 10.0)); // a = b-eps
454+
compiled.set(a, 1.0); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // a == b → false
455+
compiled.set(a, 1.0 + eps); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // a = b+eps
456+
compiled.set(a, 2.0); compiled.set(b, 1.0); REQUIRE(approx(compiled.run()[0], -5.0)); // a > b
457+
}
458+
459+
// -------------------------------------------------------------------
460+
// 4. Comparison against zero (common activation pattern)
461+
// -------------------------------------------------------------------
462+
SECTION("branch(a > 0) activation pattern - eval and compiled")
463+
{
464+
Workspace ws;
465+
Scalar a = ws.make_scalar();
466+
Scalar E = branch(a > 0.0, 10.0, -5.0);
467+
468+
// --- eval ---
469+
a.set_value(1.0); REQUIRE(approx(E.eval(), 10.0)); // +1 → true
470+
a.set_value(eps); REQUIRE(approx(E.eval(), 10.0)); // +eps → true
471+
a.set_value(0.0); REQUIRE(approx(E.eval(), -5.0)); // 0 → false (activation off)
472+
a.set_value(-eps); REQUIRE(approx(E.eval(), -5.0)); // -eps → false
473+
a.set_value(-1.0); REQUIRE(approx(E.eval(), -5.0)); // -1 → false
474+
475+
// --- compiled ---
476+
Compiled<double> compiled({E}, "branch_gt_zero", symx::get_codegen_dir(), E.get_checksum());
477+
compiled.set(a, 1.0); REQUIRE(approx(compiled.run()[0], 10.0));
478+
compiled.set(a, eps); REQUIRE(approx(compiled.run()[0], 10.0));
479+
compiled.set(a, 0.0); REQUIRE(approx(compiled.run()[0], -5.0));
480+
compiled.set(a, -eps); REQUIRE(approx(compiled.run()[0], -5.0));
481+
compiled.set(a, -1.0); REQUIRE(approx(compiled.run()[0], -5.0));
482+
}
483+
484+
SECTION("branch(a < 0) activation pattern - eval and compiled")
485+
{
486+
Workspace ws;
487+
Scalar a = ws.make_scalar();
488+
Scalar E = branch(a < 0.0, 10.0, -5.0);
489+
490+
// --- eval ---
491+
a.set_value(-1.0); REQUIRE(approx(E.eval(), 10.0)); // -1 → true
492+
a.set_value(-eps); REQUIRE(approx(E.eval(), 10.0)); // -eps → true
493+
a.set_value(0.0); REQUIRE(approx(E.eval(), -5.0)); // 0 → false
494+
a.set_value(eps); REQUIRE(approx(E.eval(), -5.0)); // +eps → false
495+
a.set_value(1.0); REQUIRE(approx(E.eval(), -5.0)); // +1 → false
496+
497+
// --- compiled ---
498+
Compiled<double> compiled({E}, "branch_lt_zero", symx::get_codegen_dir(), E.get_checksum());
499+
compiled.set(a, -1.0); REQUIRE(approx(compiled.run()[0], 10.0));
500+
compiled.set(a, -eps); REQUIRE(approx(compiled.run()[0], 10.0));
501+
compiled.set(a, 0.0); REQUIRE(approx(compiled.run()[0], -5.0));
502+
compiled.set(a, eps); REQUIRE(approx(compiled.run()[0], -5.0));
503+
compiled.set(a, 1.0); REQUIRE(approx(compiled.run()[0], -5.0));
504+
}
505+
506+
// -------------------------------------------------------------------
507+
// 5. Derived operations (min, max, abs, sign) at equality boundary
508+
// -------------------------------------------------------------------
509+
SECTION("min and max at boundary - eval")
510+
{
511+
Workspace ws;
512+
Scalar a = ws.make_scalar();
513+
Scalar b = ws.make_scalar();
514+
Scalar mn = min(a, b);
515+
Scalar mx = max(a, b);
516+
517+
// a > b
518+
a.set_value(3.0); b.set_value(2.0);
519+
REQUIRE(approx(mn.eval(), 2.0));
520+
REQUIRE(approx(mx.eval(), 3.0));
521+
522+
// a < b
523+
a.set_value(1.0); b.set_value(2.0);
524+
REQUIRE(approx(mn.eval(), 1.0));
525+
REQUIRE(approx(mx.eval(), 2.0));
526+
527+
// a == b: both branches collapse to the same numeric value
528+
a.set_value(2.0); b.set_value(2.0);
529+
REQUIRE(approx(mn.eval(), 2.0));
530+
REQUIRE(approx(mx.eval(), 2.0));
531+
532+
// a = b + eps (a strictly greater)
533+
a.set_value(1.0 + eps); b.set_value(1.0);
534+
REQUIRE(approx(mn.eval(), 1.0, 1e-14));
535+
REQUIRE(approx(mx.eval(), 1.0 + eps, 1e-14));
536+
}
537+
538+
SECTION("abs and sign at boundary - eval")
539+
{
540+
Workspace ws;
541+
Scalar a = ws.make_scalar();
542+
Scalar ab = abs(a);
543+
Scalar sg = sign(a);
544+
545+
// Positive: true branch
546+
a.set_value(3.0); REQUIRE(approx(ab.eval(), 3.0)); REQUIRE(approx(sg.eval(), 1.0));
547+
a.set_value(eps); REQUIRE(approx(ab.eval(), eps)); REQUIRE(approx(sg.eval(), 1.0));
548+
549+
// Zero: cond = 0 is not > 0, so false branch
550+
// abs(0) = -(-0) = 0 (both branches give 0 anyway)
551+
// sign(0) = -1.0 (false branch)
552+
a.set_value(0.0); REQUIRE(approx(ab.eval(), 0.0)); REQUIRE(approx(sg.eval(), -1.0));
553+
554+
// Negative: false branch
555+
a.set_value(-eps); REQUIRE(approx(ab.eval(), eps)); REQUIRE(approx(sg.eval(), -1.0));
556+
a.set_value(-3.0); REQUIRE(approx(ab.eval(), 3.0)); REQUIRE(approx(sg.eval(), -1.0));
557+
}
558+
559+
// -------------------------------------------------------------------
560+
// 6. Eval vs compiled consistency across the boundary
561+
// -------------------------------------------------------------------
562+
SECTION("Eval vs compiled consistency at boundary")
563+
{
564+
Workspace ws;
565+
Scalar a = ws.make_scalar();
566+
Scalar b = ws.make_scalar();
567+
568+
// Use a non-trivial expression so both paths exercise real computation
569+
Scalar E = branch(a > b, a * a - b, b.sqrt() + a);
570+
Compiled<double> compiled({E}, "branch_boundary_consistency",
571+
symx::get_codegen_dir(), E.get_checksum());
572+
573+
auto check = [&](double A, double B) {
574+
a.set_value(A); b.set_value(B);
575+
compiled.set(a, A); compiled.set(b, B);
576+
REQUIRE(approx(E.eval(), compiled.run()[0]));
577+
};
578+
579+
check(3.0, 2.0); // a > b
580+
check(1.0 + eps, 1.0); // a = b+eps (just above boundary)
581+
check(1.0, 1.0); // a == b (boundary → false branch)
582+
check(1.0 - eps, 1.0); // a = b-eps (just below boundary)
583+
check(0.0, 2.0); // a < b
584+
}
585+
}
586+
358587
TEST_CASE("Log edge cases", "[scalar]")
359588
{
360589
Workspace ws;

0 commit comments

Comments
 (0)