Skip to content

Commit 02bde9f

Browse files
committed
feat: Pipeline-level ORAM normalization + branchless oram_access_impl
Move ORAM scan from PUSH/POP handlers into dispatch_unit pipeline so every sub-instruction does exactly 1 scan at identical cost: - Add Phase D.oram between operand resolve and handler dispatch - Branchless MUX computes addr/value/direction from decoded opcode: PUSH: addr=vm_sp-8, value=encode(plain_a), write POP: addr=vm_sp, read else: addr=0, dummy read - PUSH/POP handlers no longer call Oram::read/write directly; POP reads from VmExecution::oram_read_result staging field - Remove per-DU dummy_scan (replaced by per-sub-insn scan) Fix oram_access_impl branchless read/write: - Previous code branched on is_write inside the 64-line inner loop - New code always accumulates read result AND computes written word, then selects via bitmask — no data-dependent branches on access type Fix benchmark OramPop setup: LOAD_CONST BB uses fallthrough instead of JMP (which was incorrectly fixed-up to skip PUSH BBs). Fix benchmark runner: use dispatch_unit() for setup instead of step() (step() lacks Phase D.oram, so PUSH setup via step() never wrote ORAM).
1 parent 6735a1b commit 02bde9f

7 files changed

Lines changed: 209 additions & 109 deletions

File tree

runtime/bench/program_factory.cpp

Lines changed: 76 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,37 @@ static TestBB make_nop_bb(uint32_t bb_id, uint8_t epoch_base, uint32_t N) {
7878
return bb;
7979
}
8080

81-
/// Build a setup BB of exactly N insns with the given instructions + NOP padding + JMP.
82-
static TestBB make_setup_bb(uint32_t bb_id, uint8_t epoch_base,
83-
const std::vector<TestInstruction>& setup_insns,
84-
uint32_t target_bb_id, uint32_t N) {
85-
auto bb = make_bb(bb_id, epoch_base);
86-
for (const auto& insn : setup_insns)
87-
bb.instructions.push_back(insn);
88-
// Pad with NOPs, leaving room for JMP at end
89-
while (bb.instructions.size() + 1 < N)
90-
bb.instructions.push_back({VmOpcode::NOP, f_none(), 0, 0, 0});
91-
// JMP to first measured BB (or next setup BB)
92-
bb.instructions.push_back({VmOpcode::JMP, f_none(), 0, 0, target_bb_id});
93-
return bb;
81+
/// Build setup BBs of exactly N insns each.
82+
///
83+
/// For N≥2: one BB with setup_insns + NOP padding + JMP (fits in N insns).
84+
/// For N=1: one BB per setup instruction (no JMP — fallthrough handles it).
85+
///
86+
/// Returns the BBs (may be 1 or many) and updates next_bb_id.
87+
static std::vector<TestBB> make_setup_bbs(uint32_t& next_bb_id,
88+
uint8_t epoch_base,
89+
const std::vector<TestInstruction>& setup_insns,
90+
uint32_t target_bb_id, uint32_t N) {
91+
std::vector<TestBB> result;
92+
93+
if (N >= 2) {
94+
// All setup insns + NOPs + JMP fit in one N-insn BB
95+
auto bb = make_bb(next_bb_id++, epoch_base);
96+
for (const auto& insn : setup_insns)
97+
bb.instructions.push_back(insn);
98+
while (bb.instructions.size() + 1 < N)
99+
bb.instructions.push_back({VmOpcode::NOP, f_none(), 0, 0, 0});
100+
bb.instructions.push_back({VmOpcode::JMP, f_none(), 0, 0, target_bb_id});
101+
result.push_back(std::move(bb));
102+
} else {
103+
// N=1: each setup insn is its own 1-insn BB (fallthrough to next)
104+
for (const auto& insn : setup_insns) {
105+
auto bb = make_bb(next_bb_id++, epoch_base);
106+
bb.instructions.push_back(insn);
107+
result.push_back(std::move(bb));
108+
}
109+
}
110+
111+
return result;
94112
}
95113

96114
// ─── Trivial native for NATIVE_CALL benchmarks ────────────────────────
@@ -169,56 +187,49 @@ DUBenchProgram build_du_program(const OpcodeBenchSpec& spec,
169187
// ── Build setup BBs (untimed) ───────────────────────────────────
170188
uint32_t first_measured_id = 0; // set after setup BBs
171189

190+
// Helper: add setup BBs and update setup_du_count.
191+
auto add_setup = [&](const std::vector<TestInstruction>& insns) {
192+
auto sbs = make_setup_bbs(next_bb_id, 0xA0, insns, 0, N);
193+
prog.setup_du_count += static_cast<uint32_t>(sbs.size());
194+
for (auto& sb : sbs)
195+
bbs.push_back(std::move(sb));
196+
};
197+
172198
switch (spec.setup) {
173-
case Setup::Reg1: {
199+
case Setup::Reg1:
174200
pool.push_back(42);
175-
auto sb = make_setup_bb(next_bb_id++, 0xA0,
176-
{{VmOpcode::LOAD_CONST, f_pool(), spec.reg_a, 0, 0}},
177-
0 /* target set later */, N);
178-
bbs.push_back(std::move(sb));
179-
prog.setup_du_count = 1;
201+
add_setup({{VmOpcode::LOAD_CONST, f_pool(), spec.reg_a, 0, 0}});
180202
break;
181-
}
182-
case Setup::Reg2: {
203+
204+
case Setup::Reg2:
183205
pool.push_back(42);
184206
pool.push_back(3);
185-
auto sb = make_setup_bb(next_bb_id++, 0xA0,
186-
{{VmOpcode::LOAD_CONST, f_pool(), spec.reg_a, 0, 0},
187-
{VmOpcode::LOAD_CONST, f_pool(), spec.reg_b, 0, 1}},
188-
0, N);
189-
bbs.push_back(std::move(sb));
190-
prog.setup_du_count = 1;
207+
add_setup({{VmOpcode::LOAD_CONST, f_pool(), spec.reg_a, 0, 0},
208+
{VmOpcode::LOAD_CONST, f_pool(), spec.reg_b, 0, 1}});
191209
break;
192-
}
193-
case Setup::Memory: {
210+
211+
case Setup::Memory:
194212
pool.push_back(42);
195-
auto sb = make_setup_bb(next_bb_id++, 0xA0,
196-
{{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0},
197-
{VmOpcode::STORE, f_rm(), 0, 0, 0}},
198-
0, N);
199-
bbs.push_back(std::move(sb));
200-
prog.setup_du_count = 1;
213+
add_setup({{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0},
214+
{VmOpcode::STORE, f_rm(), 0, 0, 0}});
201215
break;
202-
}
203-
case Setup::OramPush: {
216+
217+
case Setup::OramPush:
204218
pool.push_back(42);
205-
auto sb = make_setup_bb(next_bb_id++, 0xA0,
206-
{{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0}},
207-
0, N);
208-
bbs.push_back(std::move(sb));
209-
prog.setup_du_count = 1;
219+
add_setup({{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0}});
210220
break;
211-
}
221+
212222
case Setup::OramPop: {
213-
// Need K PUSHes to fill stack, then K POPs measured.
214-
// Setup: 1 LOAD_CONST BB + K PUSH BBs.
223+
// Setup: 1 LOAD_CONST BB + K PUSH BBs, all fallthrough (no JMP).
224+
// The LOAD_CONST BB must NOT use add_setup() because add_setup's
225+
// JMP gets fixed to first_measured_id, skipping the PUSH BBs.
215226
pool.push_back(42);
216-
auto sb = make_setup_bb(next_bb_id++, 0xA0,
217-
{{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0}},
218-
0, N);
219-
bbs.push_back(std::move(sb));
220-
prog.setup_du_count = 1;
221-
227+
{
228+
TestInstruction lc{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0};
229+
bbs.push_back(make_measured_bb(next_bb_id++, 0xA0,
230+
lc, N, false));
231+
prog.setup_du_count++;
232+
}
222233
// K PUSH BBs (each is a DU with 1 PUSH + N-1 NOP)
223234
for (uint32_t i = 0; i < K; ++i) {
224235
TestInstruction push_insn{VmOpcode::PUSH, f_r(), 0, 0, 0};
@@ -228,36 +239,32 @@ DUBenchProgram build_du_program(const OpcodeBenchSpec& spec,
228239
}
229240
break;
230241
}
231-
case Setup::Pool: {
242+
243+
case Setup::Pool:
232244
for (uint32_t i = 0; i < K; ++i)
233245
pool.push_back(i + 100);
234-
// Pool index cycles: aux = i % pool_count
235246
break;
236-
}
237-
case Setup::CtxWrite: {
247+
248+
case Setup::CtxWrite:
238249
pool.push_back(0x800);
239-
auto sb = make_setup_bb(next_bb_id++, 0xA0,
240-
{{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0}},
241-
0, N);
242-
bbs.push_back(std::move(sb));
243-
prog.setup_du_count = 1;
250+
add_setup({{VmOpcode::LOAD_CONST, f_pool(), 0, 0, 0}});
244251
break;
245-
}
246-
case Setup::NativeCall: {
247-
// One transition entry for all NATIVE_CALL instructions.
248-
// All point to the same trivial native.
252+
253+
case Setup::NativeCall:
249254
break;
250-
}
255+
251256
default:
252257
break;
253258
}
254259

255-
// Fix setup BB JMP targets → first measured BB
260+
// Fix setup BB JMP targets → first measured BB (N≥2 only; N=1 uses fallthrough)
256261
first_measured_id = next_bb_id;
257262
for (auto& sb : bbs) {
258-
auto& last = sb.instructions.back();
259-
if (last.opcode == VmOpcode::JMP)
260-
last.aux = first_measured_id;
263+
if (!sb.instructions.empty()) {
264+
auto& last = sb.instructions.back();
265+
if (last.opcode == VmOpcode::JMP)
266+
last.aux = first_measured_id;
267+
}
261268
}
262269

263270
// ── Build K measured BBs ────────────────────────────────────────

runtime/bench/runner.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ std::vector<BenchResult> run_all(const RunConfig& cfg) {
5050
prog.blob.data(), prog.blob.size(), prog.seed, delta);
5151
if (!engine) return 0;
5252

53-
// Untimed: step through setup DUs
54-
for (uint32_t i = 0; i < prog.setup_du_count * N; ++i) {
55-
auto sr = engine->step();
53+
// Untimed: run setup DUs via dispatch_unit (NOT step()).
54+
//
55+
// WHY dispatch_unit: step() does not have the pipeline-level
56+
// ORAM scan (Phase D.oram). PUSH setup via step() would fail
57+
// to write to ORAM, causing POP to read garbage.
58+
for (uint32_t i = 0; i < prog.setup_du_count; ++i) {
59+
auto sr = engine->dispatch_unit();
5660
if (!sr || *sr == VmResult::Halted) return 0;
5761
}
5862

runtime/include/handler_impls.hpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,36 +140,42 @@ struct HandlerTraits<VmOpcode::STORE, P> {
140140
};
141141

142142
/// PUSH: register -> ORAM stack.
143-
/// Encode plaintext to memory domain, write via ORAM. No register result.
143+
///
144+
/// Doc 19 pipeline-level ORAM: the ORAM write has already been executed
145+
/// by Phase D.oram in dispatch_unit (branchless, before handler dispatch).
146+
/// The handler only updates vm_sp. No direct Oram::write call.
144147
template<typename P>
145148
struct HandlerTraits<VmOpcode::PUSH, P> {
146149
static constexpr auto security_class = SecurityClass::A;
147150
using oram_tag = UsesOramTag;
148151
template<typename Oram>
149-
static HandlerResult exec(VmExecution& e, VmEpoch&, VmOramState& o,
150-
const VmImmutable& im, const DecodedInsn& i) noexcept {
152+
static HandlerResult exec(VmExecution& e, VmEpoch&, VmOramState&,
153+
const VmImmutable&, const DecodedInsn&) noexcept {
151154
if (e.vm_sp < 8) return tl::make_unexpected(DiagnosticCode::StackOverflow);
152155
e.vm_sp -= 8;
153-
MemVal mem(im.mem.encode_lut().apply(i.plain_a));
154-
Oram::write(o, e.vm_sp, mem);
156+
// ORAM write already done by pipeline Phase D.oram at (vm_sp - 8)
157+
// with encode_lut().apply(plain_a).
155158
return {};
156159
}
157160
};
158161

159162
/// POP: ORAM stack -> register.
160-
/// Read MemVal from ORAM, decode from memory domain to plaintext.
161-
/// Pipeline will FPE-encode regs[dst].
163+
///
164+
/// Doc 19 pipeline-level ORAM: the ORAM read has already been executed
165+
/// by Phase D.oram in dispatch_unit. The result is in exec.oram_read_result.
166+
/// The handler decodes and writes to the destination register.
162167
template<typename P>
163168
struct HandlerTraits<VmOpcode::POP, P> {
164169
static constexpr auto security_class = SecurityClass::A;
165170
using oram_tag = UsesOramTag;
166171
template<typename Oram>
167-
static HandlerResult exec(VmExecution& e, VmEpoch&, VmOramState& o,
172+
static HandlerResult exec(VmExecution& e, VmEpoch&, VmOramState&,
168173
const VmImmutable& im, const DecodedInsn& i) noexcept {
169174
if (e.vm_sp >= VM_OBLIVIOUS_SIZE) return tl::make_unexpected(DiagnosticCode::StackUnderflow);
170-
MemVal mem = Oram::read(o, e.vm_sp);
175+
// ORAM read already done by pipeline Phase D.oram at vm_sp.
176+
// Result is in e.oram_read_result (raw MemVal bits).
171177
e.vm_sp += 8;
172-
uint64_t plain = im.mem.decode_lut().apply(mem.bits);
178+
uint64_t plain = im.mem.decode_lut().apply(e.oram_read_result);
173179
e.regs[i.reg_a] = RegVal(plain);
174180
return {};
175181
}

runtime/include/oram_strategy.hpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,25 @@ struct RollingKeyOram {
7070
/// Scans all 64 cache lines and re-encrypts entire workspace.
7171
static void write(VmOramState& state, uint64_t offset, MemVal val) noexcept;
7272

73-
/// Unconditional dummy scan (Doc 19 §C.1, ORAM Invariant).
73+
/// Unified ORAM access — always performs a full 64-line scan.
7474
///
75-
/// WHY: every dispatch_unit must produce the same ORAM access pattern
76-
/// regardless of opcode mix. A DU with PUSH/POP triggers real ORAM
77-
/// scans; a DU with only ALU ops does not. The dummy scan ensures
78-
/// at least 1 full scan per DU, normalizing the memory bus frequency
79-
/// to constant rate (Doc 19 Appendix C.4).
75+
/// WHY unified (Doc 19 pipeline-level normalization):
76+
/// The dispatch_unit pipeline calls this once per sub-instruction,
77+
/// unconditionally. Branchless MUX in the pipeline selects the
78+
/// address, value, and direction based on the decoded opcode:
79+
/// PUSH: addr=vm_sp-8, value=encoded, is_write=true
80+
/// POP: addr=vm_sp, value=0, is_write=false
81+
/// else: addr=0, value=0, is_write=false (dummy)
8082
///
81-
/// Implementation: read-equivalent — full 64-line scan + re-encrypt +
82-
/// nonce bump, identical cost to a real read. Result discarded.
83+
/// Every sub-instruction does exactly 1 scan at the same cost.
84+
/// PUSH/POP handlers no longer call read/write directly.
85+
///
86+
/// @return read result (meaningful for POP; 0 for PUSH/dummy)
87+
[[nodiscard]] static uint64_t access(VmOramState& state, uint64_t addr,
88+
uint64_t write_value,
89+
bool is_write) noexcept;
90+
91+
/// Unconditional dummy scan (legacy — kept for backward compatibility).
8392
static void dummy_scan(VmOramState& state) noexcept;
8493
};
8594

@@ -103,8 +112,12 @@ struct DirectOram {
103112
/// Write 8 bytes to workspace at `offset` (direct indexed access).
104113
static void write(VmOramState& state, uint64_t offset, MemVal val) noexcept;
105114

115+
/// Unified access for DirectOram — direct indexed, no oblivious scan.
116+
[[nodiscard]] static uint64_t access(VmOramState& state, uint64_t addr,
117+
uint64_t write_value,
118+
bool is_write) noexcept;
119+
106120
/// No-op dummy scan — DirectOram does not need timing normalization.
107-
/// DebugPolicy::constant_time == false, so timing leaks are acceptable.
108121
static void dummy_scan(VmOramState&) noexcept {}
109122
};
110123

runtime/include/vm_state.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,17 @@ struct alignas(64) VmExecution {
196196

197197
uint64_t trash_regs[VM_REG_COUNT] = {};
198198

199+
// ── ORAM staging (Doc 19 pipeline-level normalization) ──────────────
200+
201+
/// Result from the per-sub-instruction unconditional ORAM scan.
202+
///
203+
/// WHY staging: ORAM scans are moved from PUSH/POP handlers into the
204+
/// dispatch_unit pipeline so every sub-instruction does exactly 1 scan.
205+
/// POP handler reads this field instead of calling Oram::read directly.
206+
/// PUSH handler ignores it (write-only, result is meaningless).
207+
/// NOP/ALU handlers ignore it (dummy scan at offset 0, result discarded).
208+
uint64_t oram_read_result = 0;
209+
199210
// ── Doc 16 forward-secrecy state ────────────────────────────────────
200211

201212
/// Current Speck-FPE key for register encoding/decoding.

0 commit comments

Comments
 (0)