From 2e1421eda6b1b03ecbe4fbc59b11c040f773da63 Mon Sep 17 00:00:00 2001 From: Markus Frey Date: Fri, 19 Jun 2026 16:20:03 +0200 Subject: [PATCH] minimal changes for olmes integration --- docs/components/downstream_evaluation.md | 49 ++++++ src/modalities/config/config.py | 9 + src/modalities/config/instantiation_models.py | 2 + src/modalities/config/pydantic_if_types.py | 2 + src/modalities/evaluator.py | 140 ++++++++++++++++ src/modalities/gym.py | 4 +- src/modalities/main.py | 12 ++ src/modalities/registry/components.py | 3 + src/modalities/trainer.py | 5 + tests/test_downstream_evaluator.py | 156 ++++++++++++++++++ 10 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 docs/components/downstream_evaluation.md create mode 100644 tests/test_downstream_evaluator.py diff --git a/docs/components/downstream_evaluation.md b/docs/components/downstream_evaluation.md new file mode 100644 index 000000000..98441f8f6 --- /dev/null +++ b/docs/components/downstream_evaluation.md @@ -0,0 +1,49 @@ +# Downstream Evaluation Pipeline + +The downstream evaluation pipeline in Modalities is a callback system that executes at configurable step intervals during the training loop. + +The order of execution inside `Trainer.train` is: +1. `checkpointing_callback`: Saves the PyTorch/FSDP checkpoint to disk. +2. `downstream_evaluation_callback`: (Optional) Runs external evaluation tools (like OLMES) on the newly created HF checkpoint. + +--- + +## Downstream Evaluation Callback (`DownstreamEvaluator`) + +**Location:** `src/modalities/evaluator.py` + +The `DownstreamEvaluator` checks for the existence of an HF checkpoint, launches an evaluation script via a subprocess, tracks active processes, and syncs OLMES metrics to the active W&B run. + +### Behavior +- Triggered if `num_train_steps_done % eval_interval == 0`. +- Only executes on `global_rank == 0`. +- Reads `last_checkpoint_info.json` to find the latest checkpoint. +- Checks if `{checkpoint_path}/hf_checkpoint` exists. If it does NOT exist, evaluation is skipped with a warning (assuming conversion failed or was disabled). +- If the HF checkpoint exists, it formats the `olmes_command_template` and launches it asynchronously using `subprocess.Popen(cmd, shell=True)`. +- **Process Tracking**: Stores `(Popen, step, hf_model_dir)` tuples in `self.active_processes`. +- **Graceful Exit**: `wait_for_evaluations()` iterates over `active_processes`, calls `.wait()`, and syncs metrics after each evaluation completes. +- **W&B Metric Sync**: Extracts `primary_score` for each task alias from the OLMES output file and logs them to the active `wandb.run` as `downstream/{alias}` at the correct training step. Gracefully skips if W&B is disabled or not installed. + +### Placeholders +The `olmes_command_template` string can use the following placeholders: +- `{hf_model_dir}`: The path to the `{checkpoint_path}/hf_checkpoint` directory. +- `{tasks}`: A space-separated string of the tasks provided in the config. +- `{step}`: The current `num_train_steps_done`. + +### YAML Configuration +```yaml +downstream_evaluator: + component_key: downstream_evaluator + variant_key: default + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + tasks: + - "arc_challenge::olmes" + - "hellaswag::olmes" + eval_interval: 100 + checkpoint_dir: ${settings.paths.experiments_root_path}/${settings.experiment_id} + global_rank: ${settings.cuda_env.global_rank} + olmes_command_template: "bash scripts/evaluation/run_olmes_sbatch.sh {hf_model_dir} '{tasks}' {step} 1024 1" +``` diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..48e34b0d6 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -520,6 +520,15 @@ class ParallelDegreeConfig(BaseModel): parallelism_methods: list[ParallelismDegrees] +class DownstreamEvaluatorConfig(BaseModel): + tokenizer: PydanticTokenizerIFType + tasks: list[str] + eval_interval: Annotated[int, Field(strict=True, gt=0)] + checkpoint_dir: Path + global_rank: Annotated[int, Field(strict=True, ge=0)] + olmes_command_template: str + + # Recursive type representing arbitrary-depth YAML config structures. YAMLPrimitive = str | int | float | bool | None YAMLValue: TypeAlias = YAMLPrimitive | Path | list["YAMLValue"] | dict[str, "YAMLValue"] diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index fd7fd3b78..982d6d5c9 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -22,6 +22,7 @@ PydanticSteppableProfilerIFType, PydanticTextInferenceComponentType, PydanticTokenizerIFType, + PydanticDownstreamEvaluatorType, ) from modalities.config.utils import parse_torch_device from modalities.dataloader.dataset import Dataset @@ -192,6 +193,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel mfu_calculator: PydanticMFUCalculatorABCType | None = None scheduled_pipeline: PydanticPipelineType | None = None device_mesh: PydanticDeviceMeshIFType | None = None + downstream_evaluator: Optional[PydanticDownstreamEvaluatorType] = None model_raw: PydanticPytorchModuleType @model_validator(mode="after") diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index 90b7ca951..1460a6aab 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -23,6 +23,7 @@ from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator +from modalities.evaluator import DownstreamEvaluator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -98,3 +99,4 @@ def __get_pydantic_core_schema__( torch.utils.hooks.RemovableHandle, PydanticThirdPartyTypeIF(torch.utils.hooks.RemovableHandle) ] PydanticDebuggingType = Annotated[Debugging, PydanticThirdPartyTypeIF(Debugging)] +PydanticDownstreamEvaluatorType = Annotated[DownstreamEvaluator, PydanticThirdPartyTypeIF(DownstreamEvaluator)] diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index fb9bdc0d3..7ef66ea6f 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -1,10 +1,16 @@ from typing import Callable +import json +import logging +import subprocess +from pathlib import Path import torch import torch.distributed as dist import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh +from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper + from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch, ResultItem from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate @@ -15,6 +21,8 @@ from modalities.running_env.fsdp.reducer import Reducer from modalities.util import TimeRecorder +logger = logging.getLogger(__name__) + class Evaluator: """Evaluator class which is responsible for evaluating the model on a set of datasets""" @@ -197,3 +205,135 @@ def _publish_evaluation_result( evaluation_result_publisher.publish_message( payload=evaluation_result, message_type=MessageTypes.EVALUATION_RESULT ) + + +class DownstreamEvaluator: + """Evaluator that runs OLMES on HF checkpoints produced by the conversion callback. + + Checks if an ``hf_checkpoint`` folder exists inside the latest checkpoint directory + (as written by ``ModelConverter``). If it does, the configured OLMES command template + is executed via subprocess. + """ + + def __init__( + self, + tokenizer: TokenizerWrapper, + tasks: list[str], + eval_interval: int, + checkpoint_dir: Path, + global_rank: int, + olmes_command_template: str, + ) -> None: + self.tokenizer = tokenizer + self.tasks = tasks + self.eval_interval = eval_interval + self.checkpoint_dir = Path(checkpoint_dir) + self.global_rank = global_rank + self.olmes_command_template = olmes_command_template + self.active_processes: list[tuple[subprocess.Popen, int, Path]] = [] + + def evaluate(self, num_train_steps_done: int) -> None: + if num_train_steps_done == 0 or num_train_steps_done % self.eval_interval != 0: + return + if self.global_rank != 0: + return + + hf_model_dir = self._find_hf_checkpoint() + if hf_model_dir is None: + logger.warning( + f"No hf_checkpoint found in {self.checkpoint_dir} at step {num_train_steps_done}, " + "skipping downstream evaluation." + ) + return + + tasks_str = " ".join(self.tasks) + cmd = self.olmes_command_template.format( + hf_model_dir=str(hf_model_dir), + tasks=tasks_str, + step=num_train_steps_done, + ) + + logger.info(f"Running downstream evaluation: {cmd}") + try: + p = subprocess.Popen(cmd, shell=True) + self.active_processes.append((p, num_train_steps_done, hf_model_dir)) + logger.info(f"Downstream evaluation launched for step {num_train_steps_done}.") + except Exception as e: + logger.error(f"Failed to launch downstream evaluation: {e}") + + def wait_for_evaluations(self) -> None: + if not hasattr(self, "active_processes") or not self.active_processes: + return + + logger.info(f"Waiting for {len(self.active_processes)} downstream evaluations to finish...") + for p, step, hf_model_dir in self.active_processes: + p.wait() + if p.returncode == 0: + self._sync_metrics_to_wandb(step, hf_model_dir) + else: + logger.warning(f"Downstream evaluation for step {step} exited with code {p.returncode}, skipping W&B sync.") + logger.info("All downstream evaluations finished.") + self.active_processes = [] + + def _sync_metrics_to_wandb(self, step: int, hf_model_dir: Path) -> None: + """Parse OLMES metrics-all.jsonl and log primary scores to the active W&B run.""" + metrics_file = hf_model_dir / f"olmes_eval_{step}" / "metrics-all.jsonl" + if not metrics_file.exists(): + logger.warning(f"No metrics file found at {metrics_file}, skipping W&B sync for step {step}.") + return + + metrics_dict = {} + try: + with open(metrics_file, "r", encoding="utf-8") as f: + for line in f: + obj = json.loads(line) + alias = ( + obj.get("task_config", {}).get("metadata", {}).get("alias") + or obj.get("task_name") + ) + score = obj.get("metrics", {}).get("primary_score") + if alias and score is not None: + metrics_dict[f"downstream/{alias}"] = score + except Exception as e: + logger.error(f"Failed to parse metrics file {metrics_file}: {e}") + return + + if not metrics_dict: + logger.warning(f"No metrics extracted from {metrics_file} for step {step}.") + return + + try: + import wandb + + if wandb.run is not None: + # Define a custom step metric so downstream/* metrics are decoupled from + # the global training step counter (which is already past these steps). + wandb.run.define_metric("downstream_step", hidden=True) + wandb.run.define_metric("downstream/*", step_metric="downstream_step") + metrics_dict["downstream_step"] = step + wandb.run.log(metrics_dict) + logger.info(f"Synced {len(metrics_dict)} OLMES metrics to W&B at step {step}: {metrics_dict}") + else: + logger.info(f"W&B not active, skipping metric sync for step {step}.") + except ImportError: + logger.info(f"wandb not installed, skipping metric sync for step {step}.") + + def _find_hf_checkpoint(self) -> Path | None: + """Read last_checkpoint_info.json and check for hf_checkpoint subfolder.""" + info_file = self.checkpoint_dir / "last_checkpoint_info.json" + if not info_file.exists(): + return None + + with open(info_file, "r", encoding="utf-8") as f: + info = json.load(f) + + checkpoint_path_str = info.get("checkpoint_folder_path") or info.get("model_checkpoint_path") + if checkpoint_path_str is None: + return None + + checkpoint_path = Path(checkpoint_path_str) + if checkpoint_path.is_file(): + checkpoint_path = checkpoint_path.parent + + hf_dir = checkpoint_path / "hf_checkpoint" + return hf_dir if hf_dir.exists() else None diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 010c4ca60..8e673d238 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -1,6 +1,6 @@ from datetime import datetime from functools import partial -from typing import Callable +from typing import Callable, Optional import torch.nn as nn @@ -42,6 +42,7 @@ def run( evaluation_data_loaders: list[LLMDataLoader], checkpoint_saving: CheckpointSaving, scheduled_pipeline: Pipeline | None = None, + downstream_evaluation_callback: Optional[Callable[[int], None]] = None, ): """Runs the model training, including evaluation and checkpointing. @@ -80,6 +81,7 @@ def run( checkpointing_callback=checkpointing_callback, training_log_interval_in_steps=training_log_interval_in_steps, scheduled_pipeline=scheduled_pipeline, + downstream_evaluation_callback=downstream_evaluation_callback, ) print_rank_0(f"Training done at {datetime.now()}.") diff --git a/src/modalities/main.py b/src/modalities/main.py index 49ac97b91..535cd19e9 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -220,6 +220,10 @@ def run(self, components: TrainingComponentsInstantiationModel): print_rank_0(report) + downstream_evaluation_callback = None + if getattr(components, "downstream_evaluator", None) is not None: + downstream_evaluation_callback = components.downstream_evaluator.evaluate + gym.run( train_data_loader=components.train_dataloader, evaluation_data_loaders=components.eval_dataloaders, @@ -229,8 +233,16 @@ def run(self, components: TrainingComponentsInstantiationModel): evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps, training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps, scheduled_pipeline=components.scheduled_pipeline, + downstream_evaluation_callback=downstream_evaluation_callback, ) + if getattr(components, "downstream_evaluator", None) is not None: + print_rank_0("\n" + "="*80) + print_rank_0("Training loop complete! Waiting for background evaluations to finish...") + print_rank_0("="*80 + "\n") + components.downstream_evaluator.wait_for_evaluations() + print_rank_0("All background evaluations completed successfully!") + def get_logging_publishers( self, progress_subscriber: MessageSubscriberIF[ProgressUpdate], diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..a1d44f42c 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -68,6 +68,7 @@ TorchCheckpointLoadingConfig, WandBEvaluationResultSubscriberConfig, WeightInitializedModelConfig, + DownstreamEvaluatorConfig, ) from modalities.dataloader.collate_fns.collator_fn_wrapper_for_loss_masking import ( LossMaskingCollateFnWrapper, @@ -83,6 +84,7 @@ ResultsSubscriberFactory, ) from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.evaluator import DownstreamEvaluator from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import ( @@ -528,4 +530,5 @@ class ComponentEntity: maybe_model_list(HookRegistration.register_print_forward_hooks), PrintForwardHookConfig, ), + ComponentEntity("downstream_evaluator", "default", DownstreamEvaluator, DownstreamEvaluatorConfig), ] diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 4ad54b226..16a22214f 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -207,6 +207,7 @@ def train( evaluation_callback: Callable[[int], None], checkpointing_callback: Callable[[TrainingProgress], None], scheduled_pipeline: Pipeline | None = None, + downstream_evaluation_callback: Optional[Callable[[int], None]] = None, ): """ Trains the model. @@ -257,6 +258,8 @@ def train( num_target_tokens=self.num_target_tokens, ) checkpointing_callback(training_progress=training_progress) + if downstream_evaluation_callback is not None: + downstream_evaluation_callback(num_train_steps_done=self.num_seen_train_steps) num_steps_todo = self.num_target_steps - self.num_seen_train_steps num_batches_todo = num_steps_todo * self.gradient_acc_steps @@ -388,6 +391,8 @@ def train( self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) + if downstream_evaluation_callback is not None: + downstream_evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) profiler_cm.step() diff --git a/tests/test_downstream_evaluator.py b/tests/test_downstream_evaluator.py new file mode 100644 index 000000000..bbfa0fbef --- /dev/null +++ b/tests/test_downstream_evaluator.py @@ -0,0 +1,156 @@ +import json +import tempfile +from pathlib import Path +from unittest.mock import patch + +from pydantic import BaseModel + +from modalities.config.component_factory import ComponentFactory +from modalities.evaluator import DownstreamEvaluator +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper + + +# ---------- helpers ---------- + +class MockTokenizer(TokenizerWrapper): + def tokenize(self, text: str) -> list[int]: + return [] + + def decode(self, input_ids: list[int]) -> str: + return "" + + @property + def vocab_size(self) -> int: + return 0 + + def get_token_id(self, token: str) -> int: + return 0 + + def is_special_token_id(self, token_id: int) -> bool: + return False + + +# ---------- DownstreamEvaluator tests ---------- + +def test_downstream_evaluator_skips_non_matching_step(): + evaluator = DownstreamEvaluator( + tokenizer=MockTokenizer(), + tasks=["arc_challenge::olmes"], + eval_interval=5, + checkpoint_dir=Path("/tmp/fake"), + global_rank=0, + olmes_command_template="echo {hf_model_dir} {tasks} {step}", + ) + with patch("subprocess.Popen") as mock_popen: + evaluator.evaluate(num_train_steps_done=3) + mock_popen.assert_not_called() + + +def test_downstream_evaluator_skips_non_rank_zero(): + evaluator = DownstreamEvaluator( + tokenizer=MockTokenizer(), + tasks=["arc_challenge::olmes"], + eval_interval=5, + checkpoint_dir=Path("/tmp/fake"), + global_rank=1, + olmes_command_template="echo {hf_model_dir} {tasks} {step}", + ) + with patch("subprocess.Popen") as mock_popen: + evaluator.evaluate(num_train_steps_done=5) + mock_popen.assert_not_called() + + +def test_downstream_evaluator_runs_when_hf_checkpoint_exists(): + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = Path(tmpdir) + ckpt_path = checkpoint_dir / "step_10" + ckpt_path.mkdir() + hf_dir = ckpt_path / "hf_checkpoint" + hf_dir.mkdir() + + info = {"checkpoint_folder_path": str(ckpt_path)} + with open(checkpoint_dir / "last_checkpoint_info.json", "w") as f: + json.dump(info, f) + + evaluator = DownstreamEvaluator( + tokenizer=MockTokenizer(), + tasks=["arc_challenge::olmes", "hellaswag::olmes"], + eval_interval=10, + checkpoint_dir=checkpoint_dir, + global_rank=0, + olmes_command_template="olmes --model {hf_model_dir} --tasks {tasks} --step {step}", + ) + + with patch("subprocess.Popen") as mock_popen: + evaluator.evaluate(num_train_steps_done=10) + mock_popen.assert_called_once() + cmd = mock_popen.call_args[0][0] + assert str(hf_dir) in cmd + assert "arc_challenge::olmes hellaswag::olmes" in cmd + assert "10" in cmd + + +def test_downstream_evaluator_skips_when_no_hf_checkpoint(): + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = Path(tmpdir) + ckpt_path = checkpoint_dir / "step_10" + ckpt_path.mkdir() + # No hf_checkpoint folder + + info = {"checkpoint_folder_path": str(ckpt_path)} + with open(checkpoint_dir / "last_checkpoint_info.json", "w") as f: + json.dump(info, f) + + evaluator = DownstreamEvaluator( + tokenizer=MockTokenizer(), + tasks=["arc_challenge::olmes"], + eval_interval=10, + checkpoint_dir=checkpoint_dir, + global_rank=0, + olmes_command_template="echo {hf_model_dir} {tasks} {step}", + ) + + with patch("subprocess.Popen") as mock_popen: + evaluator.evaluate(num_train_steps_done=10) + mock_popen.assert_not_called() + + +# ---------- Factory instantiation tests ---------- + +def test_downstream_evaluator_factory_instantiation(): + from modalities.config.pydantic_if_types import PydanticDownstreamEvaluatorType + + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + + tokenizer_mock = MockTokenizer() + + class TrainingModel(BaseModel): + downstream_eval: PydanticDownstreamEvaluatorType + + config_dict = { + "downstream_eval": { + "component_key": "downstream_evaluator", + "variant_key": "default", + "config": { + "tokenizer": tokenizer_mock, + "tasks": ["task_a"], + "eval_interval": 10, + "checkpoint_dir": "/tmp/test_checkpoints", + "global_rank": 0, + "olmes_command_template": "echo {hf_model_dir}", + }, + } + } + + components = component_factory.build_components( + config_dict=config_dict, + components_model_type=TrainingModel, + ) + + assert isinstance(components.downstream_eval, DownstreamEvaluator) + assert components.downstream_eval.tokenizer == tokenizer_mock + assert components.downstream_eval.tasks == ["task_a"] + assert components.downstream_eval.eval_interval == 10