Skip to content

Commit 3b8dc9a

Browse files
committed
Enhance proof generation by integrating fixed point encoding in vector-matrix multiplication. Add timing for proof generation and improve handling of input data reshaping.
1 parent 962b680 commit 3b8dc9a

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

src/zklora/zk_proof_generator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from onnx import numpy_helper
1414
import torch
1515
from zklora.fp_coding import fixed_point_encode, fixed_point_decode
16+
import plonky3_py as pl
1617

1718

1819
class ProofPaths(NamedTuple):
@@ -265,8 +266,19 @@ async def generate_proofs(
265266

266267
# Flatten to 1D with correct shape
267268
x = np.array(input_data["input_data"], dtype=np.float32)[0]
268-
x_2d = x.reshape(-1, len(W)) # shape: (batch*seq_len, W.shape[0])
269+
x_2d = x.reshape(-1, m) # shape: (batch*seq_len, W.shape[0])
269270
print("batch x tokens × hidden:", x_2d.shape)
271+
272+
for i in range(len(x_2d)):
273+
v = x_2d[i].tolist()
274+
v_encoded = fixed_point_encode(v, fractional_bits=24)
275+
start_time = time.time()
276+
pl.vector_matrix_multiplication_prove(m, n, v_encoded, W_encoded)
277+
end_time = time.time()
278+
if verbose:
279+
print(f"Proof gen took {end_time - start_time:.2f} sec")
280+
total_prove_time += end_time - start_time
281+
270282

271283
else:
272284
raise ValueError(f"Invalid ZK backend: {zk_backend}")

0 commit comments

Comments
 (0)