Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 262 additions & 74 deletions compiler/compiler.go

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,32 @@ res`
require.NotContains(t, scriptIR, mangled+"_ret", "single-scalar return should not use sret struct")
}

func TestDirectScalarParamsStayValuesInCallee(t *testing.T) {
code := `res = Add(x, y)
res = x + y`
script := `res = Add(2, 3)
res`

scriptIR, _ := compileScriptAndCodeIR(t, "direct_param_values", code, script)

require.NotContains(t, scriptIR, "%x = alloca i64", "direct scalar param x should stay in SSA form by default")
require.NotContains(t, scriptIR, "%y = alloca i64", "direct scalar param y should stay in SSA form by default")
require.NotContains(t, scriptIR, "store i64 %0, ptr %x", "callee should not eagerly spill direct scalar param x")
require.NotContains(t, scriptIR, "store i64 %1, ptr %y", "callee should not eagerly spill direct scalar param y")
}

func TestDirectScalarOutputsStayValuesInCallee(t *testing.T) {
code := `sum = Add(x, y)
sum = x + y`
script := `res = Add(2, 3)
res`

scriptIR, _ := compileScriptAndCodeIR(t, "direct_output_values", code, script)

require.NotContains(t, scriptIR, "%sum = alloca i64", "direct scalar output sum should stay in SSA form by default")
require.NotContains(t, scriptIR, "store i64 0, ptr %sum", "callee should not eagerly seed direct scalar output sum via an alloca")
}

func TestPhase1ScalarABIDirectF64(t *testing.T) {
code := `res = AddF(x, y)
res = x + y`
Expand Down Expand Up @@ -102,6 +128,37 @@ res`
require.NotContains(t, scriptIR, mangled+"_ret", "single-scalar range variant should not use sret struct")
}

func TestConditionalDirectCallArgsDoNotPromoteDirectParams(t *testing.T) {
code := `res = Id(x)
res = x

res = CondId(x, y)
res = x
res = y > 0 Id(x)`
script := `res = CondId(2, 1)
res`

scriptIR, _ := compileScriptAndCodeIR(t, "conditional_direct_call_args", code, script)

require.NotContains(t, scriptIR, "%x = alloca i64", "direct call args in conditional lowering should not force param spills")
require.NotContains(t, scriptIR, "%y = alloca i64", "condition-only direct params should not be promoted")
require.NotContains(t, scriptIR, "store i64 %0, ptr %x", "conditional direct call should keep x in value form")
require.NotContains(t, scriptIR, "store i64 %1, ptr %y", "conditional direct call should keep y in value form")
}

func TestRangeBearingDirectOutputsUseSlotsForLoopCarriedState(t *testing.T) {
code := `sum = AccFmt(a, x)
"count-a%n chars"
sum = a + x`
script := `res = 10
res = AccFmt(res, 1:4)
res`

scriptIR, _ := compileScriptAndCodeIR(t, "aliased_direct_output_promotion", code, script)

require.Contains(t, scriptIR, "%sum = alloca i64", "range-bearing direct returns still need an output slot for loop-carried state")
}

func TestInferCallParamTypesUsesScalarizedVariantWhenRangesConsumed(t *testing.T) {
ctx := llvm.NewContext()
defer ctx.Dispose()
Expand Down
40 changes: 27 additions & 13 deletions compiler/cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,43 @@ func (c *Compiler) compileConditions(stmt *ast.LetStatement) (cond llvm.Value, h
return
}

// collectCallArgIdentifiers walks an expression and records identifiers that
// appear inside call argument subexpressions. These identifiers may be promoted
// to memory by compileArgs, so conditional lowering pre-promotes them before
// branching.
func collectCallArgIdentifiers(expr ast.Expression, out map[string]struct{}) {
// collectPromotableCallArgIdentifiers walks an expression and records bare
// identifier call arguments that lower indirectly. Those identifiers may be
// promoted to memory by lowerCallArgs, so conditional lowering pre-promotes
// only that subset before branching.
func (c *Compiler) collectPromotableCallArgIdentifiers(expr ast.Expression, out map[string]struct{}) {
if ce, ok := expr.(*ast.CallExpression); ok {
for _, arg := range ce.Arguments {
if ident, ok := arg.(*ast.Identifier); ok {
out[ident.Value] = struct{}{}
info := c.ExprCache[key(c.FuncNameMangled, ce)]
if info != nil {
paramTypes := c.inferCallParamTypes(info)
mangled := Mangle(c.MangledPath, ce.Function.Value, paramTypes)
if fnInfo := c.FuncCache[mangled]; fnInfo != nil {
abi := classifyFuncABI(paramTypes, fnInfo.OutTypes)
for i, arg := range ce.Arguments {
if abi.Params[i].Mode != ABIParamIndirect {
continue
}
ident, ok := arg.(*ast.Identifier)
if !ok {
continue
}
out[ident.Value] = struct{}{}
}
}
}
}
for _, child := range ast.ExprChildren(expr) {
collectCallArgIdentifiers(child, out)
c.collectPromotableCallArgIdentifiers(child, out)
}
}

// prePromoteConditionalCallArgs promotes local identifiers that are used as call
// arguments so branch codegen does not introduce path-dependent promotions.
// prePromoteConditionalCallArgs promotes local identifiers that are used as
// indirect call arguments so branch codegen does not introduce path-dependent
// promotions.
func (c *Compiler) prePromoteConditionalCallArgs(exprs []ast.Expression) {
argNames := make(map[string]struct{})
for _, expr := range exprs {
collectCallArgIdentifiers(expr, argNames)
c.collectPromotableCallArgIdentifiers(expr, argNames)
}

for name := range argNames {
Expand All @@ -97,7 +111,7 @@ func (c *Compiler) resolveDestSeed(ident *ast.Identifier, outType Type) *Symbol
if !ok {
return c.makeZeroValue(outType)
}
return c.derefIfPointer(existing, ident.Value+"_cond_seed")
return c.valueSymbol(ident.Value, existing, ident.Value+"_cond_seed")
}

func (c *Compiler) createConditionalTempOutputs(stmt *ast.LetStatement) ([]*ast.Identifier, []Type) {
Expand Down
12 changes: 1 addition & 11 deletions compiler/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,7 @@ func (c *Compiler) parseMarker(tok token.Token, value string, runes []rune, i in
}

func (c *Compiler) getIdSym(id string) (*Symbol, bool) {
s, ok := Get(c.Scopes, id)
if ok {
return c.derefIfPointer(s, id+"_load"), ok
}

cc := c.CodeCompiler.Compiler
s, ok = Get(cc.Scopes, id)
if ok {
return c.derefIfPointer(s, id+"_load"), ok
}
return s, ok
return c.namedValueSymbol(id, id+"_load")
}

// assumes we have at least one identifier in ids. CustomSpec is printf specifier %...
Expand Down
2 changes: 2 additions & 0 deletions tests/math/acc.exp
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
25
10
25
10
9 changes: 8 additions & 1 deletion tests/math/acc.pt
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
res = Id(x)
res = x

res = AccViaId(a, x)
tmp = Id(a)
res = tmp + x

res = Acc(a, x)
res = a + x
res = a + x
8 changes: 8 additions & 0 deletions tests/math/acc.spt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ res
res = 10
res = Acc(res, 0:-2)
res

res = 10
res = AccViaId(res, 1:6)
res

res = 10
res = AccViaId(res, 0:-2)
res
4 changes: 4 additions & 0 deletions tests/math/acc_fmt.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
count chars
count chars
count chars
8
3 changes: 3 additions & 0 deletions tests/math/acc_fmt.pt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
res = AccFmt(a, x)
"count-a%n chars"
res = a + x
3 changes: 3 additions & 0 deletions tests/math/acc_fmt.spt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
res = 10
res = AccFmt(res, 1:4)
res
Loading