Skip to content

Commit c58d68f

Browse files
committed
refactor(compiler): simplify logical or conditional lowering
Rename the OR fallback lowering path around its onTrue/onFalse contract, share scalar comparison extraction, and gate base conditions once before entering OR fallback selection. Add E2E coverage for heap-string fallback, tuple fallback, and nested value-position OR expressions.
1 parent 95e6311 commit c58d68f

5 files changed

Lines changed: 48 additions & 29 deletions

File tree

compiler/compiler.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,8 @@ func (c *Compiler) compileInfixBasic(expr *ast.InfixExpression, info *ExprInfo)
17201720
// Usually pre-extracted via condLHS, but can still occur when range
17211721
// comparisons are scalarized by an outer loop (e.g. call arg vectorization).
17221722
res = append(res, c.compileCondScalar(expr.Operator, left[i], right[i]))
1723+
case CondOr:
1724+
panic("internal: value-position logical OR must be lowered through conditional expression branching")
17231725
default:
17241726
res = append(res, c.compileInfix(expr.Operator, left[i], right[i], info.OutTypes[i]))
17251727
}

compiler/cond.go

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -475,23 +475,9 @@ func (c *Compiler) extractCondExprs(expr ast.Expression, cond llvm.Value, temps
475475
// Comparisons with ranges can be extracted only when all required iterators
476476
// are already bound by an outer loop (no pending ranges).
477477
if infix, ok := expr.(*ast.InfixExpression); ok && info.HasCondScalar() && len(c.pendingLoopRanges(info.Ranges)) == 0 {
478-
// Bottom-up: extract conditions from operands first
479478
cond, temps = c.extractCondExprs(infix.Left, cond, temps)
480479
cond, temps = c.extractCondExprs(infix.Right, cond, temps)
481-
482-
// Compile both operands (may return pre-extracted values)
483-
left := c.compileExpression(infix.Left, nil)
484-
right := c.compileExpression(infix.Right, nil)
485-
486-
var lhsSyms []*Symbol
487-
lhsSyms, cond = c.handleComparisons(infix.Operator, left, right, info, cond)
488-
489-
c.requireCondLHSFrame()[key(c.FuncNameMangled, expr)] = lhsSyms
490-
temps = append(temps, condTemp{infix.Left, left})
491-
// Free right-side temporaries (only used for comparison).
492-
// Left-side values are retained in condLHS for later substitution.
493-
c.freeTemporary(infix.Right, right)
494-
return cond, temps
480+
return c.extractCondExprSelf(infix, info, cond, temps)
495481
}
496482

497483
// Not a conditional expression — recurse into children
@@ -501,6 +487,21 @@ func (c *Compiler) extractCondExprs(expr ast.Expression, cond llvm.Value, temps
501487
return cond, temps
502488
}
503489

490+
func (c *Compiler) extractCondExprSelf(infix *ast.InfixExpression, info *ExprInfo, cond llvm.Value, temps []condTemp) (llvm.Value, []condTemp) {
491+
left := c.compileExpression(infix.Left, nil)
492+
right := c.compileExpression(infix.Right, nil)
493+
494+
var lhsSyms []*Symbol
495+
lhsSyms, cond = c.handleComparisons(infix.Operator, left, right, info, cond)
496+
497+
c.requireCondLHSFrame()[key(c.FuncNameMangled, infix)] = lhsSyms
498+
temps = append(temps, condTemp{infix.Left, left})
499+
// Free right-side temporaries (only used for comparison).
500+
// Left-side values are retained in condLHS for later substitution.
501+
c.freeTemporary(infix.Right, right)
502+
return cond, temps
503+
}
504+
504505
// cleanupCondExprElse frees temporaries retained during cond-expr extraction
505506
// that are not consumed when the condition evaluates to false.
506507
func (c *Compiler) cleanupCondExprElse(temps []condTemp) {
@@ -530,7 +531,7 @@ func (c *Compiler) compileCondExprValue(expr ast.Expression, baseCond llvm.Value
530531

531532
func (c *Compiler) compileCondExprValueInFrame(expr ast.Expression, baseCond llvm.Value, onTrue func()) {
532533
if c.hasLogicalOrCondExprInTree(expr) {
533-
c.compileCondExprAlternative(expr, baseCond, onTrue, func() {})
534+
c.compileCondExprWithFailure(expr, baseCond, onTrue, func() {})
534535
return
535536
}
536537

@@ -591,7 +592,7 @@ func (c *Compiler) compileCondExprChildrenInFrame(children []ast.Expression, bas
591592
child := children[0]
592593
rest := children[1:]
593594
if c.hasCondExprInTree(child) {
594-
c.compileCondExprAlternative(child, baseCond, func() {
595+
c.compileCondExprWithFailure(child, baseCond, func() {
595596
c.compileCondExprChildrenInFrame(rest, llvm.Value{}, onTrue, onFalse)
596597
}, onFalse)
597598
return
@@ -600,9 +601,9 @@ func (c *Compiler) compileCondExprChildrenInFrame(children []ast.Expression, bas
600601
c.compileCondExprChildrenInFrame(rest, baseCond, onTrue, onFalse)
601602
}
602603

603-
func (c *Compiler) compileCondExprAlternative(expr ast.Expression, baseCond llvm.Value, onTrue func(), onFalse func()) {
604+
func (c *Compiler) compileCondExprWithFailure(expr ast.Expression, baseCond llvm.Value, onTrue func(), onFalse func()) {
604605
if logicalOr, ok := c.logicalOrCondExpr(expr); ok {
605-
c.compileLogicalOrCondExprAlternative(logicalOr, baseCond, onTrue, onFalse)
606+
c.compileLogicalOrCondExprWithFailure(logicalOr, baseCond, onTrue, onFalse)
606607
return
607608
}
608609

@@ -616,13 +617,7 @@ func (c *Compiler) compileCondExprAlternative(expr ast.Expression, baseCond llvm
616617
if infix, ok := expr.(*ast.InfixExpression); ok {
617618
info := c.ExprCache[key(c.FuncNameMangled, infix)]
618619
if info != nil && info.HasCondScalar() && len(c.pendingLoopRanges(info.Ranges)) == 0 {
619-
left := c.compileExpression(infix.Left, nil)
620-
right := c.compileExpression(infix.Right, nil)
621-
622-
lhsSyms, cond := c.handleComparisons(infix.Operator, left, right, info, llvm.Value{})
623-
c.requireCondLHSFrame()[key(c.FuncNameMangled, expr)] = lhsSyms
624-
temps := []condTemp{{infix.Left, left}}
625-
c.freeTemporary(infix.Right, right)
620+
cond, temps := c.extractCondExprSelf(infix, info, llvm.Value{}, nil)
626621
c.compileCondExprBranchWithFailure(cond, temps, onTrue, onFalse)
627622
return
628623
}
@@ -645,7 +640,14 @@ func (c *Compiler) withCondLHS(expr ast.Expression, syms []*Symbol, body func())
645640
delete(frame, exprKey)
646641
}
647642

648-
func (c *Compiler) compileLogicalOrCondExprAlternative(expr *ast.InfixExpression, baseCond llvm.Value, onTrue func(), onFalse func()) {
643+
func (c *Compiler) compileLogicalOrCondExprWithFailure(expr *ast.InfixExpression, baseCond llvm.Value, onTrue func(), onFalse func()) {
644+
if !baseCond.IsNil() {
645+
c.compileCondExprBranchWithFailure(baseCond, nil, func() {
646+
c.compileLogicalOrCondExprWithFailure(expr, llvm.Value{}, onTrue, onFalse)
647+
}, onFalse)
648+
return
649+
}
650+
649651
leftTrue := func() {
650652
left := c.compileExpression(expr.Left, nil)
651653
c.withCondLHS(expr, left, onTrue)
@@ -655,8 +657,8 @@ func (c *Compiler) compileLogicalOrCondExprAlternative(expr *ast.InfixExpression
655657
c.withCondLHS(expr, right, onTrue)
656658
}
657659

658-
c.compileCondExprAlternative(expr.Left, baseCond, leftTrue, func() {
659-
c.compileCondExprAlternative(expr.Right, baseCond, rightTrue, onFalse)
660+
c.compileCondExprWithFailure(expr.Left, llvm.Value{}, leftTrue, func() {
661+
c.compileCondExprWithFailure(expr.Right, llvm.Value{}, rightTrue, onFalse)
660662
})
661663
}
662664

compiler/solver.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,9 @@ func (ts *TypeSolver) typeLogicalOrExpression(expr *ast.InfixExpression, left, r
15091509
rightInfo := ts.ExprCache[key(ts.FuncNameMangled, expr.Right)]
15101510

15111511
if ts.InValueExpr {
1512+
// Value-position OR is asymmetric: the left operand must be able to
1513+
// fail, while the right operand may be another condition or a plain
1514+
// fallback value.
15121515
if leftInfo == nil || !leftInfo.HasCondExpr() {
15131516
ts.Errors = append(ts.Errors, &token.CompileError{
15141517
Token: expr.Token,

tests/cond/value_cond_expr.exp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ LogicalOrBeforeExisting: 44
6666
LogicalOrExistingFalse: 44
6767
LogicalOrNestedAdd: 8
6868
LogicalOrCallArg: 14
69+
LogicalOrHeapFallback: fallback
70+
LogicalOrPairFallback: 7 8
71+
LogicalOrNestedSum: 10
6972
LogicalOrDefault: -1
7073
LogicalOrDefaultArg: -2
7174
LogicalOrArrayDefault: [-1 -1 -1 3 4 5 6 7 -1 -1]

tests/cond/value_cond_expr.spt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ lorNestedAdd = (b > 2 || d < 10) + 1
270270
lorCallArg = Double(b > 2 || d < 10)
271271
"LogicalOrCallArg: -lorCallArg"
272272

273+
lorHeapFallback = (sl ⊕ sm) > "hello x" || "fallback"
274+
"LogicalOrHeapFallback: -lorHeapFallback"
275+
276+
lorPairA, lorPairB = Pair(1, 2) > Pair(5, 5) || Pair(7, 8)
277+
"LogicalOrPairFallback: -lorPairA -lorPairB"
278+
279+
lorNestedSum = (b > 2 || d < 10) + (a > 10 || c > 2)
280+
"LogicalOrNestedSum: -lorNestedSum"
281+
273282
lorDefault = b > 2 || -1
274283
"LogicalOrDefault: -lorDefault"
275284

0 commit comments

Comments
 (0)