Skip to content

Commit afd477d

Browse files
arthurpaulinogabriel-barrett
authored andcommitted
Use named structures for QueryCounts
Replace the anonymous `Nat × Nat × Nat × Nat` quads / `Nat × Nat` pairs shipped from the executor with `GroupStats { groupIdx, totalWidth, uniqueRows, totalHits }` and `MemoryCount { uniqueRows, totalHits }`, bundled into a `QueryCounts { functionStats, memoryCounts }` struct. The FFI extern keeps the tuple shape (so the Rust side can build it without matching a Lean structure ctor); `Bytecode.Toplevel.execute` converts to the structured form at the public API boundary. Callers (`Kernel.lean`, `Statistics.computeStats`) lose the `.1`/`.2.2.1` indexing and use named field access.
1 parent 3cbeab4 commit afd477d

3 files changed

Lines changed: 54 additions & 42 deletions

File tree

Ix/Aiur/Semantics/BytecodeFfi.lean

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,52 +50,69 @@ instance : BEq IOBuffer where
5050
-- via `Std.HashMap.beq_iff_equiv` + `Std.HashMap.Equiv.{refl,symm,trans}`,
5151
-- bypassing the need for `LawfulBEq` on the outer `IOBuffer`.
5252

53-
/-- Per-circuit query counts for one circuit (one per function circuit, then
54-
one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum
55-
of query multiplicities. The difference `totalHits - uniqueRows` is the number
56-
of cache hits. -/
57-
structure QueryCount where
53+
namespace Bytecode.Toplevel
54+
55+
/-- Per-split execution stats for one return-group of one function.
56+
`groupIdx` keys the corresponding entry in `Function.groupNames`. -/
57+
structure GroupStats where
58+
groupIdx : Nat
59+
totalWidth : Nat
5860
uniqueRows : Nat
5961
totalHits : Nat
60-
deriving Inhabited
62+
deriving Inhabited, Repr
6163

62-
namespace Bytecode.Toplevel
64+
/-- Per-memory-size counts. `uniqueRows` is the trace height; `totalHits`
65+
is the sum of multiplicities. `totalHits - uniqueRows` is the cache-hit
66+
count. -/
67+
structure MemoryCount where
68+
uniqueRows : Nat
69+
totalHits : Nat
70+
deriving Inhabited, Repr
6371

64-
/-- Per-function execution stats. One entry per split (return group), keyed
65-
by group index. Each quadruple is
66-
`(groupIdx, totalWidth, uniqueRows, totalHits)`. The display name is looked
67-
up via `Function.groupNames[groupIdx]`. -/
68-
abbrev FunctionStats := Array (Array (Nat × Nat × Nat × Nat))
72+
/-- Per-function execution stats. Outer index is `FunIdx`; inner index is
73+
the `USize` group index used by `Ctrl.return`. -/
74+
abbrev FunctionStats := Array (Array GroupStats)
6975

70-
/-- Per-memory-size `(uniqueRows, totalHits)` pairs. -/
71-
abbrev MemoryCounts := Array (Nat × Nat)
76+
/-- Per-memory-size counts, parallel to `Toplevel.memorySizes`. -/
77+
abbrev MemoryCounts := Array MemoryCount
7278

73-
/-- Query counts shipped back from the Rust executor: per-function split stats
74-
plus per-memory pairs. -/
75-
abbrev QueryCounts := FunctionStats × MemoryCounts
79+
/-- Query counts shipped back from the Rust executor. -/
80+
structure QueryCounts where
81+
functionStats : FunctionStats
82+
memoryCounts : MemoryCounts
83+
deriving Inhabited
7684

85+
/-- Raw FFI tuple shape — kept tuple-flat so the Rust side can build it
86+
without declaring matching Lean structure ctors. `execute` wraps the
87+
result in the structured `QueryCounts` immediately. -/
7788
@[extern "rs_aiur_toplevel_execute"]
7889
private opaque execute' : @& Bytecode.Toplevel →
7990
@& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) →
8091
(ioMap : @& Array (Array G × IOKeyInfo)) →
81-
Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × QueryCounts)
92+
Except String (Array G × (Array G × Array (Array G × IOKeyInfo))
93+
× (Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat)))
8294

8395
/-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`,
8496
returning the raw output of the function, the updated `IOBuffer`, and a
85-
`QueryCounts` (per-function split stats + per-memory `(uniqueRows, totalHits)`
86-
pairs). Returns `Except.error msg` when execution fails (e.g. `assert_eq!`
87-
mismatch from a typechecker rejecting a constant), so callers can recover
88-
instead of crashing. -/
97+
`QueryCounts`. Returns `Except.error msg` when execution fails (e.g.
98+
`assert_eq!` mismatch from a typechecker rejecting a constant), so callers
99+
can recover instead of crashing. -/
89100
def execute (toplevel : @& Bytecode.Toplevel)
90101
(funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) :
91102
Except String (Array G × IOBuffer × QueryCounts) :=
92103
let ioData := ioBuffer.data
93104
let ioMap := ioBuffer.map
94105
match execute' toplevel funIdx args ioData ioMap.toArray with
95106
| .error e => .error e
96-
| .ok (output, (ioData, ioMap), queryCounts) =>
107+
| .ok (output, (ioData, ioMap), rawFn, rawMem) =>
97108
let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅
98-
.ok (output, ⟨ioData, ioMap⟩, queryCounts)
109+
let functionStats : FunctionStats := rawFn.map fun perFn =>
110+
perFn.map fun quad =>
111+
{ groupIdx := quad.1, totalWidth := quad.2.1,
112+
uniqueRows := quad.2.2.1, totalHits := quad.2.2.2 }
113+
let memoryCounts : MemoryCounts := rawMem.map fun pair =>
114+
{ uniqueRows := pair.1, totalHits := pair.2 }
115+
.ok (output, ⟨ioData, ioMap⟩, { functionStats, memoryCounts })
99116

100117
end Bytecode.Toplevel
101118

Ix/Aiur/Statistics.lean

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ def fftCost (w h : Nat) : Float :=
3939
let hf := h.toFloat
4040
wf * hf * (max hf 2.0).log2
4141

42-
def computeStats (compiled : CompiledToplevel)
43-
(functionStats : Array (Array (Nat × Nat × Nat × Nat)))
44-
(memoryCounts : Array (Nat × Nat)) : ExecutionStats :=
42+
def computeStats (compiled : CompiledToplevel) (qc : Bytecode.Toplevel.QueryCounts) :
43+
ExecutionStats :=
4544
let t := compiled.bytecode
4645
-- Invert nameMap to get FunIdx → String
4746
let reverseMap := compiled.nameMap.fold (init := (∅ : Std.HashMap Bytecode.FunIdx String))
@@ -53,24 +52,20 @@ def computeStats (compiled : CompiledToplevel)
5352
if t.functions[i]!.constrained then
5453
let baseName := reverseMap[i]?.getD s!"<fn {i}>"
5554
let groupNames := t.functions[i]!.groupNames
56-
for quad in functionStats[i]! do
57-
let groupIdx := quad.1
58-
let w := quad.2.1
59-
let h := quad.2.2.1
60-
let totalHits := quad.2.2.2
61-
let hits := totalHits - h
62-
let group := groupNames[groupIdx]?.getD ""
55+
for gs in qc.functionStats[i]! do
56+
let group := groupNames[gs.groupIdx]?.getD ""
6357
let name := if group.isEmpty then baseName else s!"{baseName} [{group}]"
64-
acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats }
58+
let hits := gs.totalHits - gs.uniqueRows
59+
acc := acc.push
60+
{ name, width := gs.totalWidth, height := gs.uniqueRows,
61+
cacheHits := hits, fftCost := fftCost gs.totalWidth gs.uniqueRows : CircuitStats }
6562
acc
6663
let memoryCircuits := t.memorySizes.mapIdx fun i size =>
6764
let w := size + 11
68-
let pair := memoryCounts[i]!
69-
let h := pair.1
70-
let totalHits := pair.2
71-
let hits := totalHits - h
72-
{ name := s!"memory[{size}]",
73-
width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats }
65+
let mc := qc.memoryCounts[i]!
66+
let hits := mc.totalHits - mc.uniqueRows
67+
{ name := s!"memory[{size}]", width := w, height := mc.uniqueRows,
68+
cacheHits := hits, fftCost := fftCost w mc.uniqueRows : CircuitStats }
7469
let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost)
7570
let totalFftCost := circuits.foldl (· + ·.fftCost) 0.0
7671
let totalUncachedFftCost := circuits.foldl (fun acc cs => acc + fftCost cs.width (cs.height + cs.cacheHits)) 0.0

Kernel.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ where
5959
if ioBuffer != testCase.expectedIOBuffer then
6060
IO.eprintln s!"{name}: IOBuffer mismatch"
6161
return 1
62-
let stats := Aiur.computeStats compiled queryCounts.1 queryCounts.2
62+
let stats := Aiur.computeStats compiled queryCounts
6363
Aiur.printStats stats
6464
pure 0
6565
interpCheck decls name env : IO UInt32 := do

0 commit comments

Comments
 (0)