Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 83 additions & 45 deletions src/package_io/input_parser.star
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,13 @@ def input_parser(plan, input_args):
)

if "zkboost" in result["additional_services"]:
# Inject default mock zkvm if none configured.
if len(result["zkboost_params"]["zkvms"]) == 0:
has_instance_zkvms = False
for instance in result["zkboost_params"]["instances"]:
if len(instance.get("zkvms", [])) > 0:
has_instance_zkvms = True

# Inject default mock zkvm if none configured globally or per instance.
if len(result["zkboost_params"]["zkvms"]) == 0 and not has_instance_zkvms:
result["zkboost_params"]["zkvms"] = [
{
"kind": "mock",
Expand Down Expand Up @@ -529,77 +534,102 @@ def input_parser(plan, input_args):
"reth-zisk",
]
configured_proof_types = []
for idx, zkvm in enumerate(result["zkboost_params"]["zkvms"]):
effective_zkvms = []
effective_ere_zkvms = (
[]
) # All ere zkvms with instance context for GPU validation
for instance_idx, instance in enumerate(result["zkboost_params"]["instances"]):
instance_name = instance.get("name", str(instance_idx))
instance_zkvms = instance.get("zkvms", result["zkboost_params"]["zkvms"])
instance_proof_types = []
for zkvm_idx, zkvm in enumerate(instance_zkvms):
proof_type = zkvm.get("proof_type")
if proof_type in instance_proof_types:
fail(
"zkboost_params.instances[{0}].zkvms[{1}]: duplicate proof_type '{2}' in instance '{3}'".format(
instance_idx,
zkvm_idx,
proof_type,
instance_name,
)
)
instance_proof_types.append(proof_type)
effective_zkvms.append((instance_idx, zkvm_idx, zkvm))
if zkvm.get("kind") == "ere":
effective_ere_zkvms.append(
{
"instance_idx": instance_idx,
"instance_name": instance_name,
"zkvm_idx": zkvm_idx,
"zkvm": zkvm,
}
)

for inst_idx, zkvm_idx, zkvm in effective_zkvms:
kind = zkvm.get("kind")
proof_type = zkvm.get("proof_type")
zkvm_path = "zkboost_params.instances[{0}].zkvms[{1}]".format(
inst_idx, zkvm_idx
)

if kind not in ["mock", "ere", "external", "verifier"]:
fail(
"zkboost_params.zkvms[{0}]: unsupported kind '{1}', please use 'mock', 'ere', 'external', or 'verifier'".format(
idx, kind
"{0}: unsupported kind '{1}', please use 'mock', 'ere', 'external', or 'verifier'".format(
zkvm_path, kind
)
)

if proof_type not in valid_proof_types:
fail(
"zkboost_params.zkvms[{0}]: unsupported proof_type '{1}', please use one of: {2}".format(
idx, proof_type, ", ".join(valid_proof_types)
"{0}: unsupported proof_type '{1}', please use one of: {2}".format(
zkvm_path, proof_type, ", ".join(valid_proof_types)
)
)

if proof_type in configured_proof_types:
fail(
"zkboost_params.zkvms[{0}]: duplicate proof_type '{1}'".format(
idx, proof_type
)
)
configured_proof_types.append(proof_type)
if proof_type not in configured_proof_types:
configured_proof_types.append(proof_type)

proof_timeout = zkvm.get("proof_timeout_secs", 12)
if proof_timeout <= 0:
fail(
"zkboost_params.zkvms[{0}]: proof_timeout_secs must be > 0, got {1}".format(
idx, proof_timeout
"{0}: proof_timeout_secs must be > 0, got {1}".format(
zkvm_path, proof_timeout
)
)

if kind == "external":
if zkvm.get("endpoint", "") == "":
fail(
"zkboost_params.zkvms[{0}]: external zkvm requires 'endpoint'".format(
idx
)
)
fail("{0}: external zkvm requires 'endpoint'".format(zkvm_path))

if kind == "mock":
mock_proving_time = zkvm.get("mock_proving_time")
if mock_proving_time != None:
pt_kind = mock_proving_time.get("kind", "constant")
if pt_kind not in ["constant", "random", "linear"]:
fail(
"zkboost_params.zkvms[{0}]: unsupported mock_proving_time kind '{1}', please use 'constant', 'random' or 'linear'".format(
idx, pt_kind
"{0}: unsupported mock_proving_time kind '{1}', please use 'constant', 'random' or 'linear'".format(
zkvm_path, pt_kind
)
)
if pt_kind == "random":
min_ms = mock_proving_time.get("min_ms", 0)
max_ms = mock_proving_time.get("max_ms", 0)
if min_ms > max_ms:
fail(
"zkboost_params.zkvms[{0}]: mock_proving_time random min_ms ({1}) must be <= max_ms ({2})".format(
idx, min_ms, max_ms
"{0}: mock_proving_time random min_ms ({1}) must be <= max_ms ({2})".format(
zkvm_path, min_ms, max_ms
)
)

mock_proof_size = zkvm.get("mock_proof_size", 128 << 10)
if mock_proof_size < 32:
fail(
"zkboost_params.zkvms[{0}]: mock_proof_size must be >= 32, got {1}".format(
idx, mock_proof_size
"{0}: mock_proof_size must be >= 32, got {1}".format(
zkvm_path, mock_proof_size
)
)

_validate_ere_gpu_config(result["zkboost_params"]["zkvms"])
_validate_ere_gpu_config(effective_ere_zkvms)
_validate_requested_proof_types(result["participants"], configured_proof_types)

if (
Expand Down Expand Up @@ -1123,13 +1153,23 @@ def input_parser(plan, input_args):
)


def _validate_ere_gpu_config(zkvms):
services_using_count = []
gpu_device_usage = {} # device_id -> proof_type
def _validate_ere_gpu_config(ere_zkvms):
"""Validate GPU configuration for all ere zkvms.

for zkvm in zkvms:
if zkvm.get("kind") != "ere":
continue
Args:
ere_zkvms: List of dicts with instance_idx, instance_name, zkvm_idx, zkvm
"""
services_using_count = []
gpu_device_usage = {} # device_id -> instance_name (for error messages)

for entry in ere_zkvms:
instance_idx = entry["instance_idx"]
instance_name = entry["instance_name"]
zkvm_idx = entry["zkvm_idx"]
zkvm = entry["zkvm"]
zkvm_path = "zkboost_params.instances[{0}].zkvms[{1}]".format(
instance_idx, zkvm_idx
)

proof_type = zkvm.get("proof_type")
gpu_cfg = zkvm.get("gpu", {})
Expand All @@ -1141,37 +1181,35 @@ def _validate_ere_gpu_config(zkvms):
# Pre-built ere-server images are CUDA-enabled and require GPU for proving.
if not has_gpu:
fail(
"proof_type '{0}' has kind=ere but no GPU configured. ".format(
proof_type
)
"{0}: kind=ere but no GPU configured. ".format(zkvm_path)
+ "ere-server requires GPU for proving. "
+ "Either add gpu.device_ids or gpu.count, or use 'kind: mock' for testing. "
+ "For verification-only use cases, use 'kind: verifier' instead."
)

# Check: GPU device_id overlap
# Check: GPU device_id overlap across all instances
for device_id in device_ids:
if device_id in gpu_device_usage:
fail(
"GPU device '{0}' is used by multiple ere entries: '{1}' and '{2}'. ".format(
device_id, gpu_device_usage[device_id], proof_type
"{0}: GPU device '{1}' is already used by instance '{2}'. ".format(
zkvm_path, device_id, gpu_device_usage[device_id]
)
+ "Each ere-server requires exclusive GPU access."
)
gpu_device_usage[device_id] = proof_type
gpu_device_usage[device_id] = instance_name

if count > 0 and len(device_ids) == 0:
services_using_count.append(proof_type)
services_using_count.append(instance_name)

# Check: Multiple services using gpu.count without device_ids
if len(services_using_count) > 1:
fail(
"Multiple ere services specify gpu.count without gpu.device_ids: [{0}]. ".format(
"Multiple ere instances specify gpu.count without gpu.device_ids: [{0}]. ".format(
", ".join(services_using_count)
)
+ "Docker assigns GPUs from the same pool when gpu.count is used, so all services "
+ "requesting GPUs this way will receive the same device(s). "
+ "Use gpu.device_ids to explicitly assign distinct GPU(s) to each service instead "
+ "Use gpu.device_ids to explicitly assign distinct GPU(s) to each instance instead "
+ '(e.g. gpu: {{device_ids: ["0"]}} and gpu: {{device_ids: ["1"]}}).'
)

Expand Down Expand Up @@ -1216,7 +1254,7 @@ def _validate_requested_proof_types(participants, configured_proof_types):
"participants[{0}] requests proof_type '{1}' (ID {2}) via --proof-types flag, ".format(
idx, proof_type, proof_type_id
)
+ "but no zkvm is configured for it in zkboost_params.zkvms. "
+ "but no zkvm is configured for it in zkboost_params.instances[*].zkvms. "
+ "Either add a zkvm entry for '{0}' or remove ID {1} from --proof-types. ".format(
proof_type, proof_type_id
)
Expand Down
50 changes: 33 additions & 17 deletions src/zkboost/zkboost_launcher.star
Original file line number Diff line number Diff line change
Expand Up @@ -69,31 +69,40 @@ def launch_zkboost(

# Per-instance zkvms: each instance falls back to the global
# `zkboost_params.zkvms` if no per-instance list is set. Resolve artifacts
# (ere image/elf_url, verifier program_vk_url) for each, then collect every
# `kind: ere` entry across instances so we launch each ere-server once.
# (ere image/elf_url, verifier program_vk_url) for each, then launch an
# ere-server for each `kind: ere` entry. Each instance gets its own
# ere-server(s), enabling separate GPU pools per instance.
# `verifier` entries get no ere-server — zkboost links the in-process
# `ere-verifier-*` crate and only needs the .vk URL downloaded at startup.
instance_zkvms = []
for instance in zkboost_params.instances:
raw = instance.get("zkvms", zkboost_params.zkvms)
instance_zkvms.append(_resolve_zkvm_artifacts(raw, zkboost_params.image))

ere_server_endpoints = {}
# Launch ere-servers per instance. Each instance's ere zkvms get their own
# ere-server, allowing different GPU configurations per instance.
# ere_server_endpoints[instance_index][proof_type] = endpoint
ere_server_endpoints = []
metrics_jobs = []
for resolved in instance_zkvms:
for zkvm in resolved:
for instance_index, instance in enumerate(zkboost_params.instances):
instance_name = instance["name"]
instance_endpoints = {}
for zkvm in instance_zkvms[instance_index]:
if zkvm["kind"] != "ere":
continue

proof_type = zkvm["proof_type"]
if proof_type in ere_server_endpoints:
continue

endpoint = _launch_ere_server(
plan, zkvm, global_node_selectors, tolerations, tempo_otlp_grpc_url
plan,
zkvm,
instance_name,
global_node_selectors,
tolerations,
tempo_otlp_grpc_url,
)
ere_server_endpoints[proof_type] = endpoint
metrics_jobs.append(_get_ere_server_metrics_job(proof_type))
instance_endpoints[proof_type] = endpoint
metrics_jobs.append(_get_ere_server_metrics_job(instance_name, proof_type))
ere_server_endpoints.append(instance_endpoints)

for instance_index, instance in enumerate(zkboost_params.instances):
name = instance["name"]
Expand Down Expand Up @@ -127,7 +136,9 @@ def launch_zkboost(
),
}
if zkvm["kind"] == "ere":
entry["Endpoint"] = ere_server_endpoints[zkvm["proof_type"]]
entry["Endpoint"] = ere_server_endpoints[instance_index][
zkvm["proof_type"]
]
elif zkvm["kind"] == "external":
entry[
"Kind"
Expand Down Expand Up @@ -264,11 +275,15 @@ def get_config(


def _launch_ere_server(
plan, zkvm, global_node_selectors, tolerations, tempo_otlp_grpc_url
plan, zkvm, instance_name, global_node_selectors, tolerations, tempo_otlp_grpc_url
):
"""Launch an ere-server prover service and return its HTTP endpoint."""
"""Launch an ere-server prover service and return its HTTP endpoint.

Each zkboost instance gets its own ere-server(s), named after the instance
and proof_type to allow separate GPU configurations per instance.
"""
proof_type = zkvm["proof_type"]
service_name = "ere-server-{0}".format(proof_type)
service_name = "ere-server-{0}-{1}".format(instance_name, proof_type)
zkvm_kind = _zkvm_kind_from_proof_type(proof_type)

gpu = dict(zkvm.get("gpu", {}))
Expand Down Expand Up @@ -540,15 +555,16 @@ def _parse_cargo_dependency_version(cargo_toml, dependency):
return None


def _get_ere_server_metrics_job(proof_type):
service_name = "ere-server-{0}".format(proof_type)
def _get_ere_server_metrics_job(instance_name, proof_type):
service_name = "ere-server-{0}-{1}".format(instance_name, proof_type)
return {
"Name": service_name,
"Endpoint": "{0}:{1}".format(service_name, ERE_SERVER_PORT),
"MetricsPath": "/metrics",
"Labels": {
"service": service_name,
"client_type": "ere-server",
"instance": instance_name,
"proof_type": proof_type,
},
"ScrapeInterval": "15s",
Expand Down
Loading