Skip to content

Commit 6a3d279

Browse files
authored
feat: merge-train/avm (#22823)
BEGIN_COMMIT_OVERRIDE chore(avm)!: remove hack for bb pilcom to see shifted temp columns (#22723) chore(avm): attacker-simulation tests for shift/unshift row-0 guarantees (#22743) END_COMMIT_OVERRIDE
2 parents b3cddb2 + c277fb5 commit 6a3d279

5 files changed

Lines changed: 157 additions & 39 deletions

File tree

barretenberg/cpp/pil/vm2/scalar_mul.pil

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,6 @@ namespace scalar_mul;
168168
end * (temp_y - point_y) = 0;
169169
end * (temp_inf - point_inf) = 0;
170170

171-
// TODO(#AVM-212) Hack for bb-pilcom to see shifted temp columns
172-
temp_x' - temp_x' = 0;
173-
temp_y' - temp_y' = 0;
174-
temp_inf' - temp_inf' = 0;
175-
176171
#[DOUBLE]
177172
sel_not_end { temp_x, temp_y, temp_inf, temp_x', temp_y', temp_inf', sel_not_end /* = 1 */ }
178173
in

barretenberg/cpp/src/barretenberg/vm2/constraining/verifier.test.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#include "barretenberg/vm2/constraining/verifier.hpp"
22
#include "barretenberg/srs/global_crs.hpp"
3+
#include "barretenberg/vm2/common/constants.hpp"
4+
#include "barretenberg/vm2/constraining/check_circuit.hpp"
5+
#include "barretenberg/vm2/constraining/polynomials.hpp"
36
#include "barretenberg/vm2/constraining/prover.hpp"
7+
#include "barretenberg/vm2/generated/columns.hpp"
48
#include "barretenberg/vm2/proving_helper.hpp"
59
#include "barretenberg/vm2/testing/fixtures.hpp"
610

@@ -72,6 +76,130 @@ TEST_F(AvmVerifierTests, NegativeBadPublicInputs)
7276
ASSERT_TRUE(verified) << "native proof verification failed, but should have succeeded";
7377
}
7478

79+
// Attacker simulation: commit honestly to keccak_memory_addr, but in sumcheck use a DIFFERENT
80+
// (independent) polynomial for keccak_memory_addr_shift whose last-row value is non-zero.
81+
// The honest prover's `key_poly.shifted()` shares memory with the unshifted polynomial and its
82+
// end_index is (unshifted.end_index - 1) <= N - 1, so the last-row shifted value is always past
83+
// the shifted polynomial's end and thus virtually zero.
84+
// A malicious prover can replace this shifted view after AvmProver construction to try and
85+
// smuggle a non-zero value at the last row. This test verifies that the PCS (Shplemini) catches
86+
// the mismatch between the malicious shifted evaluation used in sumcheck and the real shift of
87+
// the commitment, causing the verifier to reject.
88+
TEST_F(AvmVerifierTests, ProvingSystemSecurityShiftedLastRowMustBeZero)
89+
{
90+
if (testing::skip_slow_tests()) {
91+
GTEST_SKIP() << "Skipping slow test";
92+
}
93+
94+
auto [trace, public_inputs] = testing::get_minimal_trace_with_pi();
95+
// Capture the number of witness rows before compute_polynomials consumes the trace.
96+
const size_t num_witness_rows = trace.get_num_witness_rows() + 1;
97+
98+
auto polynomials = constraining::compute_polynomials(trace);
99+
auto proving_key = constraining::proving_key_from_polynomials(polynomials);
100+
auto verification_key = std::make_shared<AvmVerifier::VerificationKey>();
101+
102+
AvmProver prover(proving_key, verification_key, proving_key->commitment_key);
103+
104+
// Attacker: overwrite the shifted view with an independent polynomial carrying a non-zero
105+
// value at the last row of the circuit. The shared-memory link to the unshifted polynomial
106+
// is severed, so the unshifted commitment no longer "agrees" with what sumcheck uses.
107+
using Polynomial = AvmFlavor::Polynomial;
108+
auto make_malicious_shift = [] {
109+
Polynomial p(/*size=*/1, /*virtual_size=*/MAX_AVM_TRACE_SIZE, /*start_index=*/MAX_AVM_TRACE_SIZE - 1);
110+
p.at(MAX_AVM_TRACE_SIZE - 1) = FF(FF::modulus - 1);
111+
return p;
112+
};
113+
prover.prover_polynomials.get(ColumnAndShifts::keccak_memory_addr_shift) = make_malicious_shift();
114+
115+
// Sanity: all relations (main + lookup/permutation) still hold with the attacker's
116+
// polynomials. This demonstrates that any subsequent verification failure is NOT due to a
117+
// relation violation but to the proving system's cryptographic shift-consistency check
118+
// catching the forged shifted value.
119+
AvmFlavor::ProverPolynomials check_polys(*proving_key);
120+
check_polys.get(ColumnAndShifts::keccak_memory_addr_shift) = make_malicious_shift();
121+
ASSERT_NO_THROW(constraining::run_check_circuit(check_polys, num_witness_rows, /*skippable_enabled=*/true));
122+
123+
const auto proof = prover.construct_proof();
124+
125+
Verifier verifier;
126+
const bool verified = verifier.verify_proof(proof, public_inputs.to_columns());
127+
128+
ASSERT_FALSE(verified)
129+
<< "verifier accepted a proof where keccak_memory_addr_shift at the last row was forged to be non-zero";
130+
}
131+
132+
// Symmetric attacker simulation for the UNSHIFTED polynomial at its first row (index 0).
133+
// Unlike the shifted-last-row case, this test is DISABLED by default because it cannot run
134+
// against an unmodified barretenberg tree. The attack path uses a polynomial with
135+
// start_index = 0, which triggers invariants enforced by the honest prover's polynomial
136+
// library. To run this test the following safeguards in
137+
// `barretenberg/cpp/src/barretenberg/polynomials/polynomial.cpp` must be relaxed:
138+
//
139+
// 1. `Polynomial::add_scaled` (and `Polynomial::operator+=`): the two asserts
140+
// `BB_ASSERT_LTE(start_index(), other.start_index)` and
141+
// `BB_ASSERT_GTE(end_index(), other.end_index())` fire during the PCS batching step in
142+
// `execute_pcs_rounds` because the accumulator has start_index = 1 while the malicious
143+
// polynomial has start_index = 0. Replace the asserts with a left/right expansion of
144+
// self's backing memory (using `_clone(..., left_expansion)` /
145+
// `_clone(..., right_expansion)`) so that the malicious value at index 0 contributes
146+
// to the batched polynomial consistently.
147+
//
148+
// 2. `Polynomial::shifted`: asserts `start_ >= 1` because the Gemini shift
149+
// (`A_0 = F + G/X`) is only well-defined when the polynomial has zero constant term.
150+
// To keep the attacker-prover running past this point, special-case start_ == 0 by
151+
// cloning the backing memory and dropping the first element. This is the same
152+
// algebraic step that, on the verifier side, makes the proof unverifiable: the
153+
// committed polynomial and the PCS-derived shifted opening cannot both be consistent
154+
// when f[0] != 0.
155+
//
156+
// Once those patches are applied, this test passes (the verifier rejects the forged proof),
157+
// confirming that the proving system structurally enforces `f[0] = 0` for any polynomial
158+
// that is referenced in shifted form — independent of any PIL relation — thanks to the
159+
// non-cyclic multilinear shift in the PCS.
160+
TEST_F(AvmVerifierTests, DISABLED_ProvingSystemSecurityUnshiftedFirstRowMustBeZero)
161+
{
162+
if (testing::skip_slow_tests()) {
163+
GTEST_SKIP() << "Skipping slow test";
164+
}
165+
166+
auto [trace, public_inputs] = testing::get_minimal_trace_with_pi();
167+
const size_t num_witness_rows = trace.get_num_witness_rows() + 1;
168+
169+
auto polynomials = constraining::compute_polynomials(trace);
170+
auto proving_key = constraining::proving_key_from_polynomials(polynomials);
171+
auto verification_key = std::make_shared<AvmVerifier::VerificationKey>();
172+
173+
AvmProver prover(proving_key, verification_key, proving_key->commitment_key);
174+
175+
// Attacker: overwrite the unshifted polynomial with an independent polynomial carrying a
176+
// non-zero value at index 0. The honest shifted view is left untouched in
177+
// prover_polynomials so we can observe what the verifier does with an inconsistent pair.
178+
using Polynomial = AvmFlavor::Polynomial;
179+
auto make_malicious_addr = [] {
180+
Polynomial p(/*size=*/1024, /*virtual_size=*/MAX_AVM_TRACE_SIZE, /*start_index=*/0);
181+
p.at(0) = FF(FF::modulus - 1);
182+
return p;
183+
};
184+
prover.prover_polynomials.get(ColumnAndShifts::keccak_memory_addr) = make_malicious_addr();
185+
186+
// Sanity: all relations (main + lookup/permutation) still hold with the attacker's
187+
// polynomials. This demonstrates that any subsequent verification failure is NOT due to a
188+
// relation violation but to the proving system's cryptographic shift-consistency check
189+
// catching the forged first-row value.
190+
AvmFlavor::ProverPolynomials check_polys(*proving_key);
191+
check_polys.get(ColumnAndShifts::keccak_memory_addr) = make_malicious_addr();
192+
ASSERT_NO_THROW(constraining::run_check_circuit(check_polys, num_witness_rows, /*skippable_enabled=*/true));
193+
194+
const auto proof = prover.construct_proof();
195+
196+
Verifier verifier;
197+
const bool verified = verifier.verify_proof(proof, public_inputs.to_columns());
198+
199+
ASSERT_FALSE(verified)
200+
<< "verifier accepted a proof where keccak_memory_addr at the first row was forged to be non-zero";
201+
}
202+
75203
// Verify that the actual proof size matches COMPUTED_AVM_PROOF_LENGTH_IN_FIELDS
76204
TEST_F(AvmVerifierTests, ProofSizeMatchesComputedConstant)
77205
{

barretenberg/cpp/src/barretenberg/vm2/generated/relations/scalar_mul.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ template <typename FF_> class scalar_mulImpl {
1414
public:
1515
using FF = FF_;
1616

17-
static constexpr std::array<size_t, 29> SUBRELATION_PARTIAL_LENGTHS = { 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3,
18-
3, 3, 3, 3, 2, 2, 2, 4, 4, 4, 3, 4, 4, 4 };
17+
static constexpr std::array<size_t, 26> SUBRELATION_PARTIAL_LENGTHS = { 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3,
18+
3, 3, 3, 3, 3, 3, 4, 4, 4, 3, 4, 4, 4 };
1919

2020
template <typename AllEntities> inline static bool skip(const AllEntities& in)
2121
{

barretenberg/cpp/src/barretenberg/vm2/generated/relations/scalar_mul_impl.hpp

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -140,71 +140,53 @@ void scalar_mulImpl<FF_>::accumulate(ContainerOverSubrelations& evals,
140140
}
141141
{
142142
using View = typename std::tuple_element_t<19, ContainerOverSubrelations>::View;
143-
auto tmp = (static_cast<View>(in.get(C::scalar_mul_temp_x_shift)) -
144-
static_cast<View>(in.get(C::scalar_mul_temp_x_shift)));
145-
std::get<19>(evals) += (tmp * scaling_factor);
146-
}
147-
{
148-
using View = typename std::tuple_element_t<20, ContainerOverSubrelations>::View;
149-
auto tmp = (static_cast<View>(in.get(C::scalar_mul_temp_y_shift)) -
150-
static_cast<View>(in.get(C::scalar_mul_temp_y_shift)));
151-
std::get<20>(evals) += (tmp * scaling_factor);
152-
}
153-
{
154-
using View = typename std::tuple_element_t<21, ContainerOverSubrelations>::View;
155-
auto tmp = (static_cast<View>(in.get(C::scalar_mul_temp_inf_shift)) -
156-
static_cast<View>(in.get(C::scalar_mul_temp_inf_shift)));
157-
std::get<21>(evals) += (tmp * scaling_factor);
158-
}
159-
{
160-
using View = typename std::tuple_element_t<22, ContainerOverSubrelations>::View;
161143
auto tmp = static_cast<View>(in.get(C::scalar_mul_end)) *
162144
((static_cast<View>(in.get(C::scalar_mul_point_x)) * static_cast<View>(in.get(C::scalar_mul_bit)) +
163145
CView(ecc_INFINITY_X) * (FF(1) - static_cast<View>(in.get(C::scalar_mul_bit)))) -
164146
static_cast<View>(in.get(C::scalar_mul_res_x)));
165-
std::get<22>(evals) += (tmp * scaling_factor);
147+
std::get<19>(evals) += (tmp * scaling_factor);
166148
}
167149
{
168-
using View = typename std::tuple_element_t<23, ContainerOverSubrelations>::View;
150+
using View = typename std::tuple_element_t<20, ContainerOverSubrelations>::View;
169151
auto tmp = static_cast<View>(in.get(C::scalar_mul_end)) *
170152
((static_cast<View>(in.get(C::scalar_mul_point_y)) * static_cast<View>(in.get(C::scalar_mul_bit)) +
171153
CView(ecc_INFINITY_Y) * (FF(1) - static_cast<View>(in.get(C::scalar_mul_bit)))) -
172154
static_cast<View>(in.get(C::scalar_mul_res_y)));
173-
std::get<23>(evals) += (tmp * scaling_factor);
155+
std::get<20>(evals) += (tmp * scaling_factor);
174156
}
175157
{
176-
using View = typename std::tuple_element_t<24, ContainerOverSubrelations>::View;
158+
using View = typename std::tuple_element_t<21, ContainerOverSubrelations>::View;
177159
auto tmp = static_cast<View>(in.get(C::scalar_mul_end)) *
178160
(((static_cast<View>(in.get(C::scalar_mul_point_inf)) - FF(1)) *
179161
static_cast<View>(in.get(C::scalar_mul_bit)) +
180162
FF(1)) -
181163
static_cast<View>(in.get(C::scalar_mul_res_inf)));
182-
std::get<24>(evals) += (tmp * scaling_factor);
164+
std::get<21>(evals) += (tmp * scaling_factor);
183165
}
184166
{
185-
using View = typename std::tuple_element_t<25, ContainerOverSubrelations>::View;
167+
using View = typename std::tuple_element_t<22, ContainerOverSubrelations>::View;
186168
auto tmp =
187169
(static_cast<View>(in.get(C::scalar_mul_should_add)) -
188170
static_cast<View>(in.get(C::scalar_mul_sel_not_end)) * static_cast<View>(in.get(C::scalar_mul_bit)));
189-
std::get<25>(evals) += (tmp * scaling_factor);
171+
std::get<22>(evals) += (tmp * scaling_factor);
190172
}
191173
{
192-
using View = typename std::tuple_element_t<26, ContainerOverSubrelations>::View;
174+
using View = typename std::tuple_element_t<23, ContainerOverSubrelations>::View;
193175
auto tmp = CView(scalar_mul_SHOULD_PASS) * (static_cast<View>(in.get(C::scalar_mul_res_x)) -
194176
static_cast<View>(in.get(C::scalar_mul_res_x_shift)));
195-
std::get<26>(evals) += (tmp * scaling_factor);
177+
std::get<23>(evals) += (tmp * scaling_factor);
196178
}
197179
{
198-
using View = typename std::tuple_element_t<27, ContainerOverSubrelations>::View;
180+
using View = typename std::tuple_element_t<24, ContainerOverSubrelations>::View;
199181
auto tmp = CView(scalar_mul_SHOULD_PASS) * (static_cast<View>(in.get(C::scalar_mul_res_y)) -
200182
static_cast<View>(in.get(C::scalar_mul_res_y_shift)));
201-
std::get<27>(evals) += (tmp * scaling_factor);
183+
std::get<24>(evals) += (tmp * scaling_factor);
202184
}
203185
{
204-
using View = typename std::tuple_element_t<28, ContainerOverSubrelations>::View;
186+
using View = typename std::tuple_element_t<25, ContainerOverSubrelations>::View;
205187
auto tmp = CView(scalar_mul_SHOULD_PASS) * (static_cast<View>(in.get(C::scalar_mul_res_inf)) -
206188
static_cast<View>(in.get(C::scalar_mul_res_inf_shift)));
207-
std::get<28>(evals) += (tmp * scaling_factor);
189+
std::get<25>(evals) += (tmp * scaling_factor);
208190
}
209191
}
210192

bb-pilcom/bb-pil-backend/src/vm_builder.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,25 @@ fn get_all_col_names<F: FieldElement>(
157157
.map(|name| sanitize_name(name.as_str()))
158158
.collect_vec(),
159159
);
160+
// Collect shifted column references from every expression in every identity:
161+
// both selectors and tuple expressions on each side. Looking only at
162+
// left.selector misses shifts that appear exclusively in lookup/permutation
163+
// tuples (e.g. `{ ..., temp_x', temp_y', ... } in ...`).
160164
let to_be_shifted = sort_cols(
161165
&get_shifted_polys(
162166
analyzed
163167
.identities_with_inlined_intermediate_polynomials()
164168
.iter()
165-
.map(|i| i.left.selector.clone().unwrap())
169+
.flat_map(|i| {
170+
i.left
171+
.selector
172+
.iter()
173+
.chain(i.left.expressions.iter())
174+
.chain(i.right.selector.iter())
175+
.chain(i.right.expressions.iter())
176+
.cloned()
177+
.collect_vec()
178+
})
166179
.collect_vec(),
167180
)
168181
.iter()

0 commit comments

Comments
 (0)