Skip to content

Commit a121629

Browse files
authored
Merge pull request #18 from BlackVectorOps/perf/ir-canonicalizer-optimization-4018986549606110049
⚡ Optimize instruction string building in Canonicalizer
2 parents e2e7ec9 + befc54f commit a121629

2 files changed

Lines changed: 195 additions & 60 deletions

File tree

pkg/analysis/ir/benchmark_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package ir_test
2+
3+
import (
4+
"go/token"
5+
"os"
6+
"path/filepath"
7+
"strings"
8+
"testing"
9+
10+
"github.com/BlackVectorOps/semantic_firewall/v3/pkg/analysis/ir"
11+
"golang.org/x/tools/go/packages"
12+
"golang.org/x/tools/go/ssa"
13+
)
14+
15+
func compileForBenchmark(b *testing.B, src, funcName string) *ssa.Function {
16+
b.Helper()
17+
dir, err := os.MkdirTemp("", "ssa-bench-")
18+
if err != nil {
19+
b.Fatalf("failed to create temp dir: %v", err)
20+
}
21+
defer os.RemoveAll(dir)
22+
23+
modPath := filepath.Join(dir, "go.mod")
24+
if err := os.WriteFile(modPath, []byte("module testmod\n\ngo 1.23\n"), 0644); err != nil {
25+
b.Fatalf("failed to create go.mod: %v", err)
26+
}
27+
28+
path := filepath.Join(dir, "main.go")
29+
if err := os.WriteFile(path, []byte(src), 0644); err != nil {
30+
b.Fatalf("write source: %v", err)
31+
}
32+
33+
env := append(os.Environ(), "GO111MODULE=on", "GOPROXY=off", "CGO_ENABLED=0")
34+
35+
cfg := &packages.Config{
36+
Dir: dir,
37+
Mode: packages.LoadAllSyntax,
38+
Fset: token.NewFileSet(),
39+
Env: env,
40+
}
41+
42+
pkgs, err := packages.Load(cfg, "file="+path)
43+
if err != nil {
44+
b.Fatalf("packages.Load: %v", err)
45+
}
46+
if packages.PrintErrors(pkgs) > 0 {
47+
b.Fatal("compilation errors in test source")
48+
}
49+
50+
prog, _, err := ir.BuildSSAFromPackages(pkgs)
51+
if err != nil {
52+
b.Fatalf("BuildSSA: %v", err)
53+
}
54+
55+
for _, pkg := range pkgs {
56+
ssaPkg := prog.Package(pkg.Types)
57+
if ssaPkg == nil {
58+
continue
59+
}
60+
for _, member := range ssaPkg.Members {
61+
if fn, ok := member.(*ssa.Function); ok {
62+
if fn.Name() == funcName || strings.HasSuffix(fn.Name(), "."+funcName) {
63+
return fn
64+
}
65+
}
66+
}
67+
}
68+
69+
b.Fatalf("function %q not found in SSA program", funcName)
70+
return nil
71+
}
72+
73+
func BenchmarkCanonicalizeFunction(b *testing.B) {
74+
// Source code with various instruction types to exercise different paths
75+
src := `package main
76+
77+
func everything(ch chan int, m map[string]interface{}) interface{} {
78+
// Defer
79+
defer func() { recover() }()
80+
81+
res := 0
82+
83+
// Select
84+
select {
85+
case x := <-ch:
86+
// Map Update & Interface
87+
m["val"] = x
88+
res = x
89+
default:
90+
// MakeSlice & Go
91+
go func() { _ = make([]int, 10, 20) }()
92+
res = 1
93+
}
94+
95+
// Type Assert
96+
if val, ok := m["val"].(int); ok {
97+
return val * 2
98+
}
99+
100+
return res
101+
}`
102+
103+
fn := compileForBenchmark(b, src, "everything")
104+
// Use default policy
105+
policy := ir.DefaultLiteralPolicy
106+
107+
b.ResetTimer()
108+
for i := 0; i < b.N; i++ {
109+
// Acquire/Release per iteration to simulate real usage
110+
c := ir.AcquireCanonicalizer(policy)
111+
c.CanonicalizeFunction(fn)
112+
ir.ReleaseCanonicalizer(c)
113+
}
114+
}

pkg/analysis/ir/canonicalizer.go

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ type Canonicalizer struct {
7070
blockMap map[*ssa.BasicBlock]string
7171
regCounter int
7272
output strings.Builder
73+
scratch strings.Builder
7374
virtualInstrs map[ssa.Instruction]*virtualInstr
7475
virtualBlocks map[*ssa.BasicBlock]*virtualBlock
7576
virtualBinOps map[*ssa.BinOp]token.Token
@@ -814,32 +815,42 @@ func isCommutative(instr *ssa.BinOp) bool {
814815
}
815816

816817
func (c *Canonicalizer) processInstruction(instr ssa.Instruction) {
817-
var rhs strings.Builder
818+
c.scratch.Reset()
818819
val, isValue := instr.(ssa.Value)
819820
isControlFlow := false
820821

821822
switch i := instr.(type) {
822823
case *ssa.Call:
823-
rhs.WriteString("Call ")
824-
c.writeCallCommon(&rhs, &i.Call, instr)
824+
c.scratch.WriteString("Call ")
825+
c.writeCallCommon(&c.scratch, &i.Call, instr)
825826
case *ssa.BinOp:
826827
normX := c.NormalizeOperand(i.X, instr)
827828
normY := c.NormalizeOperand(i.Y, instr)
828829
op := c.getVirtualBinOpToken(i)
830+
c.scratch.WriteString("BinOp ")
831+
c.scratch.WriteString(op.String())
832+
c.scratch.WriteString(", ")
829833
if isCommutative(i) && normX > normY {
830-
rhs.WriteString(fmt.Sprintf("BinOp %s, %s, %s", op.String(), normY, normX))
834+
c.scratch.WriteString(normY)
835+
c.scratch.WriteString(", ")
836+
c.scratch.WriteString(normX)
831837
} else {
832-
rhs.WriteString(fmt.Sprintf("BinOp %s, %s, %s", op.String(), normX, normY))
838+
c.scratch.WriteString(normX)
839+
c.scratch.WriteString(", ")
840+
c.scratch.WriteString(normY)
833841
}
834842
case *ssa.UnOp:
835-
rhs.WriteString(fmt.Sprintf("UnOp %s, %s", i.Op.String(), c.NormalizeOperand(i.X, instr)))
843+
c.scratch.WriteString("UnOp ")
844+
c.scratch.WriteString(i.Op.String())
845+
c.scratch.WriteString(", ")
846+
c.scratch.WriteString(c.NormalizeOperand(i.X, instr))
836847
if i.CommaOk {
837-
rhs.WriteString(", CommaOk")
848+
c.scratch.WriteString(", CommaOk")
838849
}
839850
case *ssa.Phi:
840-
c.writePhi(&rhs, i, instr)
851+
c.writePhi(&c.scratch, i, instr)
841852
case *ssa.Alloc:
842-
rhs.WriteString("Alloca ")
853+
c.scratch.WriteString("Alloca ")
843854
handled := false
844855
if ptrType, ok := i.Type().Underlying().(*types.Pointer); ok {
845856
elemType := ptrType.Elem()
@@ -852,130 +863,140 @@ func (c *Canonicalizer) processInstruction(instr ssa.Instruction) {
852863
typeRep = fmt.Sprintf("[<len_literal>]%s", sanitizeType(arrType.Elem()))
853864
}
854865
}
855-
rhs.WriteString(typeRep)
866+
c.scratch.WriteString(typeRep)
856867
handled = true
857868
} else {
858-
rhs.WriteString(sanitizeType(elemType))
869+
c.scratch.WriteString(sanitizeType(elemType))
859870
handled = true
860871
}
861872
}
862873
if !handled {
863-
rhs.WriteString(sanitizeType(i.Type().Underlying()))
874+
c.scratch.WriteString(sanitizeType(i.Type().Underlying()))
864875
}
865876
case *ssa.Store:
866-
rhs.WriteString(fmt.Sprintf("Store %s, %s", c.NormalizeOperand(i.Addr, instr), c.NormalizeOperand(i.Val, instr)))
877+
c.scratch.WriteString("Store ")
878+
c.scratch.WriteString(c.NormalizeOperand(i.Addr, instr))
879+
c.scratch.WriteString(", ")
880+
c.scratch.WriteString(c.NormalizeOperand(i.Val, instr))
867881
case *ssa.If:
868882
isControlFlow = true
869883
succs := c.getVirtualSuccessors(i.Block())
870-
rhs.WriteString(fmt.Sprintf("If %s, %s, %s", c.NormalizeOperand(i.Cond, instr), c.blockMap[succs[0]], c.blockMap[succs[1]]))
884+
c.scratch.WriteString("If ")
885+
c.scratch.WriteString(c.NormalizeOperand(i.Cond, instr))
886+
c.scratch.WriteString(", ")
887+
c.scratch.WriteString(c.blockMap[succs[0]])
888+
c.scratch.WriteString(", ")
889+
c.scratch.WriteString(c.blockMap[succs[1]])
871890
case *ssa.Jump:
872891
isControlFlow = true
873892
if len(i.Block().Succs) > 0 {
874-
rhs.WriteString(fmt.Sprintf("Jump %s", c.blockMap[i.Block().Succs[0]]))
893+
c.scratch.WriteString("Jump ")
894+
c.scratch.WriteString(c.blockMap[i.Block().Succs[0]])
875895
} else {
876-
rhs.WriteString("Jump <invalid>")
896+
c.scratch.WriteString("Jump <invalid>")
877897
}
878898
case *ssa.Return:
879899
isControlFlow = true
880-
rhs.WriteString("Return")
900+
c.scratch.WriteString("Return")
881901
for j, res := range i.Results {
882902
if j > 0 {
883-
rhs.WriteString(",")
903+
c.scratch.WriteString(",")
884904
}
885-
rhs.WriteString(" " + c.NormalizeOperand(res, instr))
905+
c.scratch.WriteString(" ")
906+
c.scratch.WriteString(c.NormalizeOperand(res, instr))
886907
}
887908
case *ssa.IndexAddr:
888-
rhs.WriteString(fmt.Sprintf("IndexAddr %s, %s", c.NormalizeOperand(i.X, instr), c.NormalizeOperand(i.Index, instr)))
909+
c.scratch.WriteString(fmt.Sprintf("IndexAddr %s, %s", c.NormalizeOperand(i.X, instr), c.NormalizeOperand(i.Index, instr)))
889910
case *ssa.Index:
890-
rhs.WriteString(fmt.Sprintf("Index %s, %s", c.NormalizeOperand(i.X, instr), c.NormalizeOperand(i.Index, instr)))
911+
c.scratch.WriteString(fmt.Sprintf("Index %s, %s", c.NormalizeOperand(i.X, instr), c.NormalizeOperand(i.Index, instr)))
891912
case *ssa.Select:
892-
c.writeSelect(&rhs, i, instr)
913+
c.writeSelect(&c.scratch, i, instr)
893914
case *ssa.Range:
894-
rhs.WriteString(fmt.Sprintf("Range %s", c.NormalizeOperand(i.X, instr)))
915+
c.scratch.WriteString(fmt.Sprintf("Range %s", c.NormalizeOperand(i.X, instr)))
895916
case *ssa.Next:
896-
rhs.WriteString(fmt.Sprintf("Next %s", c.NormalizeOperand(i.Iter, instr)))
917+
c.scratch.WriteString(fmt.Sprintf("Next %s", c.NormalizeOperand(i.Iter, instr)))
897918
case *ssa.Extract:
898-
rhs.WriteString(fmt.Sprintf("Extract %s, %d", c.NormalizeOperand(i.Tuple, instr), i.Index))
919+
c.scratch.WriteString(fmt.Sprintf("Extract %s, %d", c.NormalizeOperand(i.Tuple, instr), i.Index))
899920
case *ssa.Slice:
900-
rhs.WriteString(fmt.Sprintf("Slice %s", c.NormalizeOperand(i.X, instr)))
921+
c.scratch.WriteString(fmt.Sprintf("Slice %s", c.NormalizeOperand(i.X, instr)))
901922
if i.Low != nil {
902-
rhs.WriteString(fmt.Sprintf(", Low:%s", c.NormalizeOperand(i.Low, instr)))
923+
c.scratch.WriteString(fmt.Sprintf(", Low:%s", c.NormalizeOperand(i.Low, instr)))
903924
}
904925
if i.High != nil {
905-
rhs.WriteString(fmt.Sprintf(", High:%s", c.NormalizeOperand(i.High, instr)))
926+
c.scratch.WriteString(fmt.Sprintf(", High:%s", c.NormalizeOperand(i.High, instr)))
906927
}
907928
if i.Max != nil {
908-
rhs.WriteString(fmt.Sprintf(", Max:%s", c.NormalizeOperand(i.Max, instr)))
929+
c.scratch.WriteString(fmt.Sprintf(", Max:%s", c.NormalizeOperand(i.Max, instr)))
909930
}
910931
case *ssa.MakeSlice:
911-
rhs.WriteString(fmt.Sprintf("MakeSlice %s, Len:%s, Cap:%s", sanitizeType(i.Type()), c.NormalizeOperand(i.Len, instr), c.NormalizeOperand(i.Cap, instr)))
932+
c.scratch.WriteString(fmt.Sprintf("MakeSlice %s, Len:%s, Cap:%s", sanitizeType(i.Type()), c.NormalizeOperand(i.Len, instr), c.NormalizeOperand(i.Cap, instr)))
912933
case *ssa.MakeMap:
913-
rhs.WriteString(fmt.Sprintf("MakeMap %s", sanitizeType(i.Type())))
934+
c.scratch.WriteString(fmt.Sprintf("MakeMap %s", sanitizeType(i.Type())))
914935
if i.Reserve != nil {
915-
rhs.WriteString(fmt.Sprintf(", Reserve:%s", c.NormalizeOperand(i.Reserve, instr)))
936+
c.scratch.WriteString(fmt.Sprintf(", Reserve:%s", c.NormalizeOperand(i.Reserve, instr)))
916937
}
917938
case *ssa.MapUpdate:
918-
rhs.WriteString(fmt.Sprintf("MapUpdate %s, Key:%s, Val:%s", c.NormalizeOperand(i.Map, instr), c.NormalizeOperand(i.Key, instr), c.NormalizeOperand(i.Value, instr)))
939+
c.scratch.WriteString(fmt.Sprintf("MapUpdate %s, Key:%s, Val:%s", c.NormalizeOperand(i.Map, instr), c.NormalizeOperand(i.Key, instr), c.NormalizeOperand(i.Value, instr)))
919940
case *ssa.Lookup:
920-
rhs.WriteString(fmt.Sprintf("Lookup %s, Key:%s", c.NormalizeOperand(i.X, instr), c.NormalizeOperand(i.Index, instr)))
941+
c.scratch.WriteString(fmt.Sprintf("Lookup %s, Key:%s", c.NormalizeOperand(i.X, instr), c.NormalizeOperand(i.Index, instr)))
921942
if i.CommaOk {
922-
rhs.WriteString(", CommaOk")
943+
c.scratch.WriteString(", CommaOk")
923944
}
924945
case *ssa.TypeAssert:
925-
rhs.WriteString(fmt.Sprintf("TypeAssert %s, AssertedType:%s", c.NormalizeOperand(i.X, instr), sanitizeType(i.AssertedType)))
946+
c.scratch.WriteString(fmt.Sprintf("TypeAssert %s, AssertedType:%s", c.NormalizeOperand(i.X, instr), sanitizeType(i.AssertedType)))
926947
if i.CommaOk {
927-
rhs.WriteString(", CommaOk")
948+
c.scratch.WriteString(", CommaOk")
928949
}
929950
case *ssa.MakeInterface:
930-
rhs.WriteString(fmt.Sprintf("MakeInterface %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
951+
c.scratch.WriteString(fmt.Sprintf("MakeInterface %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
931952
case *ssa.ChangeType:
932-
rhs.WriteString(fmt.Sprintf("ChangeType %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
953+
c.scratch.WriteString(fmt.Sprintf("ChangeType %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
933954
case *ssa.Convert:
934-
rhs.WriteString(fmt.Sprintf("Convert %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
955+
c.scratch.WriteString(fmt.Sprintf("Convert %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
935956
case *ssa.Go:
936-
rhs.WriteString("Go ")
937-
c.writeCallCommon(&rhs, &i.Call, instr)
957+
c.scratch.WriteString("Go ")
958+
c.writeCallCommon(&c.scratch, &i.Call, instr)
938959
case *ssa.Defer:
939-
rhs.WriteString("Defer ")
940-
c.writeCallCommon(&rhs, &i.Call, instr)
960+
c.scratch.WriteString("Defer ")
961+
c.writeCallCommon(&c.scratch, &i.Call, instr)
941962
case *ssa.RunDefers:
942-
rhs.WriteString("RunDefers")
963+
c.scratch.WriteString("RunDefers")
943964
case *ssa.Panic:
944-
rhs.WriteString(fmt.Sprintf("Panic %s", c.NormalizeOperand(i.X, instr)))
965+
c.scratch.WriteString(fmt.Sprintf("Panic %s", c.NormalizeOperand(i.X, instr)))
945966
case *ssa.MakeClosure:
946-
rhs.WriteString(fmt.Sprintf("MakeClosure %s", c.NormalizeOperand(i.Fn, instr)))
967+
c.scratch.WriteString(fmt.Sprintf("MakeClosure %s", c.NormalizeOperand(i.Fn, instr)))
947968
if len(i.Bindings) > 0 {
948-
rhs.WriteString(" [")
969+
c.scratch.WriteString(" [")
949970
for j, binding := range i.Bindings {
950971
if j > 0 {
951-
rhs.WriteString(", ")
972+
c.scratch.WriteString(", ")
952973
}
953-
rhs.WriteString(c.NormalizeOperand(binding, instr))
974+
c.scratch.WriteString(c.NormalizeOperand(binding, instr))
954975
}
955-
rhs.WriteString("]")
976+
c.scratch.WriteString("]")
956977
}
957978
case *ssa.FieldAddr:
958-
rhs.WriteString(fmt.Sprintf("FieldAddr %s, field(%d)", c.NormalizeOperand(i.X, instr), i.Field))
979+
c.scratch.WriteString(fmt.Sprintf("FieldAddr %s, field(%d)", c.NormalizeOperand(i.X, instr), i.Field))
959980
case *ssa.Field:
960-
rhs.WriteString(fmt.Sprintf("Field %s, field(%d)", c.NormalizeOperand(i.X, instr), i.Field))
981+
c.scratch.WriteString(fmt.Sprintf("Field %s, field(%d)", c.NormalizeOperand(i.X, instr), i.Field))
961982
case *ssa.Send:
962-
rhs.WriteString(fmt.Sprintf("Send %s, %s", c.NormalizeOperand(i.Chan, instr), c.NormalizeOperand(i.X, instr)))
983+
c.scratch.WriteString(fmt.Sprintf("Send %s, %s", c.NormalizeOperand(i.Chan, instr), c.NormalizeOperand(i.X, instr)))
963984
case *ssa.MakeChan:
964-
rhs.WriteString(fmt.Sprintf("MakeChan %s, Size:%s", sanitizeType(i.Type()), c.NormalizeOperand(i.Size, instr)))
985+
c.scratch.WriteString(fmt.Sprintf("MakeChan %s, Size:%s", sanitizeType(i.Type()), c.NormalizeOperand(i.Size, instr)))
965986
case *ssa.ChangeInterface:
966-
rhs.WriteString(fmt.Sprintf("ChangeInterface %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
987+
c.scratch.WriteString(fmt.Sprintf("ChangeInterface %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
967988
case *ssa.SliceToArrayPointer:
968-
rhs.WriteString(fmt.Sprintf("SliceToArrayPointer %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
989+
c.scratch.WriteString(fmt.Sprintf("SliceToArrayPointer %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
969990
case *ssa.MultiConvert:
970-
rhs.WriteString(fmt.Sprintf("MultiConvert %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
991+
c.scratch.WriteString(fmt.Sprintf("MultiConvert %s, %s", sanitizeType(i.Type()), c.NormalizeOperand(i.X, instr)))
971992
case *ssa.DebugRef:
972993
return
973994

974995
default:
975996
if c.StrictMode {
976997
panic(fmt.Sprintf("STRICT MODE: Unhandled SSA instruction %T", instr))
977998
}
978-
rhs.WriteString(fmt.Sprintf("UnhandledInstr<%T>", instr))
999+
c.scratch.WriteString(fmt.Sprintf("UnhandledInstr<%T>", instr))
9791000
}
9801001

9811002
c.output.WriteString(" ")
@@ -991,7 +1012,7 @@ func (c *Canonicalizer) processInstruction(instr ssa.Instruction) {
9911012
c.output.WriteString(fmt.Sprintf("%s = ", name))
9921013
}
9931014
}
994-
c.output.WriteString(rhs.String() + "\n")
1015+
c.output.WriteString(c.scratch.String() + "\n")
9951016
}
9961017

9971018
func (c *Canonicalizer) writeSelect(w *strings.Builder, i *ssa.Select, context ssa.Instruction) {

0 commit comments

Comments
 (0)