Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/components/downstream_evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Downstream Evaluation Pipeline

The downstream evaluation pipeline in Modalities is a callback system that executes at configurable step intervals during the training loop.

The order of execution inside `Trainer.train` is:
1. `checkpointing_callback`: Saves the PyTorch/FSDP checkpoint to disk.
2. `downstream_evaluation_callback`: (Optional) Runs external evaluation tools (like OLMES) on the newly created HF checkpoint.

---

## Downstream Evaluation Callback (`DownstreamEvaluator`)

**Location:** `src/modalities/evaluator.py`

The `DownstreamEvaluator` checks for the existence of an HF checkpoint, launches an evaluation script via a subprocess, tracks active processes, and syncs OLMES metrics to the active W&B run.

### Behavior
- Triggered if `num_train_steps_done % eval_interval == 0`.
- Only executes on `global_rank == 0`.
- Reads `last_checkpoint_info.json` to find the latest checkpoint.
- Checks if `{checkpoint_path}/hf_checkpoint` exists. If it does NOT exist, evaluation is skipped with a warning (assuming conversion failed or was disabled).
- If the HF checkpoint exists, it formats the `olmes_command_template` and launches it asynchronously using `subprocess.Popen(cmd, shell=True)`.
- **Process Tracking**: Stores `(Popen, step, hf_model_dir)` tuples in `self.active_processes`.
- **Graceful Exit**: `wait_for_evaluations()` iterates over `active_processes`, calls `.wait()`, and syncs metrics after each evaluation completes.
- **W&B Metric Sync**: Extracts `primary_score` for each task alias from the OLMES output file and logs them to the active `wandb.run` as `downstream/{alias}` at the correct training step. Gracefully skips if W&B is disabled or not installed.

### Placeholders
The `olmes_command_template` string can use the following placeholders:
- `{hf_model_dir}`: The path to the `{checkpoint_path}/hf_checkpoint` directory.
- `{tasks}`: A space-separated string of the tasks provided in the config.
- `{step}`: The current `num_train_steps_done`.

### YAML Configuration
```yaml
downstream_evaluator:
component_key: downstream_evaluator
variant_key: default
config:
tokenizer:
instance_key: tokenizer
pass_type: BY_REFERENCE
tasks:
- "arc_challenge::olmes"
- "hellaswag::olmes"
eval_interval: 100
checkpoint_dir: ${settings.paths.experiments_root_path}/${settings.experiment_id}
global_rank: ${settings.cuda_env.global_rank}
olmes_command_template: "bash scripts/evaluation/run_olmes_sbatch.sh {hf_model_dir} '{tasks}' {step} 1024 1"
```
9 changes: 9 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,15 @@ class ParallelDegreeConfig(BaseModel):
parallelism_methods: list[ParallelismDegrees]


class DownstreamEvaluatorConfig(BaseModel):
tokenizer: PydanticTokenizerIFType
tasks: list[str]
eval_interval: Annotated[int, Field(strict=True, gt=0)]
checkpoint_dir: Path
global_rank: Annotated[int, Field(strict=True, ge=0)]
olmes_command_template: str


# Recursive type representing arbitrary-depth YAML config structures.
YAMLPrimitive = str | int | float | bool | None
YAMLValue: TypeAlias = YAMLPrimitive | Path | list["YAMLValue"] | dict[str, "YAMLValue"]
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PydanticSteppableProfilerIFType,
PydanticTextInferenceComponentType,
PydanticTokenizerIFType,
PydanticDownstreamEvaluatorType,
)
from modalities.config.utils import parse_torch_device
from modalities.dataloader.dataset import Dataset
Expand Down Expand Up @@ -192,6 +193,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
mfu_calculator: PydanticMFUCalculatorABCType | None = None
scheduled_pipeline: PydanticPipelineType | None = None
device_mesh: PydanticDeviceMeshIFType | None = None
downstream_evaluator: Optional[PydanticDownstreamEvaluatorType] = None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
downstream_evaluator: Optional[PydanticDownstreamEvaluatorType] = None
downstream_evaluator: PydanticDownstreamEvaluatorType | None = None

model_raw: PydanticPytorchModuleType

@model_validator(mode="after")
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator
from modalities.evaluator import DownstreamEvaluator
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
Expand Down Expand Up @@ -98,3 +99,4 @@ def __get_pydantic_core_schema__(
torch.utils.hooks.RemovableHandle, PydanticThirdPartyTypeIF(torch.utils.hooks.RemovableHandle)
]
PydanticDebuggingType = Annotated[Debugging, PydanticThirdPartyTypeIF(Debugging)]
PydanticDownstreamEvaluatorType = Annotated[DownstreamEvaluator, PydanticThirdPartyTypeIF(DownstreamEvaluator)]
140 changes: 140 additions & 0 deletions src/modalities/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Callable
import json
import logging
import subprocess
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper

from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch, ResultItem
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
Expand All @@ -15,6 +21,8 @@
from modalities.running_env.fsdp.reducer import Reducer
from modalities.util import TimeRecorder

logger = logging.getLogger(__name__)


class Evaluator:
"""Evaluator class which is responsible for evaluating the model on a set of datasets"""
Expand Down Expand Up @@ -197,3 +205,135 @@ def _publish_evaluation_result(
evaluation_result_publisher.publish_message(
payload=evaluation_result, message_type=MessageTypes.EVALUATION_RESULT
)


class DownstreamEvaluator:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should have this implement an EvaluationIF interface

"""Evaluator that runs OLMES on HF checkpoints produced by the conversion callback.

Checks if an ``hf_checkpoint`` folder exists inside the latest checkpoint directory
(as written by ``ModelConverter``). If it does, the configured OLMES command template
is executed via subprocess.
"""

def __init__(
self,
tokenizer: TokenizerWrapper,
tasks: list[str],
eval_interval: int,
checkpoint_dir: Path,
global_rank: int,
olmes_command_template: str,
) -> None:
self.tokenizer = tokenizer
self.tasks = tasks
self.eval_interval = eval_interval
self.checkpoint_dir = Path(checkpoint_dir)
self.global_rank = global_rank
self.olmes_command_template = olmes_command_template
self.active_processes: list[tuple[subprocess.Popen, int, Path]] = []

def evaluate(self, num_train_steps_done: int) -> None:
if num_train_steps_done == 0 or num_train_steps_done % self.eval_interval != 0:
return
if self.global_rank != 0:
return

hf_model_dir = self._find_hf_checkpoint()
if hf_model_dir is None:
logger.warning(
f"No hf_checkpoint found in {self.checkpoint_dir} at step {num_train_steps_done}, "
"skipping downstream evaluation."
)
return

tasks_str = " ".join(self.tasks)
cmd = self.olmes_command_template.format(
hf_model_dir=str(hf_model_dir),
Comment on lines +250 to +251

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think running eval on the same node(s) as the training is unrealistic. Typically especially for larger models GPU memory headroom for evaluation is rather limited.

tasks=tasks_str,
step=num_train_steps_done,
)

logger.info(f"Running downstream evaluation: {cmd}")
try:
p = subprocess.Popen(cmd, shell=True)
self.active_processes.append((p, num_train_steps_done, hf_model_dir))
logger.info(f"Downstream evaluation launched for step {num_train_steps_done}.")
except Exception as e:
logger.error(f"Failed to launch downstream evaluation: {e}")

def wait_for_evaluations(self) -> None:
if not hasattr(self, "active_processes") or not self.active_processes:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why so complicated? It's an instance variable.

Suggested change
if not hasattr(self, "active_processes") or not self.active_processes:
if len(self.active_processes) == 0:

return

logger.info(f"Waiting for {len(self.active_processes)} downstream evaluations to finish...")
for p, step, hf_model_dir in self.active_processes:
p.wait()
if p.returncode == 0:
self._sync_metrics_to_wandb(step, hf_model_dir)
else:
logger.warning(f"Downstream evaluation for step {step} exited with code {p.returncode}, skipping W&B sync.")
logger.info("All downstream evaluations finished.")
self.active_processes = []

def _sync_metrics_to_wandb(self, step: int, hf_model_dir: Path) -> None:
"""Parse OLMES metrics-all.jsonl and log primary scores to the active W&B run."""
metrics_file = hf_model_dir / f"olmes_eval_{step}" / "metrics-all.jsonl"
if not metrics_file.exists():
logger.warning(f"No metrics file found at {metrics_file}, skipping W&B sync for step {step}.")
return

metrics_dict = {}
try:
with open(metrics_file, "r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
alias = (
obj.get("task_config", {}).get("metadata", {}).get("alias")
or obj.get("task_name")
)
score = obj.get("metrics", {}).get("primary_score")
if alias and score is not None:
metrics_dict[f"downstream/{alias}"] = score
except Exception as e:
logger.error(f"Failed to parse metrics file {metrics_file}: {e}")
return

if not metrics_dict:
logger.warning(f"No metrics extracted from {metrics_file} for step {step}.")
return

try:
import wandb

if wandb.run is not None:
# Define a custom step metric so downstream/* metrics are decoupled from
# the global training step counter (which is already past these steps).
wandb.run.define_metric("downstream_step", hidden=True)
wandb.run.define_metric("downstream/*", step_metric="downstream_step")
metrics_dict["downstream_step"] = step
wandb.run.log(metrics_dict)
logger.info(f"Synced {len(metrics_dict)} OLMES metrics to W&B at step {step}: {metrics_dict}")
else:
logger.info(f"W&B not active, skipping metric sync for step {step}.")
except ImportError:
logger.info(f"wandb not installed, skipping metric sync for step {step}.")

def _find_hf_checkpoint(self) -> Path | None:
"""Read last_checkpoint_info.json and check for hf_checkpoint subfolder."""
info_file = self.checkpoint_dir / "last_checkpoint_info.json"
if not info_file.exists():
return None

with open(info_file, "r", encoding="utf-8") as f:
info = json.load(f)

checkpoint_path_str = info.get("checkpoint_folder_path") or info.get("model_checkpoint_path")
if checkpoint_path_str is None:
return None

checkpoint_path = Path(checkpoint_path_str)
if checkpoint_path.is_file():
checkpoint_path = checkpoint_path.parent

hf_dir = checkpoint_path / "hf_checkpoint"
return hf_dir if hf_dir.exists() else None
4 changes: 3 additions & 1 deletion src/modalities/gym.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from functools import partial
from typing import Callable
from typing import Callable, Optional

import torch.nn as nn

Expand Down Expand Up @@ -42,6 +42,7 @@ def run(
evaluation_data_loaders: list[LLMDataLoader],
checkpoint_saving: CheckpointSaving,
scheduled_pipeline: Pipeline | None = None,
downstream_evaluation_callback: Optional[Callable[[int], None]] = None,
):
"""Runs the model training, including evaluation and checkpointing.

Expand Down Expand Up @@ -80,6 +81,7 @@ def run(
checkpointing_callback=checkpointing_callback,
training_log_interval_in_steps=training_log_interval_in_steps,
scheduled_pipeline=scheduled_pipeline,
downstream_evaluation_callback=downstream_evaluation_callback,
)
print_rank_0(f"Training done at {datetime.now()}.")

Expand Down
12 changes: 12 additions & 0 deletions src/modalities/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def run(self, components: TrainingComponentsInstantiationModel):

print_rank_0(report)

downstream_evaluation_callback = None
if getattr(components, "downstream_evaluator", None) is not None:
downstream_evaluation_callback = components.downstream_evaluator.evaluate

gym.run(
train_data_loader=components.train_dataloader,
evaluation_data_loaders=components.eval_dataloaders,
Expand All @@ -229,8 +233,16 @@ def run(self, components: TrainingComponentsInstantiationModel):
evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps,
training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps,
scheduled_pipeline=components.scheduled_pipeline,
downstream_evaluation_callback=downstream_evaluation_callback,
)

if getattr(components, "downstream_evaluator", None) is not None:
print_rank_0("\n" + "="*80)
print_rank_0("Training loop complete! Waiting for background evaluations to finish...")
print_rank_0("="*80 + "\n")
components.downstream_evaluator.wait_for_evaluations()
print_rank_0("All background evaluations completed successfully!")

Comment on lines +239 to +245

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as written in the overall feedback, I think we should exclude evaluation from the eval loop.

def get_logging_publishers(
self,
progress_subscriber: MessageSubscriberIF[ProgressUpdate],
Expand Down
3 changes: 3 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
TorchCheckpointLoadingConfig,
WandBEvaluationResultSubscriberConfig,
WeightInitializedModelConfig,
DownstreamEvaluatorConfig,
)
from modalities.dataloader.collate_fns.collator_fn_wrapper_for_loss_masking import (
LossMaskingCollateFnWrapper,
Expand All @@ -83,6 +84,7 @@
ResultsSubscriberFactory,
)
from modalities.loss_functions import CLMCrossEntropyLoss
from modalities.evaluator import DownstreamEvaluator
from modalities.models.coca.coca_model import CoCa, CoCaConfig
from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn
from modalities.models.components.layer_norms import (
Expand Down Expand Up @@ -528,4 +530,5 @@ class ComponentEntity:
maybe_model_list(HookRegistration.register_print_forward_hooks),
PrintForwardHookConfig,
),
ComponentEntity("downstream_evaluator", "default", DownstreamEvaluator, DownstreamEvaluatorConfig),
]
5 changes: 5 additions & 0 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def train(
evaluation_callback: Callable[[int], None],
checkpointing_callback: Callable[[TrainingProgress], None],
scheduled_pipeline: Pipeline | None = None,
downstream_evaluation_callback: Optional[Callable[[int], None]] = None,
):
"""
Trains the model.
Expand Down Expand Up @@ -257,6 +258,8 @@ def train(
num_target_tokens=self.num_target_tokens,
)
checkpointing_callback(training_progress=training_progress)
if downstream_evaluation_callback is not None:
downstream_evaluation_callback(num_train_steps_done=self.num_seen_train_steps)

num_steps_todo = self.num_target_steps - self.num_seen_train_steps
num_batches_todo = num_steps_todo * self.gradient_acc_steps
Expand Down Expand Up @@ -388,6 +391,8 @@ def train(
self.gc.run(step_count=training_progress.num_seen_steps_total)
evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
checkpointing_callback(training_progress=training_progress)
if downstream_evaluation_callback is not None:
downstream_evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)

profiler_cm.step()

Expand Down
Loading