@@ -355,6 +355,235 @@ TEST_CASE("Branches", "[scalar]")
355355 }
356356}
357357
358+ TEST_CASE (" Branch boundary conditions" , " [scalar]" )
359+ {
360+ const double eps = std::numeric_limits<double >::epsilon ();
361+
362+ // -------------------------------------------------------------------
363+ // 1. Direct-scalar selector: branch(c, T, F)
364+ // True branch fires when c > 0 (strictly).
365+ // -------------------------------------------------------------------
366+ SECTION (" Selector conditions +1/+eps/0/-eps/-1 - eval" )
367+ {
368+ Workspace ws;
369+ Scalar c = ws.make_scalar ();
370+ Scalar E = branch (c, 10.0 , -5.0 );
371+
372+ c.set_value (1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // +1 → true
373+ c.set_value (eps); REQUIRE (approx (E.eval (), 10.0 )); // +eps → true
374+ c.set_value (0.0 ); REQUIRE (approx (E.eval (), -5.0 )); // 0 → false
375+ c.set_value (-eps); REQUIRE (approx (E.eval (), -5.0 )); // -eps → false
376+ c.set_value (-1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // -1 → false
377+ }
378+
379+ SECTION (" Selector conditions +1/+eps/0/-eps/-1 - compiled" )
380+ {
381+ Workspace ws;
382+ Scalar c = ws.make_scalar ();
383+ Scalar E = branch (c, 10.0 , -5.0 );
384+ Compiled<double > compiled ({E}, " branch_selector" , symx::get_codegen_dir (), E.get_checksum ());
385+
386+ compiled.set (c, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 )); // +1 → true
387+ compiled.set (c, eps); REQUIRE (approx (compiled.run ()[0 ], 10.0 )); // +eps → true
388+ compiled.set (c, 0.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // 0 → false
389+ compiled.set (c, -eps); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // -eps → false
390+ compiled.set (c, -1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // -1 → false
391+ }
392+
393+ // -------------------------------------------------------------------
394+ // 2. operator> : branch(a > b, T, F)
395+ // Internally stores cond = a - b. True branch fires when a - b > 0.
396+ // -------------------------------------------------------------------
397+ SECTION (" operator> boundary - eval" )
398+ {
399+ Workspace ws;
400+ Scalar a = ws.make_scalar ();
401+ Scalar b = ws.make_scalar ();
402+ Scalar E = branch (a > b, 10.0 , -5.0 );
403+
404+ a.set_value (2.0 ); b.set_value (1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // a > b → true
405+ a.set_value (1.0 + eps); b.set_value (1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // a = b+eps → true
406+ a.set_value (1.0 ); b.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // a == b → false (equality → false branch)
407+ a.set_value (1.0 - eps); b.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // a = b-eps → false
408+ a.set_value (0.0 ); b.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // a < b → false
409+ }
410+
411+ SECTION (" operator> boundary - compiled" )
412+ {
413+ Workspace ws;
414+ Scalar a = ws.make_scalar ();
415+ Scalar b = ws.make_scalar ();
416+ Scalar E = branch (a > b, 10.0 , -5.0 );
417+ Compiled<double > compiled ({E}, " branch_gt" , symx::get_codegen_dir (), E.get_checksum ());
418+
419+ compiled.set (a, 2.0 ); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 )); // a > b
420+ compiled.set (a, 1.0 + eps); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 )); // a = b+eps
421+ compiled.set (a, 1.0 ); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // a == b → false
422+ compiled.set (a, 1.0 - eps); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // a = b-eps
423+ compiled.set (a, 0.0 ); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // a < b
424+ }
425+
426+ // -------------------------------------------------------------------
427+ // 3. operator< : branch(a < b, T, F)
428+ // Internally stores cond = b - a. True branch fires when b - a > 0.
429+ // -------------------------------------------------------------------
430+ SECTION (" operator< boundary - eval" )
431+ {
432+ Workspace ws;
433+ Scalar a = ws.make_scalar ();
434+ Scalar b = ws.make_scalar ();
435+ Scalar E = branch (a < b, 10.0 , -5.0 );
436+
437+ a.set_value (0.0 ); b.set_value (1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // a < b → true
438+ a.set_value (1.0 - eps); b.set_value (1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // a = b-eps → true
439+ a.set_value (1.0 ); b.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // a == b → false
440+ a.set_value (1.0 + eps); b.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // a = b+eps → false
441+ a.set_value (2.0 ); b.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // a > b → false
442+ }
443+
444+ SECTION (" operator< boundary - compiled" )
445+ {
446+ Workspace ws;
447+ Scalar a = ws.make_scalar ();
448+ Scalar b = ws.make_scalar ();
449+ Scalar E = branch (a < b, 10.0 , -5.0 );
450+ Compiled<double > compiled ({E}, " branch_lt" , symx::get_codegen_dir (), E.get_checksum ());
451+
452+ compiled.set (a, 0.0 ); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 )); // a < b
453+ compiled.set (a, 1.0 - eps); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 )); // a = b-eps
454+ compiled.set (a, 1.0 ); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // a == b → false
455+ compiled.set (a, 1.0 + eps); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // a = b+eps
456+ compiled.set (a, 2.0 ); compiled.set (b, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 )); // a > b
457+ }
458+
459+ // -------------------------------------------------------------------
460+ // 4. Comparison against zero (common activation pattern)
461+ // -------------------------------------------------------------------
462+ SECTION (" branch(a > 0) activation pattern - eval and compiled" )
463+ {
464+ Workspace ws;
465+ Scalar a = ws.make_scalar ();
466+ Scalar E = branch (a > 0.0 , 10.0 , -5.0 );
467+
468+ // --- eval ---
469+ a.set_value (1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // +1 → true
470+ a.set_value (eps); REQUIRE (approx (E.eval (), 10.0 )); // +eps → true
471+ a.set_value (0.0 ); REQUIRE (approx (E.eval (), -5.0 )); // 0 → false (activation off)
472+ a.set_value (-eps); REQUIRE (approx (E.eval (), -5.0 )); // -eps → false
473+ a.set_value (-1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // -1 → false
474+
475+ // --- compiled ---
476+ Compiled<double > compiled ({E}, " branch_gt_zero" , symx::get_codegen_dir (), E.get_checksum ());
477+ compiled.set (a, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 ));
478+ compiled.set (a, eps); REQUIRE (approx (compiled.run ()[0 ], 10.0 ));
479+ compiled.set (a, 0.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 ));
480+ compiled.set (a, -eps); REQUIRE (approx (compiled.run ()[0 ], -5.0 ));
481+ compiled.set (a, -1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 ));
482+ }
483+
484+ SECTION (" branch(a < 0) activation pattern - eval and compiled" )
485+ {
486+ Workspace ws;
487+ Scalar a = ws.make_scalar ();
488+ Scalar E = branch (a < 0.0 , 10.0 , -5.0 );
489+
490+ // --- eval ---
491+ a.set_value (-1.0 ); REQUIRE (approx (E.eval (), 10.0 )); // -1 → true
492+ a.set_value (-eps); REQUIRE (approx (E.eval (), 10.0 )); // -eps → true
493+ a.set_value (0.0 ); REQUIRE (approx (E.eval (), -5.0 )); // 0 → false
494+ a.set_value (eps); REQUIRE (approx (E.eval (), -5.0 )); // +eps → false
495+ a.set_value (1.0 ); REQUIRE (approx (E.eval (), -5.0 )); // +1 → false
496+
497+ // --- compiled ---
498+ Compiled<double > compiled ({E}, " branch_lt_zero" , symx::get_codegen_dir (), E.get_checksum ());
499+ compiled.set (a, -1.0 ); REQUIRE (approx (compiled.run ()[0 ], 10.0 ));
500+ compiled.set (a, -eps); REQUIRE (approx (compiled.run ()[0 ], 10.0 ));
501+ compiled.set (a, 0.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 ));
502+ compiled.set (a, eps); REQUIRE (approx (compiled.run ()[0 ], -5.0 ));
503+ compiled.set (a, 1.0 ); REQUIRE (approx (compiled.run ()[0 ], -5.0 ));
504+ }
505+
506+ // -------------------------------------------------------------------
507+ // 5. Derived operations (min, max, abs, sign) at equality boundary
508+ // -------------------------------------------------------------------
509+ SECTION (" min and max at boundary - eval" )
510+ {
511+ Workspace ws;
512+ Scalar a = ws.make_scalar ();
513+ Scalar b = ws.make_scalar ();
514+ Scalar mn = min (a, b);
515+ Scalar mx = max (a, b);
516+
517+ // a > b
518+ a.set_value (3.0 ); b.set_value (2.0 );
519+ REQUIRE (approx (mn.eval (), 2.0 ));
520+ REQUIRE (approx (mx.eval (), 3.0 ));
521+
522+ // a < b
523+ a.set_value (1.0 ); b.set_value (2.0 );
524+ REQUIRE (approx (mn.eval (), 1.0 ));
525+ REQUIRE (approx (mx.eval (), 2.0 ));
526+
527+ // a == b: both branches collapse to the same numeric value
528+ a.set_value (2.0 ); b.set_value (2.0 );
529+ REQUIRE (approx (mn.eval (), 2.0 ));
530+ REQUIRE (approx (mx.eval (), 2.0 ));
531+
532+ // a = b + eps (a strictly greater)
533+ a.set_value (1.0 + eps); b.set_value (1.0 );
534+ REQUIRE (approx (mn.eval (), 1.0 , 1e-14 ));
535+ REQUIRE (approx (mx.eval (), 1.0 + eps, 1e-14 ));
536+ }
537+
538+ SECTION (" abs and sign at boundary - eval" )
539+ {
540+ Workspace ws;
541+ Scalar a = ws.make_scalar ();
542+ Scalar ab = abs (a);
543+ Scalar sg = sign (a);
544+
545+ // Positive: true branch
546+ a.set_value (3.0 ); REQUIRE (approx (ab.eval (), 3.0 )); REQUIRE (approx (sg.eval (), 1.0 ));
547+ a.set_value (eps); REQUIRE (approx (ab.eval (), eps)); REQUIRE (approx (sg.eval (), 1.0 ));
548+
549+ // Zero: cond = 0 is not > 0, so false branch
550+ // abs(0) = -(-0) = 0 (both branches give 0 anyway)
551+ // sign(0) = -1.0 (false branch)
552+ a.set_value (0.0 ); REQUIRE (approx (ab.eval (), 0.0 )); REQUIRE (approx (sg.eval (), -1.0 ));
553+
554+ // Negative: false branch
555+ a.set_value (-eps); REQUIRE (approx (ab.eval (), eps)); REQUIRE (approx (sg.eval (), -1.0 ));
556+ a.set_value (-3.0 ); REQUIRE (approx (ab.eval (), 3.0 )); REQUIRE (approx (sg.eval (), -1.0 ));
557+ }
558+
559+ // -------------------------------------------------------------------
560+ // 6. Eval vs compiled consistency across the boundary
561+ // -------------------------------------------------------------------
562+ SECTION (" Eval vs compiled consistency at boundary" )
563+ {
564+ Workspace ws;
565+ Scalar a = ws.make_scalar ();
566+ Scalar b = ws.make_scalar ();
567+
568+ // Use a non-trivial expression so both paths exercise real computation
569+ Scalar E = branch (a > b, a * a - b, b.sqrt () + a);
570+ Compiled<double > compiled ({E}, " branch_boundary_consistency" ,
571+ symx::get_codegen_dir (), E.get_checksum ());
572+
573+ auto check = [&](double A, double B) {
574+ a.set_value (A); b.set_value (B);
575+ compiled.set (a, A); compiled.set (b, B);
576+ REQUIRE (approx (E.eval (), compiled.run ()[0 ]));
577+ };
578+
579+ check (3.0 , 2.0 ); // a > b
580+ check (1.0 + eps, 1.0 ); // a = b+eps (just above boundary)
581+ check (1.0 , 1.0 ); // a == b (boundary → false branch)
582+ check (1.0 - eps, 1.0 ); // a = b-eps (just below boundary)
583+ check (0.0 , 2.0 ); // a < b
584+ }
585+ }
586+
358587TEST_CASE (" Log edge cases" , " [scalar]" )
359588{
360589 Workspace ws;
0 commit comments