Skip to content

Commit 567c4fc

Browse files
committed
Address proof-bound delta review
1 parent 541f492 commit 567c4fc

9 files changed

Lines changed: 306 additions & 32 deletions

File tree

readme.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ To solve this, we created **ZKLoRA** a zero-knowledge verification protocol that
3232

3333
This implementation uses a native Halo2 backend for transcript-bound proof artifacts. The v2 proof contract verifies exact quantized LoRA delta correctness for the statement the base user actually sent and received, and binds the proof to a pre-inference adapter manifest. It does not claim an end-to-end proof that the base model computed those activations.
3434

35+
Verifier trust boundary: `expected_adapters` must be obtained and pinned by the verifier out-of-band before inference starts, for example by recording the exact manifest file or digest. A contributor-generated adapter manifest is only a convenience handoff artifact; if it is first generated after inference or supplied only alongside proofs, it is not trusted verifier input.
36+
3537
For detailed information about this research, please refer to [our paper](https://arxiv.org/abs/2501.13965).
3638

3739
<h2 align="center">Quick Usage Instructions</h2>
@@ -45,7 +47,7 @@ pip install zklora
4547

4648
Use `src/scripts/lora_contributor_sample_script.py` to:
4749
- Host LoRA submodules
48-
- Write a pre-inference adapter manifest
50+
- Write a pre-inference adapter manifest for the verifier to pin out-of-band
4951
- Handle inference requests
5052
- Generate proof artifacts
5153

@@ -57,18 +59,36 @@ import time
5759
from zklora import LoRAServer, LoRAServerSocket
5860

5961
def main():
60-
parser = argparse.ArgumentParser()
62+
parser = argparse.ArgumentParser(
63+
description=(
64+
"Run a sample LoRA contributor server and write the adapter manifest "
65+
"that the verifier should pin out-of-band before inference."
66+
)
67+
)
6168
parser.add_argument("--host", default="127.0.0.1")
6269
parser.add_argument("--port_a", type=int, default=30000)
6370
parser.add_argument("--base_model", default="distilgpt2")
6471
parser.add_argument("--lora_model_id", default="ng0-k1/distilgpt2-finetuned-es")
6572
parser.add_argument("--out_dir", default="a-out")
66-
parser.add_argument("--adapter_manifest", default="adapter-manifest.json")
73+
parser.add_argument(
74+
"--adapter_manifest",
75+
default="adapter-manifest.json",
76+
help=(
77+
"Convenience manifest handoff path. The verifier must obtain and pin "
78+
"this manifest out-of-band before inference; a post-inference manifest "
79+
"is not trusted expected_adapters input."
80+
),
81+
)
6782
args = parser.parse_args()
6883

6984
stop_event = threading.Event()
7085
server_obj = LoRAServer(args.base_model, args.lora_model_id, args.out_dir)
7186
server_obj.write_adapter_manifest(args.adapter_manifest)
87+
print(f"[A-Server] wrote adapter manifest => {args.adapter_manifest}")
88+
print(
89+
"[A-Server] verifier must pin this manifest out-of-band before inference; "
90+
"post-inference manifests are not trusted expected_adapters."
91+
)
7292
t = LoRAServerSocket(args.host, args.port_a, server_obj, stop_event)
7393
t.start()
7494

@@ -141,6 +161,8 @@ if __name__=="__main__":
141161

142162
Use `src/scripts/verify_proofs.py` to validate the proof artifacts:
143163

164+
`--expected_adapters` must point to the verifier's pinned pre-inference adapter manifest. Do not accept a contributor manifest that was generated after inference, or first delivered with the proof bundle, as trusted verifier input; it is useful only as a handoff artifact to compare against the pinned expectation.
165+
144166
```python
145167
#!/usr/bin/env python3
146168
"""
@@ -173,7 +195,10 @@ def main():
173195
"--expected_adapters",
174196
type=str,
175197
required=True,
176-
help="Pre-inference adapter manifest JSON agreed by the verifier."
198+
help=(
199+
"Verifier-pinned pre-inference adapter manifest JSON. This must be "
200+
"obtained out-of-band before inference, not first supplied with proofs."
201+
)
177202
)
178203
parser.add_argument(
179204
"--verbose",

src/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ extension-module = ["python", "pyo3/extension-module"]
1515
[dependencies]
1616
halo2_proofs = "0.3.2"
1717
halo2_gadgets = "0.4"
18-
halo2_poseidon = "0.1"
1918
ff = "0.13"
2019
num-bigint = "0.4"
2120
num-integer = "0.1"

src/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ src/
2828
The zero-knowledge proof system in ZKLoRA is built on transcript-bound LoRA delta statements and native Halo2 proofs. The `zk_proof_generator.py` module orchestrates the proof generation process by:
2929

3030
1. Capturing the base user's local transcript of activations and returned LoRA deltas
31-
2. Binding each proof to a pre-inference adapter manifest with a Poseidon adapter commitment
31+
2. Binding each proof to a verifier-pinned pre-inference adapter manifest with a Poseidon adapter commitment
3232
3. Generating native `.zklora.*` proof artifacts for contributor-side LoRA invocations
3333
4. Verifying proof artifacts against both the base user's transcript and expected adapter manifest before accepting a module
3434

35+
The verifier must obtain and pin `expected_adapters` out-of-band before inference starts. Contributor-generated adapter manifests are convenience handoff artifacts only; if a manifest is generated after inference or first delivered alongside proofs, it is not trusted to define the expected adapter.
36+
3537
### Multi-Party Inference Protocol
3638

3739
The MPI system enables interaction between the base model user (B) and LoRA provider (A) through:
@@ -113,4 +115,6 @@ verify_time, num_proofs = batch_verify_proofs(
113115
)
114116
```
115117

118+
In this example, `adapter-manifest.json` is the verifier's pre-inference pinned copy or digest-matched file, not a manifest first generated after inference.
119+
116120
For detailed implementation information, please refer to the individual module documentation.

src/scripts/lora_contributor_sample_script.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,49 @@
66

77

88
def main():
9-
parser = argparse.ArgumentParser()
10-
parser.add_argument("--host", default="127.0.0.1")
11-
parser.add_argument("--port_a", type=int, default=30000)
12-
parser.add_argument("--base_model", default="distilgpt2")
13-
parser.add_argument("--lora_model_id", default="ng0-k1/distilgpt2-finetuned-es")
14-
parser.add_argument("--out_dir", default="proof_artifacts")
15-
parser.add_argument("--adapter_manifest", default="adapter-manifest.json")
9+
parser = argparse.ArgumentParser(
10+
description=(
11+
"Run a sample LoRA contributor server and write the adapter manifest "
12+
"that the verifier should pin out-of-band before inference."
13+
),
14+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
15+
)
16+
parser.add_argument("--host", default="127.0.0.1", help="Contributor bind host.")
17+
parser.add_argument("--port_a", type=int, default=30000, help="Contributor port.")
18+
parser.add_argument(
19+
"--base_model",
20+
default="distilgpt2",
21+
help="Base model name expected for the LoRA adapter.",
22+
)
23+
parser.add_argument(
24+
"--lora_model_id",
25+
default="ng0-k1/distilgpt2-finetuned-es",
26+
help="LoRA model ID or local path served by this contributor.",
27+
)
28+
parser.add_argument(
29+
"--out_dir",
30+
default="proof_artifacts",
31+
help="Directory where native .zklora proof artifacts are written.",
32+
)
33+
parser.add_argument(
34+
"--adapter_manifest",
35+
default="adapter-manifest.json",
36+
help=(
37+
"Convenience manifest handoff path. The verifier must obtain and pin "
38+
"this manifest out-of-band before inference; a post-inference manifest "
39+
"is not trusted expected_adapters input."
40+
),
41+
)
1642
args = parser.parse_args()
1743

1844
stop_event = threading.Event()
1945
server_obj = LoRAServer(args.base_model, args.lora_model_id, args.out_dir)
2046
server_obj.write_adapter_manifest(args.adapter_manifest)
2147
print(f"[A-Server] wrote adapter manifest => {args.adapter_manifest}")
48+
print(
49+
"[A-Server] verifier must pin this manifest out-of-band before inference; "
50+
"post-inference manifests are not trusted expected_adapters."
51+
)
2252
t = LoRAServerSocket(
2353
args.host, args.port_a, server_obj, stop_event, stop_timeout=1.0
2454
)

src/src/lib.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use std::convert::TryInto;
2626

2727
const ADAPTER_COMMITMENT_DOMAIN: u64 = 0x5a4b4c4f5241; // "ZKLORA"
2828
const ADAPTER_COMMITMENT_VERSION: u64 = 1;
29+
// Must match proof_contract.SCHEMA_VERSION; it is hashed into adapter commitments.
30+
const ARTIFACT_SCHEMA_VERSION: u64 = 2;
2931
const FIELD_SAFE_BITS: usize = 250;
3032
const POSEIDON_PAIR_ROWS: usize = 96;
3133

@@ -282,7 +284,7 @@ impl Circuit<Fp> for LoraCircuit {
282284
&config,
283285
&mut adapter_words,
284286
&mut offset,
285-
BigInt::from(2u64),
287+
BigInt::from(ARTIFACT_SCHEMA_VERSION),
286288
"adapter schema version",
287289
)?;
288290
for (label, value) in [
@@ -1371,6 +1373,35 @@ mod tests {
13711373
}
13721374
}
13731375

1376+
fn minimal_circuit() -> LoraCircuit {
1377+
let input = AdapterCommitmentInput {
1378+
schema_version: ARTIFACT_SCHEMA_VERSION,
1379+
in_dim: 1,
1380+
rank: 1,
1381+
out_dim: 1,
1382+
fixed_point: FixedPointConfig {
1383+
scale_bits: 0,
1384+
value_bits: 3,
1385+
intermediate_bits: 4,
1386+
},
1387+
scaling_num: 1,
1388+
scaling_den: 1,
1389+
a: vec![vec![1]],
1390+
b: vec![vec![1]],
1391+
};
1392+
LoraCircuit {
1393+
a: input.a.clone(),
1394+
b: input.b.clone(),
1395+
x: vec![1],
1396+
delta: vec![1],
1397+
fixed_point: input.fixed_point.clone(),
1398+
scaling_num: input.scaling_num,
1399+
scaling_den: input.scaling_den,
1400+
adapter_commitment: adapter_commitment_for_input(&input).unwrap(),
1401+
statement_digest: "22".repeat(32),
1402+
}
1403+
}
1404+
13741405
#[test]
13751406
fn poseidon_adapter_commitment_is_deterministic() {
13761407
let input = adapter_input();
@@ -1409,7 +1440,7 @@ mod tests {
14091440
#[test]
14101441
#[ignore = "IPA proof generation for the Poseidon/range-check circuit is intentionally slow"]
14111442
fn real_proof_verifies_for_tiny_relation() {
1412-
let circuit = valid_circuit();
1443+
let circuit = minimal_circuit();
14131444
let statement = NativeStatement {
14141445
x: circuit.x.clone(),
14151446
delta: circuit.delta.clone(),

src/zklora/base_model_user_mpi/__init__.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import math
23
import socket
34
import uuid
45
from typing import Any
@@ -168,14 +169,33 @@ def forward(self, x: torch.Tensor):
168169
scaling_den = (
169170
self.transcript_recorder.scaling_den if self.transcript_recorder else 1
170171
)
171-
if remote_out is None:
172-
raise RuntimeError(f"[B] submodule '{self.sub_name}' => no output from A.")
173-
out_t = torch.tensor(remote_out, dtype=torch.float32)
172+
if q_delta is not None and self.transcript_recorder is not None:
173+
out_t = _dequantize_q_delta(
174+
q_delta,
175+
self.transcript_recorder.fixed_point,
176+
tuple(base_out.shape),
177+
base_out.device,
178+
base_out.dtype if torch.is_floating_point(base_out) else torch.float32,
179+
)
180+
remote_out_for_record = out_t.detach().cpu().numpy()
181+
elif self.transcript_recorder is not None:
182+
raise RuntimeError(
183+
f"[B] submodule '{self.sub_name}' => proof-bound response missing q_delta."
184+
)
185+
else:
186+
if remote_out is None:
187+
raise RuntimeError(
188+
f"[B] submodule '{self.sub_name}' => no output from A."
189+
)
190+
out_t = torch.tensor(
191+
remote_out, dtype=torch.float32, device=base_out.device
192+
)
193+
remote_out_for_record = remote_out
174194
if self.transcript_recorder is not None:
175195
self.transcript_recorder.record(
176196
self.sub_name,
177197
arr,
178-
remote_out,
198+
remote_out_for_record,
179199
scaling_num=scaling_num,
180200
scaling_den=scaling_den,
181201
q_delta_values=q_delta,
@@ -275,7 +295,9 @@ def _canonical_rows(values):
275295

276296

277297
def _canonical_int_rows(values):
278-
tensor = torch.as_tensor(_to_list(values), dtype=torch.int64)
298+
values = _to_list(values)
299+
_assert_exact_int_values(values)
300+
tensor = torch.as_tensor(values, dtype=torch.int64)
279301
if tensor.numel() == 0:
280302
return []
281303
if tensor.ndim == 0:
@@ -288,6 +310,49 @@ def _canonical_int_rows(values):
288310
]
289311

290312

313+
def _assert_exact_int_values(values):
314+
if isinstance(values, bool):
315+
raise ValueError("q_delta values must be integers, not booleans")
316+
if isinstance(values, int):
317+
return
318+
if isinstance(values, (list, tuple)):
319+
for value in values:
320+
_assert_exact_int_values(value)
321+
return
322+
raise ValueError(f"q_delta values must be integers, got {type(values).__name__}")
323+
324+
325+
def _dequantize_q_delta(
326+
q_delta_values,
327+
fixed_point: FixedPointConfig,
328+
target_shape: tuple[int, ...],
329+
device,
330+
dtype,
331+
):
332+
q_delta_rows = _canonical_int_rows(q_delta_values)
333+
if not q_delta_rows:
334+
raise RuntimeError("Received empty q_delta for proof-bound LoRA response.")
335+
q_delta = torch.tensor(q_delta_rows, dtype=torch.float64)
336+
expected_rows = math.prod(target_shape[:-1]) if target_shape else 1
337+
expected_cols = target_shape[-1] if target_shape else 1
338+
if list(q_delta.shape) != [expected_rows, expected_cols]:
339+
raise RuntimeError(
340+
"q_delta shape does not match local module output rows: "
341+
f"{list(q_delta.shape)} != {[expected_rows, expected_cols]}"
342+
)
343+
expected = math.prod(target_shape)
344+
if q_delta.numel() != expected:
345+
raise RuntimeError(
346+
"q_delta shape does not match local module output: "
347+
f"{list(q_delta.shape)} cannot reshape to {list(target_shape)}"
348+
)
349+
return (
350+
(q_delta / float(fixed_point.scale))
351+
.reshape(target_shape)
352+
.to(device=device, dtype=dtype)
353+
)
354+
355+
291356
class BaseModelClient:
292357
def __init__(
293358
self,

src/zklora/proof_contract.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import hashlib
4+
import importlib
45
import json
56
import os
67
import re
@@ -424,11 +425,19 @@ def _native_witness_json(witness: InvocationWitness) -> str:
424425

425426
def _native_module():
426427
try:
427-
from zklora import _native_prover # type: ignore
428-
429-
return _native_prover
430-
except Exception:
431-
return None
428+
return importlib.import_module("zklora._native_prover")
429+
except ModuleNotFoundError as exc:
430+
if exc.name == "zklora._native_prover":
431+
return None
432+
raise
433+
except ImportError as exc:
434+
raise ProofContractError(
435+
f"failed to import native Halo2 prover: {exc}"
436+
) from exc
437+
except Exception as exc:
438+
raise ProofContractError(
439+
f"failed to import native Halo2 prover: {exc}"
440+
) from exc
432441

433442

434443
def write_invocation_artifacts(

0 commit comments

Comments
 (0)