From 4a7a11c245549b41823d097f7ccd208bc4ef2ef0 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 24 May 2026 19:47:22 +0530 Subject: [PATCH 01/13] Implement routed experts delta replay (with branched deltas) Squashed from origin/r3-delta (tip 5c94833f, which extends the earlier 3799bda06 with 'Support branched routed expert deltas' for cases where the routed-experts payload diverges across siblings in a group). Adapts delta replay to main's deferred routed-experts chunk concat: first step starts at 0; extended steps use prefix_len - 1; row 0 fills the boundary, remaining rows append as the new suffix. Bumps router wheel pin to local-path. Bumps deps/verifiers gitlink to d39cc5876. Adds four debug configs for router-replay validation. Co-Authored-By: S1ro1 --- ...wen3_30b_a3b_pd_rlm_swe_router_replay.toml | 126 +++++++++++++ ..._pd_rlm_swe_router_replay_start_audit.toml | 126 +++++++++++++ ...qwen3_30b_a3b_pd_wordle_router_replay.toml | 113 +++++++++++ ...b_pd_wordle_router_replay_start_audit.toml | 113 +++++++++++ deps/verifiers | 2 +- pyproject.toml | 8 +- skills/configs/SKILL.md | 4 + src/prime_rl/inference/vllm/routed_experts.py | 8 +- src/prime_rl/inference/vllm/serving_tokens.py | 6 +- src/prime_rl/orchestrator/trajectories.py | 85 +++++++-- tests/unit/inference/test_serving_tokens.py | 2 +- tests/unit/orchestrator/test_trajectories.py | 137 +++++++++++++- uv.lock | 178 +++++++++++++++++- 13 files changed, 875 insertions(+), 33 deletions(-) create mode 100644 configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay.toml create mode 100644 configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay_start_audit.toml create mode 100644 configs/debug/qwen3_30b_a3b_pd_wordle_router_replay.toml create mode 100644 configs/debug/qwen3_30b_a3b_pd_wordle_router_replay_start_audit.toml diff --git a/configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay.toml b/configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay.toml new file mode 100644 index 0000000000..6d47e248da --- /dev/null +++ b/configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay.toml @@ -0,0 +1,126 @@ +output_dir = "/beegfs/outputs/qwen3-30b-a3b-pd-rlm-swe-router-replay-r3-delta-5step" +clean_output_dir = true +max_steps = 5 +seq_len = 8192 +max_async_level = 1 + +[log] +level = "debug" + +[model] +name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 2 +gpus_per_node = 8 + +[slurm] +job_name = "qwen3-pd-rlm-swe-r3-delta-5" +partition = "cluster" +time = "08:00:00" +exclude = "ltc-idc3-hgx8-h200-[53,88]" + +[wandb] +project = "qwen3-router-replay-e2e" +name = "qwen3-30b-a3b-pd-rlm-swe-r3-delta-5step" +group = "router-replay-e2e" +offline = true +shared = false + +[weight_broadcast] +type = "nccl" +timeout = 3600 + +[trainer] +enable_router_replay = true +max_concurrent_runs = 1 +dist_timeout_seconds = 3600 + +[trainer.model] +impl = "custom" +attn = "flash_attention_3" +optim_cpu_offload = true +ep = 8 + +[trainer.model.ac] +mode = "full" +freq = 1 + +[trainer.model.ac_offloading] +max_inflight_activations = 5 + +[trainer.model.compile] + +[trainer.optim] +type = "adamw" +lr = 1e-6 + +[inference] +gpu_memory_utilization = 0.9 +enable_return_routed_experts = true +enable_eplb = false + +[inference.parallel] +tp = 8 + +[inference.model] +max_model_len = 8192 + +[inference.deployment] +type = "disaggregated" +num_prefill_nodes = 1 +num_decode_nodes = 1 + +[orchestrator] +filters = [] +batch_size = 2 +max_inflight_rollouts = 2 +rollouts_per_example = 1 +max_off_policy_steps = 1 + +[orchestrator.train.sampling] +temperature = 1.0 +repetition_penalty = 1.0 +max_completion_tokens = 256 +min_tokens = 0 + +[orchestrator.client] +extra_headers_from_state = { "X-Session-ID" = "example_id" } + +[[orchestrator.train.env]] +id = "rlm_swe" +name = "rlm-swe-low" +num_workers = 1 +max_retries = 0 +max_total_completion_tokens = 1536 + +[orchestrator.train.env.args] +task_type = "swebench" +dataset_name = "PrimeIntellect/SWE-Bench-Verified-Quick" +max_turns = 6 +rlm_max_turns = 6 +timeout_seconds = 1800 +poll_interval = 5 +sandbox_cpu_cores = 2 +sandbox_memory_gb = 4 +sandbox_disk_size_gb = 4 +sandbox_client_max_workers = 16 +rlm_ref = "f466fccb6bc682092c88edf2b344951d7cbbd000" +labels = ["rlm-r3-delta-swe"] + +[orchestrator.student.client] +timeout = 1800 +wait_for_ready_timeout = 1800 + +[orchestrator.renderer] +name = "qwen3" +preserve_all_thinking = true + +[orchestrator.eval] +env = [] + +[orchestrator.buffer] +easy_threshold = 1.0 +hard_threshold = 0.0 diff --git a/configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay_start_audit.toml b/configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay_start_audit.toml new file mode 100644 index 0000000000..d3866f8721 --- /dev/null +++ b/configs/debug/qwen3_30b_a3b_pd_rlm_swe_router_replay_start_audit.toml @@ -0,0 +1,126 @@ +output_dir = "/beegfs/outputs/qwen3-30b-a3b-pd-rlm-swe-router-replay-start-audit" +clean_output_dir = true +max_steps = 5 +seq_len = 8192 +max_async_level = 1 + +[log] +level = "debug" + +[model] +name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 2 +gpus_per_node = 8 + +[slurm] +job_name = "qwen3-pd-rlm-swe-start-audit" +partition = "cluster" +time = "08:00:00" +exclude = "ltc-idc3-hgx8-h200-[53,88]" + +[wandb] +project = "qwen3-router-replay-e2e" +name = "qwen3-30b-a3b-pd-rlm-swe-start-audit" +group = "router-replay-e2e" +offline = true +shared = false + +[weight_broadcast] +type = "nccl" +timeout = 3600 + +[trainer] +enable_router_replay = true +max_concurrent_runs = 1 +dist_timeout_seconds = 3600 + +[trainer.model] +impl = "custom" +attn = "flash_attention_3" +optim_cpu_offload = true +ep = 8 + +[trainer.model.ac] +mode = "full" +freq = 1 + +[trainer.model.ac_offloading] +max_inflight_activations = 5 + +[trainer.model.compile] + +[trainer.optim] +type = "adamw" +lr = 1e-6 + +[inference] +gpu_memory_utilization = 0.9 +enable_return_routed_experts = true +enable_eplb = false + +[inference.parallel] +tp = 8 + +[inference.model] +max_model_len = 8192 + +[inference.deployment] +type = "disaggregated" +num_prefill_nodes = 1 +num_decode_nodes = 1 + +[orchestrator] +filters = [] +batch_size = 2 +max_inflight_rollouts = 2 +rollouts_per_example = 1 +max_off_policy_steps = 1 + +[orchestrator.train.sampling] +temperature = 1.0 +repetition_penalty = 1.0 +max_completion_tokens = 256 +min_tokens = 0 + +[orchestrator.client] +extra_headers_from_state = { "X-Session-ID" = "example_id" } + +[[orchestrator.train.env]] +id = "rlm_swe" +name = "rlm-swe-low" +num_workers = 1 +max_retries = 0 +max_total_completion_tokens = 1536 + +[orchestrator.train.env.args] +task_type = "swebench" +dataset_name = "PrimeIntellect/SWE-Bench-Verified-Quick" +max_turns = 6 +rlm_max_turns = 6 +timeout_seconds = 1800 +poll_interval = 5 +sandbox_cpu_cores = 2 +sandbox_memory_gb = 4 +sandbox_disk_size_gb = 4 +sandbox_client_max_workers = 16 +rlm_ref = "f466fccb6bc682092c88edf2b344951d7cbbd000" +labels = ["rlm-r3-delta-swe-start-audit"] + +[orchestrator.student.client] +timeout = 1800 +wait_for_ready_timeout = 1800 + +[orchestrator.renderer] +name = "qwen3" +preserve_all_thinking = true + +[orchestrator.eval] +env = [] + +[orchestrator.buffer] +easy_threshold = 1.0 +hard_threshold = 0.0 diff --git a/configs/debug/qwen3_30b_a3b_pd_wordle_router_replay.toml b/configs/debug/qwen3_30b_a3b_pd_wordle_router_replay.toml new file mode 100644 index 0000000000..240bc027d7 --- /dev/null +++ b/configs/debug/qwen3_30b_a3b_pd_wordle_router_replay.toml @@ -0,0 +1,113 @@ +output_dir = "/beegfs/outputs/qwen3-30b-a3b-pd-wordle-router-replay-r3-delta-10step" +clean_output_dir = true +max_steps = 10 +seq_len = 4096 +max_async_level = 1 + +[log] +level = "debug" + +[model] +name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 2 +gpus_per_node = 8 + +[slurm] +job_name = "qwen3-pd-wordle-r3-delta-10" +partition = "cluster" +time = "06:00:00" +exclude = "ltc-idc3-hgx8-h200-[53,88]" + +[wandb] +project = "qwen3-router-replay-e2e" +name = "qwen3-30b-a3b-pd-wordle-r3-delta-10step" +group = "router-replay-e2e" +offline = true +shared = false + +[weight_broadcast] +type = "nccl" +timeout = 3600 + +[trainer] +enable_router_replay = true +max_concurrent_runs = 1 +dist_timeout_seconds = 3600 + +[trainer.model] +impl = "custom" +attn = "flash_attention_3" +optim_cpu_offload = true +ep = 8 + +[trainer.model.ac] +mode = "full" +freq = 1 + +[trainer.model.ac_offloading] +max_inflight_activations = 5 + +[trainer.model.compile] + +[trainer.optim] +type = "adamw" +lr = 1e-6 + +[inference] +gpu_memory_utilization = 0.9 +enable_return_routed_experts = true +enable_eplb = false + +[inference.parallel] +tp = 8 + +[inference.model] +max_model_len = 4096 + +[inference.deployment] +type = "disaggregated" +num_prefill_nodes = 1 +num_decode_nodes = 1 + +[orchestrator] +filters = [] +batch_size = 8 +max_inflight_rollouts = 8 +rollouts_per_example = 1 +max_off_policy_steps = 2 + +[orchestrator.train.sampling] +temperature = 1.0 +repetition_penalty = 1.0 +max_completion_tokens = 512 +min_tokens = 0 + +[[orchestrator.train.env]] +id = "wordle" +name = "wordle" +num_workers = 1 +max_retries = 0 +max_total_completion_tokens = 1024 + +[orchestrator.train.env.args] +num_train_examples = 8 +num_eval_examples = 4 + +[orchestrator.train.env.extra_env_kwargs] +max_total_completion_tokens = 1024 +max_seq_len = 4096 + +[orchestrator.student.client] +timeout = 1200 +wait_for_ready_timeout = 1800 + +[orchestrator.renderer] +preserve_all_thinking = true + +[orchestrator.buffer] +easy_threshold = 1.0 +hard_threshold = 0.0 diff --git a/configs/debug/qwen3_30b_a3b_pd_wordle_router_replay_start_audit.toml b/configs/debug/qwen3_30b_a3b_pd_wordle_router_replay_start_audit.toml new file mode 100644 index 0000000000..8949cc05f2 --- /dev/null +++ b/configs/debug/qwen3_30b_a3b_pd_wordle_router_replay_start_audit.toml @@ -0,0 +1,113 @@ +output_dir = "/beegfs/outputs/qwen3-30b-a3b-pd-wordle-router-replay-start-audit" +clean_output_dir = true +max_steps = 3 +seq_len = 4096 +max_async_level = 1 + +[log] +level = "debug" + +[model] +name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 2 +gpus_per_node = 8 + +[slurm] +job_name = "qwen3-pd-wordle-start-audit" +partition = "cluster" +time = "03:00:00" +exclude = "ltc-idc3-hgx8-h200-[53,88]" + +[wandb] +project = "qwen3-router-replay-e2e" +name = "qwen3-30b-a3b-pd-wordle-start-audit" +group = "router-replay-e2e" +offline = true +shared = false + +[weight_broadcast] +type = "nccl" +timeout = 3600 + +[trainer] +enable_router_replay = true +max_concurrent_runs = 1 +dist_timeout_seconds = 3600 + +[trainer.model] +impl = "custom" +attn = "flash_attention_3" +optim_cpu_offload = true +ep = 8 + +[trainer.model.ac] +mode = "full" +freq = 1 + +[trainer.model.ac_offloading] +max_inflight_activations = 5 + +[trainer.model.compile] + +[trainer.optim] +type = "adamw" +lr = 1e-6 + +[inference] +gpu_memory_utilization = 0.9 +enable_return_routed_experts = true +enable_eplb = false + +[inference.parallel] +tp = 8 + +[inference.model] +max_model_len = 4096 + +[inference.deployment] +type = "disaggregated" +num_prefill_nodes = 1 +num_decode_nodes = 1 + +[orchestrator] +filters = [] +batch_size = 8 +max_inflight_rollouts = 8 +rollouts_per_example = 1 +max_off_policy_steps = 2 + +[orchestrator.train.sampling] +temperature = 1.0 +repetition_penalty = 1.0 +max_completion_tokens = 512 +min_tokens = 0 + +[[orchestrator.train.env]] +id = "wordle" +name = "wordle" +num_workers = 1 +max_retries = 0 +max_total_completion_tokens = 1024 + +[orchestrator.train.env.args] +num_train_examples = 8 +num_eval_examples = 4 + +[orchestrator.train.env.extra_env_kwargs] +max_total_completion_tokens = 1024 +max_seq_len = 4096 + +[orchestrator.student.client] +timeout = 1200 +wait_for_ready_timeout = 1800 + +[orchestrator.renderer] +preserve_all_thinking = true + +[orchestrator.buffer] +easy_threshold = 1.0 +hard_threshold = 0.0 diff --git a/deps/verifiers b/deps/verifiers index 04651bec3e..d39cc5876a 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 04651bec3e87b3a4e116adbc7317602307de8a31 +Subproject commit d39cc5876a8595cb021746d67c7a088652e872e0 diff --git a/pyproject.toml b/pyproject.toml index 1553034690..827e6926e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,10 +83,12 @@ envs = [ "opencode-science", "opencode-swe", "reverse-text", + "rlm-swe", "science-env", "simpleqa-verified", "tau2-bench", "wiki-search", + "wordle", ] disagg = [ "deep-ep ; platform_machine == 'x86_64'", @@ -132,6 +134,7 @@ members = [ "deps/verifiers/environments/math_python", "deps/verifiers/environments/reverse_text", "deps/verifiers/environments/wiki_search", + "deps/verifiers/environments/wordle", "deps/research-environments/environments/aime2024", "deps/research-environments/environments/aime2025", "deps/research-environments/environments/code_env", @@ -152,6 +155,7 @@ members = [ "deps/research-environments/environments/opencode_math", "deps/research-environments/environments/opencode_science", "deps/research-environments/environments/opencode_swe", + "deps/research-environments/environments/rlm_swe", "deps/research-environments/environments/science_env", "deps/research-environments/environments/simpleqa_verified", "deps/research-environments/environments/tau2_bench", @@ -220,10 +224,12 @@ opencode-math = { workspace = true } opencode-science = { workspace = true } opencode-swe = { workspace = true } reverse-text = { workspace = true } +rlm-swe = { workspace = true } science-env = { workspace = true } simpleqa-verified = { workspace = true } tau2-bench = { workspace = true } wiki-search = { workspace = true } +wordle = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } @@ -232,7 +238,7 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } prime-pydantic-config = { workspace = true } -vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } +vllm-router = { path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" } vllm = [ { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" }, diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 83f7dd8d47..ab982f823c 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -70,6 +70,10 @@ For rollout debugging, enable trainer-side token export under `trainer.experimen Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank. +## RLM SWE harness args + +For `rlm_swe` / `rlm-swe` configs using the composable RLM harness, use current harness kwargs such as `rlm_max_turns`, `rlm_exec_timeout`, `rlm_max_depth`, `summarize_at_tokens`, `rlm_ref`, `local_checkout`, `append_to_system_prompt`, and `rlm_tools`. Do not use the stale `rlm_max_turns_in_context` key with the composable harness; it is not accepted by `rlm_harness`. + ## Key files - `packages/prime-rl-configs/src/prime_rl/` — config classes under `configs/`; `utils/config.py` re-exports `BaseConfig` and `cli` diff --git a/src/prime_rl/inference/vllm/routed_experts.py b/src/prime_rl/inference/vllm/routed_experts.py index cad97e8574..e9ef70b049 100644 --- a/src/prime_rl/inference/vllm/routed_experts.py +++ b/src/prime_rl/inference/vllm/routed_experts.py @@ -8,7 +8,7 @@ from vllm.outputs import RequestOutput -def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None: +def serialize_routed_experts(routed_experts: Any, start: int = 0) -> dict[str, Any] | None: if routed_experts is None: return None @@ -23,18 +23,20 @@ def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None: return { "data": pybase64.b64encode(memoryview(compact)).decode("ascii"), "shape": list(compact.shape), + "start": start, } class RoutedExpertsCapture: - def __init__(self, generator: AsyncIterator[RequestOutput]): + def __init__(self, generator: AsyncIterator[RequestOutput], start: int = 0): self._generator = generator + self._start = start self.routed_experts: dict[int, dict[str, Any]] = {} async def __aiter__(self): async for request_output in self._generator: for output in request_output.outputs: - encoded = serialize_routed_experts(getattr(output, "routed_experts", None)) + encoded = serialize_routed_experts(getattr(output, "routed_experts", None), start=self._start) if encoded is not None: self.routed_experts[output.index] = encoded yield request_output diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index afaabef0e6..fb15aa4dc8 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -266,7 +266,11 @@ async def serve_tokens_full_generator( # type: ignore[override] # experts surface in the JSON. capture: _GenerateRoutedExpertsCapture | None = None if self.model_config.enable_return_routed_experts: - capture = _GenerateRoutedExpertsCapture(result_generator) + start = request.sampling_params.routed_experts_prompt_start + capture = _GenerateRoutedExpertsCapture( + result_generator, + start=start, + ) result_generator = capture response = await super().serve_tokens_full_generator( diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 42af5ac042..f1c164f2c8 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -244,11 +244,13 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any if tokens is not None: routed_experts_payload = tokens.get("routed_experts") routed_experts = None + routed_experts_start = None if routed_experts_payload is not None: decoded_routed_experts = pybase64.b64decode_as_bytearray(routed_experts_payload["data"]) routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape( routed_experts_payload["shape"] ) + routed_experts_start = routed_experts_payload["start"] return { "prompt_ids": list(tokens["prompt_ids"]), @@ -257,6 +259,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any "completion_mask": list(map(bool, tokens["completion_mask"])), "completion_logprobs": list(tokens["completion_logprobs"]), "routed_experts": routed_experts, + "routed_experts_start": routed_experts_start, # Renderer-emitted multimodal sidecar (placeholders + per-item # processed tensors). Populated when the rollout went through # a multimodal-aware renderer (e.g. Qwen3VLRenderer); absent @@ -277,6 +280,12 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any # Deferred routed_experts state per sample: O(N) chunk list concatenated # once at finalize, replacing the prior O(N²) per-extension unpack/repack. sample_routed_state: dict[int, dict[str, Any]] = {} + routed_prefix_states: dict[int, list[tuple[list[int], list[int], dict[str, Any]]]] = {} + + # Track (prefix_tokens, sample, step_indices) per active sample. step_indices + # is the explicit list of prepared_steps positions merged into this sample — + # non-contiguous when other agents' steps interleave. + active_samples: list[tuple[list[int], TrainingSample, list[int]]] = [] def make_sample(tokens: dict[str, Any]) -> TrainingSample: """Create a new TrainingSample from a trajectory step.""" @@ -306,9 +315,37 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: # each extension is a no-op append rather than a destructive write. step_routed = tokens.get("routed_experts") if step_routed is not None: + routed_start = tokens["routed_experts_start"] + assert routed_start is not None + chunks: list[np.ndarray] = [] + running_len = 0 + if routed_start > 0: + source_len = routed_start + 1 + source_state = None + for prompt_ids, completion_ids, candidate_state in routed_prefix_states[source_len]: + prompt_len = len(prompt_ids) + if ( + tokens["prompt_ids"][:prompt_len] == prompt_ids + and tokens["prompt_ids"][prompt_len:source_len] == completion_ids + ): + source_state = candidate_state + break + assert source_state is not None + assert source_state["running_len"] >= routed_start + remaining = routed_start + for chunk in source_state["chunks"]: + if remaining == 0: + break + take = min(remaining, int(chunk.shape[0])) + chunks.append(chunk[:take]) + remaining -= take + assert remaining == 0 + running_len = routed_start + chunks.append(step_routed) + running_len += int(step_routed.shape[0]) sample_routed_state[id(sample)] = { - "chunks": [step_routed], - "running_len": int(step_routed.shape[0]), + "chunks": chunks, + "running_len": running_len, } return sample @@ -339,30 +376,31 @@ def extend_sample( step_routed = tokens.get("routed_experts") state = sample_routed_state.get(id(sample)) - if step_routed is not None and state is not None: - # vLLM doesn't capture a routing decision for the *last* token of any - # request, so the previous step left no entry for token at index - # (prefix_len - 1). The next step's forward pass *did* process that - # token (as part of its prompt) and produced step_routed[prefix_len-1]. - # Append that single boundary entry as its own chunk, then append the - # genuinely new entries from this step. No prior bytes touched. - if prefix_len > 0 and prefix_len <= step_routed.shape[0]: - boundary_chunk = step_routed[prefix_len - 1 : prefix_len] + if state is not None: + assert step_routed is not None + if step_routed is not None: + assert state is not None + assert tokens["routed_experts_start"] == prefix_len - 1 + # Delta payloads start at prefix_len - 1. Row 0 fills the boundary + # token missing from the previous request; the rest is the new suffix. + if prefix_len > 0: + boundary_chunk = step_routed[:1] state["chunks"].append(boundary_chunk) state["running_len"] += 1 - new_chunk = step_routed[prefix_len:] + step_routed = step_routed[1:] + new_chunk = step_routed state["chunks"].append(new_chunk) state["running_len"] += int(new_chunk.shape[0]) - # Track (prefix_tokens, sample, step_indices) per active sample. step_indices - # is the explicit list of prepared_steps positions merged into this sample — - # non-contiguous when other agents' steps interleave. - active_samples: list[tuple[list[int], TrainingSample, list[int]]] = [] - first_tokens = prepared_steps[0] first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] first_sample = make_sample(first_tokens) active_samples.append((first_prefix, first_sample, [0])) + first_routed_state = sample_routed_state.get(id(first_sample)) + if first_routed_state is not None: + routed_prefix_states.setdefault(len(first_prefix), []).append( + (first_tokens["prompt_ids"], first_tokens["completion_ids"], first_routed_state) + ) for step_idx, _step in enumerate(trajectory[1:], start=1): tokens = prepared_steps[step_idx] @@ -379,11 +417,17 @@ def extend_sample( # Extension holds - merge into matched sample prefix_tokens, sample, step_indices = active_samples[matched_idx] extend_sample(sample, len(prefix_tokens), step_idx=step_idx) + new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] active_samples[matched_idx] = ( - tokens["prompt_ids"] + tokens["completion_ids"], + new_prefix, sample, step_indices + [step_idx], ) + routed_state = sample_routed_state.get(id(sample)) + if routed_state is not None: + routed_prefix_states.setdefault(len(new_prefix), []).append( + (tokens["prompt_ids"], tokens["completion_ids"], routed_state) + ) else: # No prefix matches - start a new sample logger.debug( @@ -393,6 +437,11 @@ def extend_sample( new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] sample = make_sample(tokens) active_samples.append((new_prefix, sample, [step_idx])) + routed_state = sample_routed_state.get(id(sample)) + if routed_state is not None: + routed_prefix_states.setdefault(len(new_prefix), []).append( + (tokens["prompt_ids"], tokens["completion_ids"], routed_state) + ) # Finalize routed_experts for each sample. One concat per sample (O(N) byte # work) replaces the previous per-step unpack/concat/repack (O(N²)). The diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index 1882e57e55..5045850928 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -73,7 +73,7 @@ def test_serialize_routed_experts_uses_compact_raw_payload(): def test_generate_response_post_process_replaces_upstream_routed_experts(): - compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3]} + compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3], "start": 0} capture = _GenerateRoutedExpertsCapture(_empty_request_outputs()) capture.routed_experts[0] = compact_routed_experts response = GenerateResponse( diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 36c9ef1008..1ecf2e3c46 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -31,11 +31,12 @@ def _decode_mm_thw(sample) -> list: return np.frombuffer(g.data, dtype=np.dtype(g.dtype)).reshape(g.shape).tolist() -def _routed_experts_payload(data) -> dict: +def _routed_experts_payload(data, start: int = 0) -> dict: arr = np.asarray(data, dtype=np.uint8) return { "data": pybase64.b64encode(memoryview(np.ascontiguousarray(arr))).decode("ascii"), "shape": list(arr.shape), + "start": start, } @@ -910,8 +911,9 @@ def test_interleave_rollout_multi_step_with_routed_experts(): """Routed experts are extended and aligned across multi-step trajectories.""" # Step 1: prompt=[1,2], completion=[3,4] -> 4 tokens, vLLM returns 3 step1_experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) - # Step 2: prompt=[1,2,3,4,5,6], completion=[7,8] -> 8 tokens, vLLM returns 7 - step2_experts = np.asarray([[[1, 0]], [[2, 0]], [[3, 0]], [[4, 0]], [[5, 0]], [[6, 0]], [[7, 0]]], dtype=np.uint8) + # Step 2: prompt=[1,2,3,4,5,6], completion=[7,8], bridged from prefix len 4. + # vLLM returns routed experts starting at row 3: boundary token 4, then 5, 6, 7. + step2_experts = np.asarray([[[40, 41]], [[50, 51]], [[60, 61]], [[70, 71]]], dtype=np.uint8) output = vf.RolloutOutput( example_id=0, @@ -952,7 +954,7 @@ def test_interleave_rollout_multi_step_with_routed_experts(): completion_logprobs=[-0.3, -0.4], overlong_prompt=False, is_truncated=False, - routed_experts=_routed_experts_payload(step2_experts), + routed_experts=_routed_experts_payload(step2_experts, start=3), ), reward=None, advantage=None, @@ -973,7 +975,132 @@ def test_interleave_rollout_multi_step_with_routed_experts(): # Merged sample: prompt=[1,2], completion=[3,4,5,6,7,8] -> 8 tokens total assert len(sample.prompt_ids) + len(sample.completion_ids) == 8 assert sample.routed_experts is not None - assert _sample_routed_experts(sample).shape == (8, 1, 2) + routed_experts = _sample_routed_experts(sample) + assert routed_experts.shape == (8, 1, 2) + np.testing.assert_array_equal( + routed_experts, + np.asarray( + [ + [[1, 2]], + [[3, 4]], + [[5, 6]], + [[40, 41]], + [[50, 51]], + [[60, 61]], + [[70, 71]], + [[0, 0]], + ], + dtype=np.uint8, + ), + ) + + +def test_interleave_rollout_branch_delta_uses_prior_routed_prefix(): + step1_experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) + step2_experts = np.asarray([[[40, 41]], [[50, 51]], [[60, 61]], [[70, 71]]], dtype=np.uint8) + step3_experts = np.asarray([[[80, 81]], [[90, 91]], [[100, 101]], [[110, 111]]], dtype=np.uint8) + + output = vf.RolloutOutput( + example_id=0, + trajectory=[ + vf.TrajectoryStep( + prompt=[{"role": "user", "content": "U1"}], + completion=[{"role": "assistant", "content": "A1"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2], + prompt_mask=[0, 0], + completion_ids=[3, 4], + completion_mask=[1, 1], + completion_logprobs=[-0.1, -0.2], + overlong_prompt=False, + is_truncated=False, + routed_experts=_routed_experts_payload(step1_experts), + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + vf.TrajectoryStep( + prompt=[ + {"role": "user", "content": "U1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "U2"}, + ], + completion=[{"role": "assistant", "content": "A2"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2, 3, 4, 5, 6], + prompt_mask=[0, 0, 0, 0, 0, 0], + completion_ids=[7, 8], + completion_mask=[1, 1], + completion_logprobs=[-0.3, -0.4], + overlong_prompt=False, + is_truncated=False, + routed_experts=_routed_experts_payload(step2_experts, start=3), + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + vf.TrajectoryStep( + prompt=[ + {"role": "user", "content": "U1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "branch"}, + ], + completion=[{"role": "assistant", "content": "A3"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2, 3, 4, 9, 10], + prompt_mask=[0, 0, 0, 0, 0, 0], + completion_ids=[11, 12], + completion_mask=[1, 1], + completion_logprobs=[-0.5, -0.6], + overlong_prompt=False, + is_truncated=False, + routed_experts=_routed_experts_payload(step3_experts, start=3), + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + ], + sampling_args={"temperature": 1.0}, + error=None, + ) + + rollouts = interleave_rollout(output) + + assert rollouts is not None + assert len(rollouts) == 2 + branched = rollouts[1] + assert branched.prompt_ids == [1, 2, 3, 4, 9, 10] + assert branched.completion_ids == [11, 12] + routed_experts = _sample_routed_experts(branched) + assert routed_experts.shape == (8, 1, 2) + np.testing.assert_array_equal( + routed_experts, + np.asarray( + [ + [[1, 2]], + [[3, 4]], + [[5, 6]], + [[80, 81]], + [[90, 91]], + [[100, 101]], + [[110, 111]], + [[0, 0]], + ], + dtype=np.uint8, + ), + ) def test_interleave_rollout_none_routed_experts_stays_none(): diff --git a/uv.lock b/uv.lock index a3d9158657..d5f90a349d 100644 --- a/uv.lock +++ b/uv.lock @@ -62,11 +62,13 @@ members = [ "prime-rl-configs", "renderers", "reverse-text", + "rlm-swe", "science-env", "simpleqa-verified", "tau2-bench", "verifiers", "wiki-search", + "wordle", ] overrides = [ { name = "nvidia-cudnn-cu12", specifier = ">=9.15" }, @@ -360,6 +362,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" }, ] +[[package]] +name = "bashlex" +version = "0.18" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/60/aae0bb54f9af5e0128ba90eb83d8d0d506ee8f0475c4fdda3deeda20b1d2/bashlex-0.18.tar.gz", hash = "sha256:5bb03a01c6d5676338c36fd1028009c8ad07e7d61d8a1ce3f513b7fff52796ee", size = 68742, upload-time = "2023-01-18T15:21:26.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/be/6985abb1011fda8a523cfe21ed9629e397d6e06fb5bae99750402b25c95b/bashlex-0.18-py2.py3-none-any.whl", hash = "sha256:91d73a23a3e51711919c1c899083890cdecffc91d8c088942725ac13e9dcfffa", size = 69539, upload-time = "2023-01-18T15:21:24.167Z" }, +] + [[package]] name = "bcrypt" version = "5.0.0" @@ -894,6 +905,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/18/4cedda786e7da429e7489549a9e5461530d4133130e541f25fb94f015776/cyclopts-4.11.2-py3-none-any.whl", hash = "sha256:838020120b939549ff7c8423aca29c86764b5dd1d8a5d7f3753a6327861f537b", size = 213537, upload-time = "2026-05-04T00:11:56.103Z" }, ] +[[package]] +name = "dataclasses-json" +version = "0.6.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "marshmallow", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-inspect", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, +] + [[package]] name = "datasets" version = "4.6.1" @@ -2514,6 +2538,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, ] +[[package]] +name = "marshmallow" +version = "3.26.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, +] + [[package]] name = "math-env" version = "0.1.5" @@ -2841,6 +2877,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/68/04d7a8f0f786545cf9b8c280c57aa6befb5977af6e884b8b54191cbe44b3/msgspec-0.21.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ef3ec2296248d1f8b9231acb051b6d471dfde8f21819e86c9adaaa9f42918521", size = 227303, upload-time = "2026-04-12T21:44:13.709Z" }, ] +[[package]] +name = "multi-swe-bench" +version = "1.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dataclasses-json", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "docker", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "gitpython", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pygithub", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pyyaml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "swe-rex", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "toml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "unidiff", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/48/ad/6b7cda600a50392c790b14ee420b9a3bb318a982a298c05f2d1c066a434f/multi_swe_bench-1.1.2.tar.gz", hash = "sha256:44944bc6608d7d9b8d4390f3ce0a3b2c69122ea6be6e35766c6fde2328f50392", size = 1267660, upload-time = "2025-12-18T07:16:09.584Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/a8/060eb46096742944d8d37c34094d4e0fb34b28c6291a877388543ea65660/multi_swe_bench-1.1.2-py3-none-any.whl", hash = "sha256:09a5770096d6a035383c5240762ffa8c87b1e8df7d374110de8fb781b4e5a9f9", size = 4942355, upload-time = "2025-12-18T07:16:07.468Z" }, +] + [[package]] name = "multidict" version = "6.7.1" @@ -2870,6 +2926,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/28/dd72947e59a6a8c856448a5e74da6201cb5502ddff644fbc790e4bd40b9a/multiprocess-0.70.18-py39-none-any.whl", hash = "sha256:e78ca805a72b1b810c690b6b4cc32579eba34f403094bbbae962b7b5bf9dfcb8", size = 133478, upload-time = "2025-04-17T03:11:26.253Z" }, ] +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "narwhals" version = "2.21.0" @@ -3968,10 +4033,12 @@ envs = [ { name = "opencode-science", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "opencode-swe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "reverse-text", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "rlm-swe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "science-env", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "simpleqa-verified", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "tau2-bench", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wiki-search", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "wordle", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] flash-attn = [ { name = "flash-attn", version = "2.8.3+cu128torch2.11", source = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.8.3+cu128torch2.11-cp312-cp312-linux_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -4059,6 +4126,7 @@ requires-dist = [ { name = "reverse-text", marker = "extra == 'envs'", editable = "deps/verifiers/environments/reverse_text" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, + { name = "rlm-swe", marker = "extra == 'envs'", editable = "deps/research-environments/environments/rlm_swe" }, { name = "science-env", marker = "extra == 'envs'", editable = "deps/research-environments/environments/science_env" }, { name = "setproctitle", specifier = ">=1.3.0" }, { name = "simpleqa-verified", marker = "extra == 'envs'", editable = "deps/research-environments/environments/simpleqa_verified" }, @@ -4076,9 +4144,10 @@ requires-dist = [ { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.21.0" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wiki_search" }, + { name = "wordle", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wordle" }, ] provides-extras = ["flash-attn", "flash-attn-3", "flash-attn-cute", "envs", "disagg", "gpt-oss", "quack", "all"] @@ -4445,6 +4514,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/03/bef6fff907e212d67a0003f8ea4819307bba91b2856074a0763dd483ccc4/pyfiglet-1.0.2-py3-none-any.whl", hash = "sha256:889b351d79c99e50a3f619c8f8e6ffdb27fd8c939fc43ecbd7559bd57d5f93ea", size = 1085824, upload-time = "2023-09-13T20:56:18.707Z" }, ] +[[package]] +name = "pygithub" +version = "2.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyjwt", extra = ["crypto"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pynacl", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "urllib3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/c3/8465a311197e16cf5ab68789fe689535e90f6b61ab524cc32a39e67237ae/pygithub-2.9.1.tar.gz", hash = "sha256:59771d7ff63d54d427be2e7d0dad2208dfffc2b0a045fec959263787739b611c", size = 2594989, upload-time = "2026-04-14T07:26:13.622Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/aa/81a5506f089a26338bff17535e4339b3b22049ebd1bcdeff756c4d7a7559/pygithub-2.9.1-py3-none-any.whl", hash = "sha256:2ec78fca30092d51a42d76f4ddb02131b6f0c666a35dfdf364cf302cdda115b9", size = 449710, upload-time = "2026-04-14T07:26:12.382Z" }, +] + [[package]] name = "pygments" version = "2.20.0" @@ -4468,6 +4553,25 @@ crypto = [ { name = "cryptography", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] +[[package]] +name = "pynacl" +version = "1.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'aarch64' and platform_python_implementation == 'PyPy' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'x86_64' and platform_python_implementation == 'PyPy' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/4019b524b03a13438637b11538c82781a5eda427394380381af8f04f467a/pynacl-1.6.2.tar.gz", hash = "sha256:018494d6d696ae03c7e656e5e74cdfd8ea1326962cc401bcf018f1ed8436811c", size = 3511692, upload-time = "2026-01-01T17:48:10.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/b4/e927e0653ba63b02a4ca5b4d852a8d1d678afbf69b3dbf9c4d0785ac905c/pynacl-1.6.2-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8845c0631c0be43abdd865511c41eab235e0be69c81dc66a50911594198679b0", size = 800020, upload-time = "2026-01-01T17:32:18.34Z" }, + { url = "https://files.pythonhosted.org/packages/7f/81/d60984052df5c97b1d24365bc1e30024379b42c4edcd79d2436b1b9806f2/pynacl-1.6.2-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:22de65bb9010a725b0dac248f353bb072969c94fa8d6b1f34b87d7953cf7bbe4", size = 1399174, upload-time = "2026-01-01T17:32:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/68/f7/322f2f9915c4ef27d140101dd0ed26b479f7e6f5f183590fd32dfc48c4d3/pynacl-1.6.2-cp38-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46065496ab748469cdd999246d17e301b2c24ae2fdf739132e580a0e94c94a87", size = 835085, upload-time = "2026-01-01T17:32:22.24Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d0/f301f83ac8dbe53442c5a43f6a39016f94f754d7a9815a875b65e218a307/pynacl-1.6.2-cp38-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a66d6fb6ae7661c58995f9c6435bda2b1e68b54b598a6a10247bfcdadac996c", size = 1437614, upload-time = "2026-01-01T17:32:23.766Z" }, + { url = "https://files.pythonhosted.org/packages/c4/58/fc6e649762b029315325ace1a8c6be66125e42f67416d3dbd47b69563d61/pynacl-1.6.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:26bfcd00dcf2cf160f122186af731ae30ab120c18e8375684ec2670dccd28130", size = 818251, upload-time = "2026-01-01T17:32:25.69Z" }, + { url = "https://files.pythonhosted.org/packages/c9/a8/b917096b1accc9acd878819a49d3d84875731a41eb665f6ebc826b1af99e/pynacl-1.6.2-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8a231e36ec2cab018c4ad4358c386e36eede0319a0c41fed24f840b1dac59f6", size = 1402859, upload-time = "2026-01-01T17:32:27.215Z" }, + { url = "https://files.pythonhosted.org/packages/85/42/fe60b5f4473e12c72f977548e4028156f4d340b884c635ec6b063fe7e9a5/pynacl-1.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:68be3a09455743ff9505491220b64440ced8973fe930f270c8e07ccfa25b1f9e", size = 791926, upload-time = "2026-01-01T17:32:29.314Z" }, + { url = "https://files.pythonhosted.org/packages/fa/f9/e40e318c604259301cc091a2a63f237d9e7b424c4851cafaea4ea7c4834e/pynacl-1.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b097553b380236d51ed11356c953bf8ce36a29a3e596e934ecabe76c985a577", size = 1363101, upload-time = "2026-01-01T17:32:31.263Z" }, +] + [[package]] name = "pyparsing" version = "3.3.2" @@ -4884,6 +4988,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/02/18ba0727a1c755c528d6a52b363d62c0b7a8e64cf961b3030c046107db4d/ring_flash_attn-0.1.8-py3-none-any.whl", hash = "sha256:296c929516c3b21f7bcdaeca44a99bb541779a7b63979eb0f67837dcb18a2bb9", size = 25437, upload-time = "2025-09-10T11:53:07.565Z" }, ] +[[package]] +name = "rlm-swe" +version = "0.3.4" +source = { editable = "deps/research-environments/environments/rlm_swe" } +dependencies = [ + { name = "multi-swe-bench", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "prime-sandboxes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "swebench", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + +[package.metadata] +requires-dist = [ + { name = "multi-swe-bench", specifier = ">=1.1.2" }, + { name = "prime-sandboxes", specifier = ">=0.2.19" }, + { name = "swebench", specifier = "==4.1.0" }, + { name = "verifiers", specifier = ">=0.1.13.dev8" }, +] + [[package]] name = "rpds-py" version = "0.30.0" @@ -5229,6 +5352,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/65/5e726c372da8a5e35022a94388b12252710aad0c2351699c3d76ae8dba78/supervisor-4.3.0-py2.py3-none-any.whl", hash = "sha256:0bcb763fddafba410f35cbde226aa7f8514b9fb82eb05a0c85f6588d1c13f8db", size = 320736, upload-time = "2025-08-23T18:25:00.767Z" }, ] +[[package]] +name = "swe-rex" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bashlex", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "fastapi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pexpect", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "python-multipart", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "rich", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "uvicorn", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/86/a069f93ec866151a4d476d546e60220e66b3788878b6e248b2df3ab2c5f1/swe_rex-1.4.0.tar.gz", hash = "sha256:14f8a24c49a63f9e251340b1109ac75a4aacbaece410f8599209de9bfca843c0", size = 41755, upload-time = "2025-08-14T01:19:20.22Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/0d/d06ab2aa78138055c297490762cd7b4d8ac58a544783f874c869cdb7b534/swe_rex-1.4.0-py3-none-any.whl", hash = "sha256:61261ad03eb23b717b5901cd5d229f24f6e1be2e120aad5c2e5ea3384a1d15ad", size = 47756, upload-time = "2025-08-14T01:19:18.93Z" }, +] + [[package]] name = "swebench" version = "4.1.0" @@ -5782,6 +5924,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, +] + [[package]] name = "typing-inspection" version = "0.4.2" @@ -6412,7 +6567,7 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" version = "0.1.25" -source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } +source = { path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" } dependencies = [ { name = "aiohttp", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "fastapi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -6422,7 +6577,7 @@ dependencies = [ { name = "uvicorn", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:e84e731a0779f820bfe3cf4ce78cea2d09993c0a6501c63bcda93826bcd21fd0" }, + { filename = "vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl", hash = "sha256:6386039d3256f2e7a70cb5317a7d86813f72691b8902a5dcf0d4211ebfc72abb" }, ] [package.metadata] @@ -6580,6 +6735,23 @@ requires-dist = [ { name = "verifiers", specifier = ">=0.1.9" }, ] +[[package]] +name = "wordle" +version = "0.1.7" +source = { editable = "deps/verifiers/environments/wordle" } +dependencies = [ + { name = "nltk", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "textarena", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + +[package.metadata] +requires-dist = [ + { name = "nltk", specifier = ">=3.9.2" }, + { name = "textarena", specifier = "==0.7.4" }, + { name = "verifiers", specifier = ">=0.1.9.post3" }, +] + [[package]] name = "wrapt" version = "1.17.3" From 5ee18c884bec295c77bfe6d9d276197985b01a38 Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 25 May 2026 22:57:19 +0530 Subject: [PATCH 02/13] Handle token export mkdir races --- src/prime_rl/trainer/rl/token_export.py | 18 ++++++++- tests/unit/train/rl/test_token_export.py | 49 ++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 tests/unit/train/rl/test_token_export.py diff --git a/src/prime_rl/trainer/rl/token_export.py b/src/prime_rl/trainer/rl/token_export.py index e8c7dd23bc..702f1c8f99 100644 --- a/src/prime_rl/trainer/rl/token_export.py +++ b/src/prime_rl/trainer/rl/token_export.py @@ -1,6 +1,7 @@ import atexit import json import math +import time from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any @@ -12,6 +13,8 @@ from prime_rl.trainer.rl.loss import compute_importance_ratio_and_mismatch_kl SCHEMA_VERSION = 1 +MKDIR_RETRY_DELAY_SECONDS = 0.1 +MKDIR_MAX_ATTEMPTS = 5 class DisabledTokenExporter: @@ -88,7 +91,7 @@ def _start_step(self, step: int) -> None: self._current_step = step self._sequences_this_step = 0 step_dir = self.output_dir / f"step_{step}" - step_dir.mkdir(parents=True, exist_ok=True) + _mkdir_existing_dir_ok(step_dir) self._file = (step_dir / f"rank_{self.rank}.jsonl").open("w", encoding="utf-8") def _write(self, record: dict[str, Any]) -> None: @@ -113,6 +116,19 @@ def setup_token_exporter( return exporter +def _mkdir_existing_dir_ok(path: Path) -> None: + for attempt in range(MKDIR_MAX_ATTEMPTS): + try: + path.mkdir(parents=True, exist_ok=True) + return + except FileExistsError: + if path.is_dir(): + return + if attempt == MKDIR_MAX_ATTEMPTS - 1: + raise + time.sleep(MKDIR_RETRY_DELAY_SECONDS) + + def _export_columns( micro_batch: Mapping[str, Any], model_output: Mapping[str, Tensor], loss_config: Any ) -> dict[str, list[Any]]: diff --git a/tests/unit/train/rl/test_token_export.py b/tests/unit/train/rl/test_token_export.py new file mode 100644 index 0000000000..2eddc28824 --- /dev/null +++ b/tests/unit/train/rl/test_token_export.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import pytest + +from prime_rl.trainer.rl import token_export +from prime_rl.trainer.rl.token_export import _mkdir_existing_dir_ok + + +def test_mkdir_existing_dir_ok_retries_transient_file_exists( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + target = tmp_path / "token_exports" / "step_31" + original_mkdir = Path.mkdir + calls = 0 + + def flaky_mkdir( + self: Path, + mode: int = 0o777, + parents: bool = False, + exist_ok: bool = False, + ) -> None: + nonlocal calls + if self == target and calls == 0: + calls += 1 + raise FileExistsError(str(self)) + original_mkdir(self, mode=mode, parents=parents, exist_ok=exist_ok) + + def create_dir_during_retry(_: float) -> None: + original_mkdir(target, parents=True, exist_ok=True) + + monkeypatch.setattr(Path, "mkdir", flaky_mkdir) + monkeypatch.setattr(token_export.time, "sleep", create_dir_during_retry) + + _mkdir_existing_dir_ok(target) + + assert target.is_dir() + assert calls == 1 + + +def test_mkdir_existing_dir_ok_raises_when_path_is_file( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + target = tmp_path / "token_exports" / "step_31" + target.parent.mkdir(parents=True) + target.write_text("not a directory", encoding="utf-8") + monkeypatch.setattr(token_export.time, "sleep", lambda _: None) + + with pytest.raises(FileExistsError): + _mkdir_existing_dir_ok(target) From d36434e8866271ab77393d7ee11a4f909edd093a Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Mon, 25 May 2026 21:31:00 +0000 Subject: [PATCH 03/13] fix(orchestrator): match longest active prefix in interleave_rollout The first-match-wins loop over active_samples picks the wrong sample when one active prefix is a strict prefix of another. This can happen after a compaction/rollback step whose prompt is shorter than an existing sample's prefix and whose completion re-generates the same tokens and extends past them: the new sample's prefix then starts with the older sample's prefix, and any later step that extends the new sample also satisfies the slice check against the older one. When that happens, extend_sample folds the newer sample's generated tokens into the older sample as user-input tokens (mask=False, logprob=0) and leaves the newer sample stale -- a silent Exact-Prefix invariant violation. Switch to longest-match: strictly more specific, never worse than first-match when only one prefix matches. Co-authored-by: Cursor (cherry picked from commit 0e239d1b41b3ca381da34bd63a1d5df26e90c466) --- src/prime_rl/orchestrator/trajectories.py | 12 +- tests/unit/orchestrator/test_trajectories.py | 134 +++++++++++++++++++ 2 files changed, 143 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index f1c164f2c8..7abf25fe2d 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -406,12 +406,18 @@ def extend_sample( tokens = prepared_steps[step_idx] step_prompt_ids = tokens["prompt_ids"] - # Check if this step extends ANY active prefix + # Pick the *longest* matching active prefix. With compaction/rollback, + # one active sample's prefix can be a strict prefix of another (e.g. a + # later sample re-generated tokens that overlap an earlier sample's + # prefix). Both would satisfy the slice check; the shorter would + # silently absorb the longer sample's generated tokens as user input. matched_idx = None + matched_len = -1 for idx, (prefix_tokens, _, _) in enumerate(active_samples): - if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens: + pl = len(prefix_tokens) + if pl > matched_len and step_prompt_ids[:pl] == prefix_tokens: matched_idx = idx - break + matched_len = pl if matched_idx is not None: # Extension holds - merge into matched sample diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 1ecf2e3c46..87df2646c6 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -749,6 +749,140 @@ def test_interleave_rollout_interleaved_agents(interleaved_agents_trajectory): assert agent2_sample.completion_logprobs == [-0.5, -0.6] +@pytest.fixture +def prefix_of_prefix_trajectory(): + """ + Trajectory where one active sample's prefix is a strict prefix of another's. + + Construction: + - step 0: prompt=[1,2], completion=[3,4] -> sample A, P_A=[1,2,3,4] + - step 1: extends A. prompt=[1,2,3,4,5], completion=[6] -> P_A=[1,2,3,4,5,6] + - step 2: rollback/regenerate. prompt=[1,2] (shorter than P_A so no match), + completion=[3,4,5,6,7] -> sample B, P_B=[1,2,3,4,5,6,7] + P_B starts with P_A. + - step 3: extends B. prompt=[1,2,3,4,5,6,7,8], completion=[9] + Both P_A and P_B are token-prefixes of the step's prompt. + + The correct match is the longer P_B. First-match-wins picks P_A and silently + folds B's generated tokens into A as user-input tokens (mask=False). + """ + output = vf.RolloutOutput( + example_id=2, + task="test", + trajectory=[ + vf.TrajectoryStep( + prompt="step 0", + completion="completion 0", + response=None, + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2], + prompt_mask=[0, 0], + completion_ids=[3, 4], + completion_mask=[1, 1], + completion_logprobs=[-0.1, -0.2], + overlong_prompt=False, + is_truncated=False, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="traj_A", + extras={}, + ), + vf.TrajectoryStep( + prompt="step 1", + completion="completion 1", + response=None, + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2, 3, 4, 5], + prompt_mask=[0, 0, 0, 0, 0], + completion_ids=[6], + completion_mask=[1], + completion_logprobs=[-0.3], + overlong_prompt=False, + is_truncated=False, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="traj_A", + extras={}, + ), + vf.TrajectoryStep( + prompt="step 2 (rollback)", + completion="completion 2", + response=None, + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2], + prompt_mask=[0, 0], + completion_ids=[3, 4, 5, 6, 7], + completion_mask=[1, 1, 1, 1, 1], + completion_logprobs=[-0.4, -0.5, -0.6, -0.7, -0.8], + overlong_prompt=False, + is_truncated=False, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="traj_B", + extras={}, + ), + vf.TrajectoryStep( + prompt="step 3 (extends B)", + completion="completion 3", + response=None, + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8], + prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0], + completion_ids=[9], + completion_mask=[1], + completion_logprobs=[-0.9], + overlong_prompt=False, + is_truncated=False, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="traj_B", + extras={}, + ), + ], + sampling_args={"temperature": 1.0}, + error=None, + ) + return output + + +def test_interleave_rollout_picks_longest_matching_prefix(prefix_of_prefix_trajectory): + """ + When two active samples both match (one's prefix is a strict prefix of the + other's), the longer prefix is the correct extension. Previously the first- + match-wins loop folded the longer sample's generated tokens into the shorter + sample as user input (mask=False) and left the longer sample stale. + """ + rollouts = interleave_rollout(prefix_of_prefix_trajectory) + + assert rollouts is not None + assert len(rollouts) == 2 + + # Sample A: steps 0 and 1 only. Step 3 must NOT have been folded in here. + sample_a = rollouts[0] + assert sample_a.prompt_ids == [1, 2] + # step 0 completion [3,4] + step 1 new prompt [5] + step 1 completion [6] + assert sample_a.completion_ids == [3, 4, 5, 6] + assert sample_a.completion_mask == [True, True, False, True] + assert sample_a.completion_logprobs == [-0.1, -0.2, 0.0, -0.3] + + # Sample B: steps 2 and 3 merged. The token 7 (from step 2's completion) + # must remain masked as a generated token, not silently re-classified. + sample_b = rollouts[1] + assert sample_b.prompt_ids == [1, 2] + # step 2 completion [3,4,5,6,7] + step 3 new prompt [8] + step 3 completion [9] + assert sample_b.completion_ids == [3, 4, 5, 6, 7, 8, 9] + assert sample_b.completion_mask == [True, True, True, True, True, False, True] + assert sample_b.completion_logprobs == [-0.4, -0.5, -0.6, -0.7, -0.8, 0.0, -0.9] + + def test_interleave_rollout_empty_trajectory(): """Empty trajectory returns None.""" output = vf.RolloutOutput( From cbe5be7e6ae1c2f879c57c430d630a2b6d0d8b71 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Mon, 25 May 2026 21:32:54 +0000 Subject: [PATCH 04/13] warn on ambiguous prefix match in interleave_rollout When more than one active prefix matches a step's prompt, log a warning with the example id, step index, set of matching prefix lengths, total active prefixes, and the prompt length. Longest-match still picks the correct extension; the warning just surfaces the rare ambiguous case so it's debuggable if it starts showing up in real rollouts (e.g. from compaction/rollback turns). Co-authored-by: Cursor (cherry picked from commit ca38614087109a9b2abfaa0580034e25b47d4bd5) --- src/prime_rl/orchestrator/trajectories.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 7abf25fe2d..3d0e83bfce 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -413,11 +413,26 @@ def extend_sample( # silently absorb the longer sample's generated tokens as user input. matched_idx = None matched_len = -1 + matching_prefix_lens: list[int] = [] for idx, (prefix_tokens, _, _) in enumerate(active_samples): pl = len(prefix_tokens) - if pl > matched_len and step_prompt_ids[:pl] == prefix_tokens: - matched_idx = idx - matched_len = pl + if step_prompt_ids[:pl] == prefix_tokens: + matching_prefix_lens.append(pl) + if pl > matched_len: + matched_idx = idx + matched_len = pl + + if len(matching_prefix_lens) > 1: + # Ambiguous extension: rare, but reachable via compaction/rollback + # where a new sample's prefix happens to start with an older + # sample's prefix. Longest-match is the correct choice; surface + # the ambiguity so we can audit if it shows up in real rollouts. + logger.warning( + f"Ambiguous prefix match at step {step_idx} for example {output['example_id']}: " + f"{len(matching_prefix_lens)} of {len(active_samples)} active prefixes match " + f"(lens={sorted(matching_prefix_lens)}, step_prompt_len={len(step_prompt_ids)}). " + f"Extending the longest (len={matched_len})." + ) if matched_idx is not None: # Extension holds - merge into matched sample From 8f8bd915bbeef82c612ee71dc77c2d1279605cd0 Mon Sep 17 00:00:00 2001 From: Sami Date: Wed, 27 May 2026 08:19:49 +0530 Subject: [PATCH 05/13] Log per-server inference metrics --- .../orchestrator/inference_metrics.py | 112 +++++++++++++++++- .../orchestrator/test_inference_metrics.py | 61 ++++++++++ 2 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 tests/unit/orchestrator/test_inference_metrics.py diff --git a/src/prime_rl/orchestrator/inference_metrics.py b/src/prime_rl/orchestrator/inference_metrics.py index 7e756622f8..b624d9a8f7 100644 --- a/src/prime_rl/orchestrator/inference_metrics.py +++ b/src/prime_rl/orchestrator/inference_metrics.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import re import time from collections import deque @@ -98,8 +99,21 @@ def __init__(self, admin_clients: list[AsyncClient]): self._rate_history: dict[str, deque[float]] = {} self._prev_counters: dict[str, tuple[float, float]] = {} self._prev_histograms: dict[str, tuple[float, float, float]] = {} + self._server_gauge_history: dict[str, dict[str, deque[float]]] = {} + self._server_rate_history: dict[str, dict[str, deque[float]]] = {} + self._server_prev_counters: dict[str, dict[str, tuple[float, float]]] = {} + self._server_prev_histograms: dict[str, dict[str, tuple[float, float, float]]] = {} + self._server_names = [self._server_name(idx, client) for idx, client in enumerate(admin_clients)] self._task: asyncio.Task | None = None + @staticmethod + def _server_name(idx: int, client: AsyncClient) -> str: + host = client.base_url.host or f"server_{idx}" + port = client.base_url.port + raw = f"{host}_{port}" if port is not None else host + safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", raw).strip("_") + return f"server_{idx:02d}_{safe or 'unknown'}" + async def start(self): wandb.define_metric("inference/*", step_metric="_timestamp") @@ -132,13 +146,16 @@ async def fetch(client: AsyncClient) -> str | None: agg_sum_gauges: dict[str, float] = {} agg_counters: dict[str, float] = {} agg_histograms: dict[str, tuple[float, float]] = {} + active_servers: set[str] = set() n_servers = 0 - for text in results: + for server_name, text in zip(self._server_names, results, strict=True): if text is None: continue + active_servers.add(server_name) n_servers += 1 gauges, counters, histograms = parse_prometheus_text(text) + self._update_server_histories(server_name, now, gauges, counters, histograms) for name, value in gauges.items(): if name in _DUAL_AGG_GAUGES: @@ -154,6 +171,7 @@ async def fetch(client: AsyncClient) -> str | None: agg_histograms[name] = (prev[0] + h_sum, prev[1] + h_count) if n_servers == 0: + wandb.log({**self._server_up_metrics(active_servers), "_timestamp": time.time()}) return # Update gauge history — sum gauges @@ -221,11 +239,103 @@ async def fetch(client: AsyncClient) -> str | None: for rate_name, values in self._rate_history.items(): if values: metrics[f"inference/{rate_name}"] = sum(values) / len(values) + self._add_cache_alias_metrics(metrics) + metrics.update(self._server_metrics(active_servers)) + metrics.update(self._server_up_metrics(active_servers)) if metrics: metrics["_timestamp"] = time.time() wandb.log(metrics) + def _update_server_histories( + self, + server_name: str, + now: float, + gauges: dict[str, float], + counters: dict[str, float], + histograms: dict[str, tuple[float, float]], + ) -> None: + gauge_history = self._server_gauge_history.setdefault(server_name, {}) + rate_history = self._server_rate_history.setdefault(server_name, {}) + prev_counters = self._server_prev_counters.setdefault(server_name, {}) + prev_histograms = self._server_prev_histograms.setdefault(server_name, {}) + + for name, value in gauges.items(): + short = name.removeprefix("vllm:") + if name in _DUAL_AGG_GAUGES: + short = f"{short}_max" + gauge_history.setdefault(short, deque(maxlen=WINDOW_SIZE)).append(value) + + for name, value in counters.items(): + rate_name = COUNTER_RATE_NAMES[name] + prev = prev_counters.get(name) + prev_counters[name] = (now, value) + if prev is None: + continue + prev_time, prev_value = prev + dt = now - prev_time + if dt <= 0: + continue + delta = value - prev_value + if delta < 0: + continue + rate_history.setdefault(rate_name, deque(maxlen=WINDOW_SIZE)).append(delta / dt) + + for name, (h_sum, h_count) in histograms.items(): + short = name.removeprefix("vllm:") + rate_name = f"{short}_avg_ms" + prev = prev_histograms.get(name) + prev_histograms[name] = (now, h_sum, h_count) + if prev is None: + continue + _, prev_sum, prev_count = prev + d_sum = h_sum - prev_sum + d_count = h_count - prev_count + if d_count < 0 or d_sum < 0: + continue + if d_count > 0: + rate_history.setdefault(rate_name, deque(maxlen=WINDOW_SIZE)).append((d_sum / d_count) * 1000.0) + + def _server_metrics(self, active_servers: set[str]) -> dict[str, float]: + metrics: dict[str, float] = {} + for server_name, gauge_history in self._server_gauge_history.items(): + if server_name not in active_servers: + continue + for short, values in gauge_history.items(): + if values: + metrics[f"inference/server/{server_name}/{short}"] = sum(values) / len(values) + for server_name, rate_history in self._server_rate_history.items(): + if server_name not in active_servers: + continue + for rate_name, values in rate_history.items(): + if values: + metrics[f"inference/server/{server_name}/{rate_name}"] = sum(values) / len(values) + self._add_cache_alias_metrics(metrics) + return metrics + + def _server_up_metrics(self, active_servers: set[str]) -> dict[str, float]: + return {f"inference/server/{server_name}/up": float(server_name in active_servers) for server_name in self._server_names} + + @classmethod + def _add_cache_alias_metrics(cls, metrics: dict[str, float]) -> None: + for key, value in list(metrics.items()): + if key.endswith("/gpu_prefix_cache_hit_rate_max"): + prefix = key.removesuffix("gpu_prefix_cache_hit_rate_max") + metrics[f"{prefix}kv_cache_hit_rate_max"] = value + elif key.endswith("/gpu_prefix_cache_hit_rate_mean"): + prefix = key.removesuffix("gpu_prefix_cache_hit_rate_mean") + metrics[f"{prefix}kv_cache_hit_rate_mean"] = value + elif key.endswith("/gpu_cache_usage_perc_max"): + prefix = key.removesuffix("gpu_cache_usage_perc_max") + metrics[f"{prefix}kv_cache_left_perc_min"] = cls._cache_left(value) + elif key.endswith("/gpu_cache_usage_perc_mean"): + prefix = key.removesuffix("gpu_cache_usage_perc_mean") + metrics[f"{prefix}kv_cache_left_perc_mean"] = cls._cache_left(value) + + @staticmethod + def _cache_left(usage: float) -> float: + return min(max(1.0 - usage, 0.0), 1.0) + async def stop(self): if self._task is not None: self._task.cancel() diff --git a/tests/unit/orchestrator/test_inference_metrics.py b/tests/unit/orchestrator/test_inference_metrics.py new file mode 100644 index 0000000000..5926d11e5d --- /dev/null +++ b/tests/unit/orchestrator/test_inference_metrics.py @@ -0,0 +1,61 @@ +import asyncio + +import pytest +from httpx import AsyncClient + +from prime_rl.orchestrator.inference_metrics import InferenceMetricsCollector + + +async def _close_clients(clients: list[AsyncClient]) -> None: + await asyncio.gather(*(client.aclose() for client in clients)) + + +def test_server_metrics_are_namespaced_and_track_up(): + clients = [ + AsyncClient(base_url="http://ltc-idc3-hgx8-h200-6:8100"), + AsyncClient(base_url="http://ltc-idc3-hgx8-h200-11:8200"), + ] + try: + collector = InferenceMetricsCollector(clients) + active_servers = {collector._server_names[0]} + + collector._update_server_histories( + collector._server_names[0], + 10.0, + { + "vllm:num_requests_running": 3.0, + "vllm:gpu_cache_usage_perc": 0.5, + "vllm:gpu_prefix_cache_hit_rate": 0.2, + }, + {"vllm:prompt_tokens": 100.0}, + {"vllm:nixl_xfer_time_seconds": (1.0, 2.0)}, + ) + collector._update_server_histories( + collector._server_names[0], + 15.0, + { + "vllm:num_requests_running": 5.0, + "vllm:gpu_cache_usage_perc": 0.7, + "vllm:gpu_prefix_cache_hit_rate": 0.4, + }, + {"vllm:prompt_tokens": 150.0}, + {"vllm:nixl_xfer_time_seconds": (1.5, 3.0)}, + ) + + server_0 = "server_00_ltc-idc3-hgx8-h200-6_8100" + server_1 = "server_01_ltc-idc3-hgx8-h200-11_8200" + metrics = collector._server_metrics(active_servers) + up_metrics = collector._server_up_metrics(active_servers) + + assert metrics[f"inference/server/{server_0}/num_requests_running"] == 4.0 + assert metrics[f"inference/server/{server_0}/gpu_cache_usage_perc_max"] == 0.6 + assert metrics[f"inference/server/{server_0}/kv_cache_left_perc_min"] == pytest.approx(0.4) + assert metrics[f"inference/server/{server_0}/gpu_prefix_cache_hit_rate_max"] == pytest.approx(0.3) + assert metrics[f"inference/server/{server_0}/kv_cache_hit_rate_max"] == pytest.approx(0.3) + assert metrics[f"inference/server/{server_0}/prefill_throughput_tps"] == 10.0 + assert metrics[f"inference/server/{server_0}/nixl_xfer_time_seconds_avg_ms"] == 500.0 + assert all(f"/{server_1}/" not in key for key in metrics) + assert up_metrics[f"inference/server/{server_0}/up"] == 1.0 + assert up_metrics[f"inference/server/{server_1}/up"] == 0.0 + finally: + asyncio.run(_close_clients(clients)) From 84ed29172a3d2d77619242e4d890cc464e1053a7 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Wed, 27 May 2026 19:09:59 +0530 Subject: [PATCH 06/13] fix: pin vLLM no-coordinator wheel --- pyproject.toml | 2 +- src/prime_rl/inference/patches.py | 82 ------------------------------- uv.lock | 12 ++--- 3 files changed, 7 insertions(+), 89 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 827e6926e3..ac3e2a44fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -240,7 +240,7 @@ flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdi prime-pydantic-config = { workspace = true } vllm-router = { path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" } vllm = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { path = "/home/matej/dev/vllm-0.21-revert42434-pr39568/dist-no-dp-coordinator/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 35b2d1d26c..d5990b213b 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -18,7 +18,6 @@ def transformers_v5_compat(): _patch_qwen35_lora() _patch_lora_key_prefix() monkey_patch_deep_gemm_silu_mul_quant_int64() - monkey_patch_dp_engine_core_pause_resume_deadlock() monkey_patch_vllm_layerwise_reload_alias_buffers() monkey_patch_vllm_padded_input_scrub() monkey_patch_return_routed_experts_with_nixl_connector() @@ -782,87 +781,6 @@ def _patched_to_sampling_params(self, max_tokens, default_sampling_params): ChatCompletionRequest.to_sampling_params = _patched_to_sampling_params -def monkey_patch_dp_engine_core_pause_resume_deadlock(): - """Fix DP pause/resume deadlocks around weight updates. - - Bug 1 (job 3756): while paused, START_DP_WAVE can wake idle ranks into the - DP loop. Those ranks then run dummy batches and hit DP collectives while - other ranks are still in NCCL weight transfer. - - Bug 2 (jobs 3769/3771): resume ties the DP running state to local - unfinished requests, but the DP wave state is global. Ranks with no local - work still need to re-enter the loop so they can participate in the same - DP collectives as ranks that are resuming remote-KV or decode work. - - Fix: - - ignore START_DP_WAVE wakeups while paused - - on resume, wake every DP rank and force an immediate global unfinished - sync instead of waiting for the normal 32-step cadence - - This also bypasses vLLM's two-phase DP pause implementation - (https://github.com/vllm-project/vllm/pull/39366), which makes resume - reject states that our weight-update flow can validly hit. - """ - from vllm.config import ParallelConfig - from vllm.v1.core.sched.interface import PauseState - from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType - from vllm.v1.engine.core import DPEngineCoreProc, EngineCore, EngineCoreProc - from vllm.v1.request import Request - - _base_add_request = EngineCore.add_request - _base_handle_client_request = EngineCoreProc._handle_client_request - _base_pause_complete = EngineCoreProc._pause_complete - _base_resume_scheduler = EngineCoreProc.resume_scheduler - - def _patched_add_request(self, request: Request, request_wave: int = 0): - _base_add_request(self, request, request_wave) - if self.has_coordinator and request_wave != self.current_wave: - if request_wave > self.current_wave: - self.current_wave = request_wave - elif not self.engines_running and self.scheduler.pause_state == PauseState.UNPAUSED: - self.engines_running = True - self.output_queue.put_nowait((-1, EngineCoreOutputs(start_wave=self.current_wave))) - - def _patched_handle_client_request(self, request_type, request): - if request_type == EngineCoreRequestType.START_DP_WAVE: - new_wave, exclude_eng_index = request - if exclude_eng_index != self.engine_index and new_wave >= self.current_wave: - self.current_wave = new_wave - if not self.engines_running and self.scheduler.pause_state == PauseState.UNPAUSED: - self.engines_running = True - else: - _base_handle_client_request(self, request_type, request) - - def _patched_pause_complete(self) -> bool: - self.pending_pause = False - self.ignore_start_dp_wave = False - return _base_pause_complete(self) - - def _patched_resume_scheduler(self): - was_paused = self.scheduler.pause_state != PauseState.UNPAUSED - self.pending_pause = False - self.ignore_start_dp_wave = False - _base_resume_scheduler(self) - if was_paused: - self.engines_running = True - self._force_dp_running_state_sync = True - - def _patched_has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - self.step_counter += 1 - if getattr(self, "_force_dp_running_state_sync", False): - self._force_dp_running_state_sync = False - return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) - if self.step_counter % 32 != 0: - return True - return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) - - DPEngineCoreProc.add_request = _patched_add_request - DPEngineCoreProc._handle_client_request = _patched_handle_client_request - DPEngineCoreProc._pause_complete = _patched_pause_complete - DPEngineCoreProc.resume_scheduler = _patched_resume_scheduler - DPEngineCoreProc._has_global_unfinished_reqs = _patched_has_global_unfinished_reqs - - def monkey_patch_no_moe_lora(): """This disables LoRA for MoE layers and makes them pick better kernels. diff --git a/uv.lock b/uv.lock index d5f90a349d..f87f597b95 100644 --- a/uv.lock +++ b/uv.lock @@ -3986,7 +3986,7 @@ dependencies = [ { name = "uvloop", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129.r42434.pr39568.a106aa6", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.r42434.pr39568.a106aa6", source = { path = "/home/matej/dev/vllm-0.21-revert42434-pr39568/dist-no-dp-coordinator/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] @@ -4143,7 +4143,7 @@ requires-dist = [ { name = "verifiers", editable = "deps/verifiers" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.21.0" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", path = "/home/matej/dev/vllm-0.21-revert42434-pr39568/dist-no-dp-coordinator/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wiki_search" }, @@ -6091,7 +6091,7 @@ rl = [ { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129.r42434.pr39568.a106aa6", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.r42434.pr39568.a106aa6", source = { path = "/home/matej/dev/vllm-0.21-revert42434-pr39568/dist-no-dp-coordinator/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] ta = [ @@ -6169,7 +6169,7 @@ requires-dist = [ { name = "uvloop", marker = "platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'", specifier = ">=0.21.0" }, { name = "vllm", marker = "platform_machine == 'aarch64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, - { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", path = "/home/matej/dev/vllm-0.21-revert42434-pr39568/dist-no-dp-coordinator/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" }, { name = "wandb", marker = "extra == 'rl'" }, ] provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] @@ -6390,7 +6390,7 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm" version = "0.21.0+cu129.r42434.pr39568.a106aa6" -source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" } +source = { path = "/home/matej/dev/vllm-0.21-revert42434-pr39568/dist-no-dp-coordinator/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -6466,7 +6466,7 @@ dependencies = [ { name = "xgrammar", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:0f0533a9122dfa738fc0f5532cb6795556f072b133f0058a22ff4be2d2f165f0" }, + { filename = "vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:2e822e07dd000c5f765530fc8bd2429f4cde6784be56849213f3ca3aaa4d9e8b" }, ] [package.metadata] From 51f3087c44b0e7884a704ebaf45e492ae40b6e85 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 29 May 2026 21:25:12 +0000 Subject: [PATCH 07/13] feat: configurable grace period for SLURM cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add slurm.cleanup_grace_period_seconds (default 3600) so that when a component exits — completion, crash, or SIGTERM — the multi-node RL and inference sbatch teardown sends SIGTERM and then waits up to the grace period for the remaining processes to exit before force-killing and releasing the allocation. This gives in-flight work, notably trainer checkpoint writes, a bounded window to flush. The wait ends as soon as all processes exit, so it is only an upper bound; set to 0 for the previous immediate force-kill behavior. Closes #2664 Co-authored-by: Cursor --- CHANGELOG.md | 5 +++++ .../src/prime_rl/configs/shared.py | 4 ++++ src/prime_rl/templates/inference.sbatch.j2 | 20 ++++++++++++++++++- .../templates/multi_node_rl.sbatch.j2 | 20 ++++++++++++++++++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 484a5ea4ea..7aa196d745 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **`slurm.cleanup_grace_period_seconds` (NEW)**: Added `cleanup_grace_period_seconds: int` to `SlurmConfig` (default: `3600`). When a component exits (completion, crash, or SIGTERM), the multi-node RL and inference sbatch teardown now sends SIGTERM and then waits up to this many seconds for the remaining processes to exit before force-killing (SIGKILL) and releasing the allocation, giving in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so it is only an upper bound. **Behavioral change**: teardown previously force-killed immediately; set `slurm.cleanup_grace_period_seconds = 0` to restore that. Should fit inside `slurm.time` so the job is not reaped mid-grace. (2026-05-29) +- **`sampling.min_tokens`, `sampling.repetition_penalty`, `sampling.seed` removed**: Dropped from both `TrainSamplingConfig` and `EvalSamplingConfig` (group-level `[orchestrator.train.sampling]` / `[orchestrator.eval.sampling]` and per-env `[[orchestrator.train.env.sampling]]` / `[[orchestrator.eval.env.sampling]]`). `min_tokens` suppressed natural EOS, `repetition_penalty` distorts the on-policy sampling distribution, and `seed` wasn't pulling its weight — none belonged on the supported config surface. Existing configs setting any of these must delete the field. Hard-deprecation, no migration window. (2026-05-27) +- **`wandb.shared` removed**: The deprecation shim that popped `wandb.shared` from input dicts with a `FutureWarning` (introduced in #2649) is gone. The `rl` entrypoint always uses shared W&B mode now, and existing configs that still set `wandb.shared = true` (or `false`) will fail validation. Drop the field from your config. (2026-05-27) +- **`max_async_level` and `strict_async_level` removed**: The async-execution semantics between trainer and orchestrator are now design invariants, not config knobs. The trainer always runs exactly one step ahead of inference, and the orchestrator always adopts the freshest checkpoint that doesn't violate the one-step barrier. The shared top-level `max_async_level`, the per-sub-config `trainer.max_async_level` / `orchestrator.max_async_level`, and `orchestrator.strict_async_level` have all been removed. Existing configs setting any of these must drop the field; the previous defaults (`max_async_level = 1`, `strict_async_level = false`) match the new hardcoded behavior. Bench mode no longer bypasses the weight-ckpt wait (the `int(1e9)` workaround is gone) and `multimodal/rl_color_codeword_feat_renderer.toml`'s prior `max_async_level = 0` (fully synchronous on-policy) is no longer expressible. (2026-05-25) +- **`teacher_inference` removed from RL entrypoint**: The `[teacher_inference]` config block and the `deployment.num_teacher_gpus` / `deployment.num_teacher_nodes` fields have been removed. The `rl` entrypoint now only manages the student-policy inference server. External teachers (used by OPD and local-vLLM SFT) must be started manually (e.g. `CUDA_VISIBLE_DEVICES=1 uv run inference --model.name --server.port 8001 ...`) and pointed at via `[orchestrator.teacher.client]`. Existing configs using `[teacher_inference]` or `deployment.num_teacher_gpus` / `deployment.num_teacher_nodes` must drop those fields and bring up the teacher out-of-band. (2026-05-25) - **`rollouts_per_example` → `group_size`**: The orchestrator-level field, the group-level `[orchestrator.eval]` field, and the per-env `[[orchestrator.eval.env]]` field have all been renamed. The old name still parses as a validation alias (in both TOML and CLI), so existing configs keep working without changes; new configs should prefer `group_size`. (2026-05-22) - **`AdvantageInputs` / `AdvantageOutputs` are now per-group, and `AdvantageOutputs.advantages` is a plain `list[float]`** (second breaking change to this API in three weeks). `AdvantageInputs.rollouts` is now `list[vf.RolloutOutput]` (a single group) instead of `list[list[vf.RolloutOutput]]`, and `AdvantageOutputs.advantages` is now `list[float]` instead of a 2D `Float[Tensor, "num_examples rollouts_per_example"]`. `compute_advantages` calls `advantage_fn` once per group, which lets partial-group training (groups smaller than `rollouts_per_example` after rollout errors) round-trip without the previous bucket-by-size workaround. Custom advantage functions must drop the outer list dimension and return a list of floats — e.g. `AdvantageOutputs(advantages=(rewards - rewards.mean(dim=1, keepdim=True)).tolist())` becomes `AdvantageOutputs(advantages=[r - mean for r in rewards])` (or `.tolist()` if you keep torch internally). (2026-05-22) - **`[model.vlm]` requires `orchestrator.use_renderer = true`**: VLMs must go through the renderer path; the `vlm_requires_renderer` validator rejects `use_renderer = false` when `[model.vlm]` is set. The renderer owns the HF processor per-slot and ships generic `mm_kwargs` keyed by the model's forward signature. Since `use_renderer` already defaults to `true`, most VLM configs need no change. (2026-05-19) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 0dffd79fd3..72fe418d3e 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -35,6 +35,9 @@ class SlurmConfig(BaseConfig): pre_run_command: str | None = None """Shell command to run on the head node after cd, .env sourcing, and venv activation. Useful for cleanup like ``sudo pkill -f vllm``; wrap with ``srun bash -c '...'`` to fan out to all nodes.""" + cleanup_grace_period_seconds: int = Field(3600, ge=0) + """When a component exits (completion, crash, or SIGTERM), the job sends SIGTERM to the remaining processes and then waits up to this many seconds for them to exit before force-killing (SIGKILL) and releasing the allocation. Gives in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so this is only an upper bound. Set to 0 for an immediate force-kill. Should fit inside ``time`` so the job is not reaped mid-grace.""" + @property def template_vars(self) -> dict: """Common template variables for all SLURM templates.""" @@ -47,6 +50,7 @@ def template_vars(self) -> dict: "account": self.account, "time": self.time, "pre_run_command": self.pre_run_command, + "cleanup_grace_period_seconds": self.cleanup_grace_period_seconds, } @model_validator(mode="after") diff --git a/src/prime_rl/templates/inference.sbatch.j2 b/src/prime_rl/templates/inference.sbatch.j2 index 5ba6a37631..c206f15070 100755 --- a/src/prime_rl/templates/inference.sbatch.j2 +++ b/src/prime_rl/templates/inference.sbatch.j2 @@ -324,8 +324,26 @@ srun --kill-on-bad-exit=1 bash -c ' wait -n EXIT_CODE=$? echo "[$(hostname)] component exited with code $EXIT_CODE - terminating remaining processes" -for pid in $(jobs -p); do +REMAINING_PIDS=$(jobs -p) +for pid in $REMAINING_PIDS; do kill -TERM $pid 2>/dev/null || true done +{%- if cleanup_grace_period_seconds %} +# Give in-flight work a bounded window to flush after SIGTERM before we +# force-kill and release the allocation. The wait ends early as soon as every +# process exits, so this is only an upper bound. +GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period_seconds }})) +for pid in $REMAINING_PIDS; do + while kill -0 $pid 2>/dev/null && [ "$SECONDS" -lt "$GRACE_DEADLINE" ]; do + sleep 5 + done +done +for pid in $REMAINING_PIDS; do + if kill -0 $pid 2>/dev/null; then + echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period_seconds }}s grace period - force-killing" + kill -KILL $pid 2>/dev/null || true + fi +done +{%- endif %} exit $EXIT_CODE ' diff --git a/src/prime_rl/templates/multi_node_rl.sbatch.j2 b/src/prime_rl/templates/multi_node_rl.sbatch.j2 index 3cb2718fc6..3044258b88 100755 --- a/src/prime_rl/templates/multi_node_rl.sbatch.j2 +++ b/src/prime_rl/templates/multi_node_rl.sbatch.j2 @@ -421,8 +421,26 @@ else wait -n EXIT_CODE=$? echo "[$(hostname)] component exited with code $EXIT_CODE - terminating remaining processes" -for pid in $(jobs -p); do +REMAINING_PIDS=$(jobs -p) +for pid in $REMAINING_PIDS; do kill -TERM $pid 2>/dev/null || true done +{%- if cleanup_grace_period_seconds %} +# Give in-flight work (notably trainer checkpoint writes) a bounded window to +# flush after SIGTERM before we force-kill and release the allocation. The wait +# ends early as soon as every process exits, so this is only an upper bound. +GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period_seconds }})) +for pid in $REMAINING_PIDS; do + while kill -0 $pid 2>/dev/null && [ "$SECONDS" -lt "$GRACE_DEADLINE" ]; do + sleep 5 + done +done +for pid in $REMAINING_PIDS; do + if kill -0 $pid 2>/dev/null; then + echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period_seconds }}s grace period - force-killing" + kill -KILL $pid 2>/dev/null || true + fi +done +{%- endif %} exit $EXIT_CODE ' From 3a45b7ff09333b76290848843beb9d50f632bab9 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 29 May 2026 21:26:56 +0000 Subject: [PATCH 08/13] refactor: rename cleanup_grace_period_seconds to cleanup_grace_period Drop the _seconds suffix; the unit is documented in the field docstring. Co-authored-by: Cursor --- CHANGELOG.md | 2 +- packages/prime-rl-configs/src/prime_rl/configs/shared.py | 6 +++--- src/prime_rl/templates/inference.sbatch.j2 | 6 +++--- src/prime_rl/templates/multi_node_rl.sbatch.j2 | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7aa196d745..5d4acbdb53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. -- **`slurm.cleanup_grace_period_seconds` (NEW)**: Added `cleanup_grace_period_seconds: int` to `SlurmConfig` (default: `3600`). When a component exits (completion, crash, or SIGTERM), the multi-node RL and inference sbatch teardown now sends SIGTERM and then waits up to this many seconds for the remaining processes to exit before force-killing (SIGKILL) and releasing the allocation, giving in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so it is only an upper bound. **Behavioral change**: teardown previously force-killed immediately; set `slurm.cleanup_grace_period_seconds = 0` to restore that. Should fit inside `slurm.time` so the job is not reaped mid-grace. (2026-05-29) +- **`slurm.cleanup_grace_period` (NEW)**: Added `cleanup_grace_period: int` (seconds) to `SlurmConfig` (default: `3600`). When a component exits (completion, crash, or SIGTERM), the multi-node RL and inference sbatch teardown now sends SIGTERM and then waits up to this many seconds for the remaining processes to exit before force-killing (SIGKILL) and releasing the allocation, giving in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so it is only an upper bound. **Behavioral change**: teardown previously force-killed immediately; set `slurm.cleanup_grace_period = 0` to restore that. Should fit inside `slurm.time` so the job is not reaped mid-grace. (2026-05-29) - **`sampling.min_tokens`, `sampling.repetition_penalty`, `sampling.seed` removed**: Dropped from both `TrainSamplingConfig` and `EvalSamplingConfig` (group-level `[orchestrator.train.sampling]` / `[orchestrator.eval.sampling]` and per-env `[[orchestrator.train.env.sampling]]` / `[[orchestrator.eval.env.sampling]]`). `min_tokens` suppressed natural EOS, `repetition_penalty` distorts the on-policy sampling distribution, and `seed` wasn't pulling its weight — none belonged on the supported config surface. Existing configs setting any of these must delete the field. Hard-deprecation, no migration window. (2026-05-27) - **`wandb.shared` removed**: The deprecation shim that popped `wandb.shared` from input dicts with a `FutureWarning` (introduced in #2649) is gone. The `rl` entrypoint always uses shared W&B mode now, and existing configs that still set `wandb.shared = true` (or `false`) will fail validation. Drop the field from your config. (2026-05-27) - **`max_async_level` and `strict_async_level` removed**: The async-execution semantics between trainer and orchestrator are now design invariants, not config knobs. The trainer always runs exactly one step ahead of inference, and the orchestrator always adopts the freshest checkpoint that doesn't violate the one-step barrier. The shared top-level `max_async_level`, the per-sub-config `trainer.max_async_level` / `orchestrator.max_async_level`, and `orchestrator.strict_async_level` have all been removed. Existing configs setting any of these must drop the field; the previous defaults (`max_async_level = 1`, `strict_async_level = false`) match the new hardcoded behavior. Bench mode no longer bypasses the weight-ckpt wait (the `int(1e9)` workaround is gone) and `multimodal/rl_color_codeword_feat_renderer.toml`'s prior `max_async_level = 0` (fully synchronous on-policy) is no longer expressible. (2026-05-25) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 72fe418d3e..b1a1557be5 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -35,8 +35,8 @@ class SlurmConfig(BaseConfig): pre_run_command: str | None = None """Shell command to run on the head node after cd, .env sourcing, and venv activation. Useful for cleanup like ``sudo pkill -f vllm``; wrap with ``srun bash -c '...'`` to fan out to all nodes.""" - cleanup_grace_period_seconds: int = Field(3600, ge=0) - """When a component exits (completion, crash, or SIGTERM), the job sends SIGTERM to the remaining processes and then waits up to this many seconds for them to exit before force-killing (SIGKILL) and releasing the allocation. Gives in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so this is only an upper bound. Set to 0 for an immediate force-kill. Should fit inside ``time`` so the job is not reaped mid-grace.""" + cleanup_grace_period: int = Field(3600, ge=0) + """Seconds to wait for processes to exit during teardown. When a component exits (completion, crash, or SIGTERM), the job sends SIGTERM to the remaining processes and then waits up to this long for them to exit before force-killing (SIGKILL) and releasing the allocation. Gives in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so this is only an upper bound. Set to 0 for an immediate force-kill. Should fit inside ``time`` so the job is not reaped mid-grace.""" @property def template_vars(self) -> dict: @@ -50,7 +50,7 @@ def template_vars(self) -> dict: "account": self.account, "time": self.time, "pre_run_command": self.pre_run_command, - "cleanup_grace_period_seconds": self.cleanup_grace_period_seconds, + "cleanup_grace_period": self.cleanup_grace_period, } @model_validator(mode="after") diff --git a/src/prime_rl/templates/inference.sbatch.j2 b/src/prime_rl/templates/inference.sbatch.j2 index c206f15070..4d2b76a9e2 100755 --- a/src/prime_rl/templates/inference.sbatch.j2 +++ b/src/prime_rl/templates/inference.sbatch.j2 @@ -328,11 +328,11 @@ REMAINING_PIDS=$(jobs -p) for pid in $REMAINING_PIDS; do kill -TERM $pid 2>/dev/null || true done -{%- if cleanup_grace_period_seconds %} +{%- if cleanup_grace_period %} # Give in-flight work a bounded window to flush after SIGTERM before we # force-kill and release the allocation. The wait ends early as soon as every # process exits, so this is only an upper bound. -GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period_seconds }})) +GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period }})) for pid in $REMAINING_PIDS; do while kill -0 $pid 2>/dev/null && [ "$SECONDS" -lt "$GRACE_DEADLINE" ]; do sleep 5 @@ -340,7 +340,7 @@ for pid in $REMAINING_PIDS; do done for pid in $REMAINING_PIDS; do if kill -0 $pid 2>/dev/null; then - echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period_seconds }}s grace period - force-killing" + echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period }}s grace period - force-killing" kill -KILL $pid 2>/dev/null || true fi done diff --git a/src/prime_rl/templates/multi_node_rl.sbatch.j2 b/src/prime_rl/templates/multi_node_rl.sbatch.j2 index 3044258b88..96273acc19 100755 --- a/src/prime_rl/templates/multi_node_rl.sbatch.j2 +++ b/src/prime_rl/templates/multi_node_rl.sbatch.j2 @@ -425,11 +425,11 @@ REMAINING_PIDS=$(jobs -p) for pid in $REMAINING_PIDS; do kill -TERM $pid 2>/dev/null || true done -{%- if cleanup_grace_period_seconds %} +{%- if cleanup_grace_period %} # Give in-flight work (notably trainer checkpoint writes) a bounded window to # flush after SIGTERM before we force-kill and release the allocation. The wait # ends early as soon as every process exits, so this is only an upper bound. -GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period_seconds }})) +GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period }})) for pid in $REMAINING_PIDS; do while kill -0 $pid 2>/dev/null && [ "$SECONDS" -lt "$GRACE_DEADLINE" ]; do sleep 5 @@ -437,7 +437,7 @@ for pid in $REMAINING_PIDS; do done for pid in $REMAINING_PIDS; do if kill -0 $pid 2>/dev/null; then - echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period_seconds }}s grace period - force-killing" + echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period }}s grace period - force-killing" kill -KILL $pid 2>/dev/null || true fi done From 57e751083eddfca03583ced108684eb6333edbfe Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 29 May 2026 21:29:54 +0000 Subject: [PATCH 09/13] chore: drop CHANGELOG entry for cleanup_grace_period Co-authored-by: Cursor --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d4acbdb53..e245407e5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,6 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. -- **`slurm.cleanup_grace_period` (NEW)**: Added `cleanup_grace_period: int` (seconds) to `SlurmConfig` (default: `3600`). When a component exits (completion, crash, or SIGTERM), the multi-node RL and inference sbatch teardown now sends SIGTERM and then waits up to this many seconds for the remaining processes to exit before force-killing (SIGKILL) and releasing the allocation, giving in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so it is only an upper bound. **Behavioral change**: teardown previously force-killed immediately; set `slurm.cleanup_grace_period = 0` to restore that. Should fit inside `slurm.time` so the job is not reaped mid-grace. (2026-05-29) - **`sampling.min_tokens`, `sampling.repetition_penalty`, `sampling.seed` removed**: Dropped from both `TrainSamplingConfig` and `EvalSamplingConfig` (group-level `[orchestrator.train.sampling]` / `[orchestrator.eval.sampling]` and per-env `[[orchestrator.train.env.sampling]]` / `[[orchestrator.eval.env.sampling]]`). `min_tokens` suppressed natural EOS, `repetition_penalty` distorts the on-policy sampling distribution, and `seed` wasn't pulling its weight — none belonged on the supported config surface. Existing configs setting any of these must delete the field. Hard-deprecation, no migration window. (2026-05-27) - **`wandb.shared` removed**: The deprecation shim that popped `wandb.shared` from input dicts with a `FutureWarning` (introduced in #2649) is gone. The `rl` entrypoint always uses shared W&B mode now, and existing configs that still set `wandb.shared = true` (or `false`) will fail validation. Drop the field from your config. (2026-05-27) - **`max_async_level` and `strict_async_level` removed**: The async-execution semantics between trainer and orchestrator are now design invariants, not config knobs. The trainer always runs exactly one step ahead of inference, and the orchestrator always adopts the freshest checkpoint that doesn't violate the one-step barrier. The shared top-level `max_async_level`, the per-sub-config `trainer.max_async_level` / `orchestrator.max_async_level`, and `orchestrator.strict_async_level` have all been removed. Existing configs setting any of these must drop the field; the previous defaults (`max_async_level = 1`, `strict_async_level = false`) match the new hardcoded behavior. Bench mode no longer bypasses the weight-ckpt wait (the `int(1e9)` workaround is gone) and `multimodal/rl_color_codeword_feat_renderer.toml`'s prior `max_async_level = 0` (fully synchronous on-policy) is no longer expressible. (2026-05-25) From 73ca5a11cfce6e9b3e4e7f9bf67d0d12fc2df067 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 29 May 2026 21:40:40 +0000 Subject: [PATCH 10/13] fix: make cleanup grace actually cover cross-node teardown The previous SIGTERM-then-wait approach didn't help the target case (inference dies while the trainer is mid-checkpoint on another node): that teardown is driven by `srun --kill-on-bad-exit=1`, which reaps the trainer task via SLURM's own KillWait path and never runs our in-task grace loop. Instead, on a non-zero exit the failing node now stays alive (signalling nothing) for the grace period before propagating the exit. Because --kill-on-bad-exit only fires when a task exits, holding the failing task keeps peer nodes' checkpointing trainers running untouched until they flush. Clean (zero-exit) completion is unaffected. Scope to multi_node_rl only; the inference-only template has no trainer checkpoints to protect, so it reverts to immediate teardown. Co-authored-by: Cursor --- .../src/prime_rl/configs/shared.py | 2 +- src/prime_rl/templates/inference.sbatch.j2 | 20 +----------- .../templates/multi_node_rl.sbatch.j2 | 32 ++++++++----------- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index b1a1557be5..ef51f2f181 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -36,7 +36,7 @@ class SlurmConfig(BaseConfig): """Shell command to run on the head node after cd, .env sourcing, and venv activation. Useful for cleanup like ``sudo pkill -f vllm``; wrap with ``srun bash -c '...'`` to fan out to all nodes.""" cleanup_grace_period: int = Field(3600, ge=0) - """Seconds to wait for processes to exit during teardown. When a component exits (completion, crash, or SIGTERM), the job sends SIGTERM to the remaining processes and then waits up to this long for them to exit before force-killing (SIGKILL) and releasing the allocation. Gives in-flight work — notably trainer checkpoint writes — a bounded window to flush. The wait ends as soon as all processes exit, so this is only an upper bound. Set to 0 for an immediate force-kill. Should fit inside ``time`` so the job is not reaped mid-grace.""" + """Seconds to keep the allocation alive when a multi-node RL job is torn down by a *non-zero* exit (crash, SIGTERM, wall-time), giving in-flight work — notably trainer checkpoint writes — time to flush. The failing node stays alive and signals nothing for this window, which both lets a local checkpoint finish and prevents ``srun --kill-on-bad-exit`` from reaping checkpointing trainers on peer nodes before they flush; teardown then proceeds. Clean (zero-exit) completion is unaffected and tears down immediately. This is a fixed wait, so pick a value that comfortably covers a checkpoint write. Set to 0 to tear down immediately on failure. Should fit inside ``time`` so the job is not reaped mid-grace.""" @property def template_vars(self) -> dict: diff --git a/src/prime_rl/templates/inference.sbatch.j2 b/src/prime_rl/templates/inference.sbatch.j2 index 4d2b76a9e2..5ba6a37631 100755 --- a/src/prime_rl/templates/inference.sbatch.j2 +++ b/src/prime_rl/templates/inference.sbatch.j2 @@ -324,26 +324,8 @@ srun --kill-on-bad-exit=1 bash -c ' wait -n EXIT_CODE=$? echo "[$(hostname)] component exited with code $EXIT_CODE - terminating remaining processes" -REMAINING_PIDS=$(jobs -p) -for pid in $REMAINING_PIDS; do +for pid in $(jobs -p); do kill -TERM $pid 2>/dev/null || true done -{%- if cleanup_grace_period %} -# Give in-flight work a bounded window to flush after SIGTERM before we -# force-kill and release the allocation. The wait ends early as soon as every -# process exits, so this is only an upper bound. -GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period }})) -for pid in $REMAINING_PIDS; do - while kill -0 $pid 2>/dev/null && [ "$SECONDS" -lt "$GRACE_DEADLINE" ]; do - sleep 5 - done -done -for pid in $REMAINING_PIDS; do - if kill -0 $pid 2>/dev/null; then - echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period }}s grace period - force-killing" - kill -KILL $pid 2>/dev/null || true - fi -done -{%- endif %} exit $EXIT_CODE ' diff --git a/src/prime_rl/templates/multi_node_rl.sbatch.j2 b/src/prime_rl/templates/multi_node_rl.sbatch.j2 index 96273acc19..a3e68e6888 100755 --- a/src/prime_rl/templates/multi_node_rl.sbatch.j2 +++ b/src/prime_rl/templates/multi_node_rl.sbatch.j2 @@ -420,27 +420,23 @@ else # tears down instead of leaving zombies holding the allocation. wait -n EXIT_CODE=$? -echo "[$(hostname)] component exited with code $EXIT_CODE - terminating remaining processes" REMAINING_PIDS=$(jobs -p) -for pid in $REMAINING_PIDS; do - kill -TERM $pid 2>/dev/null || true -done {%- if cleanup_grace_period %} -# Give in-flight work (notably trainer checkpoint writes) a bounded window to -# flush after SIGTERM before we force-kill and release the allocation. The wait -# ends early as soon as every process exits, so this is only an upper bound. -GRACE_DEADLINE=$((SECONDS + {{ cleanup_grace_period }})) -for pid in $REMAINING_PIDS; do - while kill -0 $pid 2>/dev/null && [ "$SECONDS" -lt "$GRACE_DEADLINE" ]; do - sleep 5 - done -done +# Cross-node teardown is driven by `srun --kill-on-bad-exit=1`, which reaps peer +# tasks only once a task *exits* non-zero. So on a non-zero exit (crash, SIGTERM, +# wall-time) we stay alive for the grace period and signal nothing: this lets an +# in-flight trainer checkpoint on this node finish, and — because the task has +# not exited yet — keeps --kill-on-bad-exit from reaping checkpointing trainers +# on peer nodes. We then propagate the exit code so the allocation tears down. A +# clean (zero) exit can't trip --kill-on-bad-exit, so we tear down immediately. +if [ "$EXIT_CODE" -ne 0 ]; then + echo "[$(hostname)] component exited with code $EXIT_CODE - waiting up to {{ cleanup_grace_period }}s for in-flight checkpoints to flush before teardown" + sleep {{ cleanup_grace_period }} +fi +{%- endif %} +echo "[$(hostname)] tearing down (exit code $EXIT_CODE) - terminating remaining processes" for pid in $REMAINING_PIDS; do - if kill -0 $pid 2>/dev/null; then - echo "[$(hostname)] process $pid still alive after {{ cleanup_grace_period }}s grace period - force-killing" - kill -KILL $pid 2>/dev/null || true - fi + kill -TERM $pid 2>/dev/null || true done -{%- endif %} exit $EXIT_CODE ' From 6685b3fed05b25a5ef92d158ab8be2288e97c92e Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 29 May 2026 21:48:24 +0000 Subject: [PATCH 11/13] docs: note cleanup_grace_period is multi-node RL only Co-authored-by: Cursor --- packages/prime-rl-configs/src/prime_rl/configs/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index ef51f2f181..a8393976f1 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -36,7 +36,7 @@ class SlurmConfig(BaseConfig): """Shell command to run on the head node after cd, .env sourcing, and venv activation. Useful for cleanup like ``sudo pkill -f vllm``; wrap with ``srun bash -c '...'`` to fan out to all nodes.""" cleanup_grace_period: int = Field(3600, ge=0) - """Seconds to keep the allocation alive when a multi-node RL job is torn down by a *non-zero* exit (crash, SIGTERM, wall-time), giving in-flight work — notably trainer checkpoint writes — time to flush. The failing node stays alive and signals nothing for this window, which both lets a local checkpoint finish and prevents ``srun --kill-on-bad-exit`` from reaping checkpointing trainers on peer nodes before they flush; teardown then proceeds. Clean (zero-exit) completion is unaffected and tears down immediately. This is a fixed wait, so pick a value that comfortably covers a checkpoint write. Set to 0 to tear down immediately on failure. Should fit inside ``time`` so the job is not reaped mid-grace.""" + """Seconds to keep the allocation alive when a multi-node RL job is torn down by a *non-zero* exit (crash, SIGTERM, wall-time), giving in-flight work — notably trainer checkpoint writes — time to flush. The failing node stays alive and signals nothing for this window, which both lets a local checkpoint finish and prevents ``srun --kill-on-bad-exit`` from reaping checkpointing trainers on peer nodes before they flush; teardown then proceeds. Clean (zero-exit) completion is unaffected and tears down immediately. This is a fixed wait, so pick a value that comfortably covers a checkpoint write. Set to 0 to tear down immediately on failure. Should fit inside ``time`` so the job is not reaped mid-grace. Only applies to multi-node RL: single-node SLURM teardown is driven by the Python launcher, not the sbatch script, so this knob has no effect there.""" @property def template_vars(self) -> dict: From d8a184c69af168dc177ad19fed05d823b2a2413f Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 29 May 2026 21:49:04 +0000 Subject: [PATCH 12/13] docs: minimize cleanup_grace_period docstring Co-authored-by: Cursor --- packages/prime-rl-configs/src/prime_rl/configs/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index a8393976f1..744e423a9b 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -36,7 +36,7 @@ class SlurmConfig(BaseConfig): """Shell command to run on the head node after cd, .env sourcing, and venv activation. Useful for cleanup like ``sudo pkill -f vllm``; wrap with ``srun bash -c '...'`` to fan out to all nodes.""" cleanup_grace_period: int = Field(3600, ge=0) - """Seconds to keep the allocation alive when a multi-node RL job is torn down by a *non-zero* exit (crash, SIGTERM, wall-time), giving in-flight work — notably trainer checkpoint writes — time to flush. The failing node stays alive and signals nothing for this window, which both lets a local checkpoint finish and prevents ``srun --kill-on-bad-exit`` from reaping checkpointing trainers on peer nodes before they flush; teardown then proceeds. Clean (zero-exit) completion is unaffected and tears down immediately. This is a fixed wait, so pick a value that comfortably covers a checkpoint write. Set to 0 to tear down immediately on failure. Should fit inside ``time`` so the job is not reaped mid-grace. Only applies to multi-node RL: single-node SLURM teardown is driven by the Python launcher, not the sbatch script, so this knob has no effect there.""" + """Seconds to wait before tearing down a multi-node RL job that hit a non-zero exit, letting in-flight checkpoints flush. Set to 0 to tear down immediately.""" @property def template_vars(self) -> dict: From ca0745d551b4e700a7fcbde7fef819fcda8ea05d Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 30 May 2026 04:50:15 +0530 Subject: [PATCH 13/13] Feat: retry/logging --- src/prime_rl/inference/vllm/server.py | 1 + src/prime_rl/utils/client.py | 35 ++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index 699b1cec75..f411695c78 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -61,6 +61,7 @@ def models(request: Request) -> OpenAIServingModels: @router.post("/pause") async def pause(request: Request): + logger.debug("Received /pause request (mode=keep, clear_cache=False)") await engine_client(request).pause_generation(mode="keep", clear_cache=False) return {"status": "paused"} diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index beb41e8ab6..61cc8a4620 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -306,13 +306,46 @@ async def _check_health(admin_client: AsyncClient) -> None: NCCL_READY_MARKER = "NCCL_READY" +def _is_retryable_pause_error(exception: BaseException) -> bool: + """Check if an exception should trigger a retry for pausing engines.""" + if isinstance(exception, httpx.HTTPStatusError): + # Retry on transient server errors (5xx, e.g. engine briefly unresponsive); + # client errors (4xx) won't fix themselves on retry. + return exception.response.status_code >= 500 + # Retry on transport-level failures (timeouts, connection resets, etc.) so the + # per-attempt read timeout below turns a stuck server into a bounded retry loop + # instead of hanging forever on the global timeout=None admin client. + if isinstance(exception, (httpx.TimeoutException, httpx.TransportError)): + return True + return False + + +# Per-attempt and total bounds for `/pause`. Pausing drains in-flight requests +# (mode="keep"), so a single attempt can legitimately take a while, but the global +# admin AsyncClient uses `timeout=None`, so a stuck server would hang the weight +# update forever. `_READ_TIMEOUT` converts a hang into a TimeoutException so +# tenacity retries; `_TOTAL` is the wall-clock budget across all retries. +PAUSE_READ_TIMEOUT_S = 120.0 +PAUSE_TOTAL_TIMEOUT_S = 300.0 + + async def _pause_engines(admin_clients: list[AsyncClient]) -> None: """Pause all inference engines, waiting for in-flight requests to drain.""" logger = get_logger() logger.info("Pausing inference engines for weight update") + @retry( + retry=retry_if_exception(_is_retryable_pause_error), + stop=stop_after_delay(PAUSE_TOTAL_TIMEOUT_S) | stop_after_attempt(10), + wait=wait_exponential(multiplier=1, min=1, max=10), + reraise=True, + ) async def _pause(client: AsyncClient) -> None: - response = await client.post("/pause", params={"mode": "keep", "clear_cache": "false"}) + response = await client.post( + "/pause", + params={"mode": "keep", "clear_cache": "false"}, + timeout=httpx.Timeout(connect=10.0, read=PAUSE_READ_TIMEOUT_S, write=60.0, pool=10.0), + ) response.raise_for_status() await asyncio.gather(*[_pause(client) for client in admin_clients])