|
6 | 6 | ) |
7 | 7 |
|
8 | 8 | import logging |
| 9 | +import traceback |
9 | 10 | from dataclasses import ( |
10 | 11 | dataclass, |
11 | 12 | ) |
@@ -110,6 +111,12 @@ def resolve_full_validation_start_step( |
110 | 111 | def parse_validation_metric(metric: str) -> tuple[str, str]: |
111 | 112 | """Parse the configured full validation metric.""" |
112 | 113 | normalized_metric = normalize_full_validation_metric(metric) |
| 114 | + if normalized_metric not in METRIC_KEY_MAP: |
| 115 | + supported_metrics = ", ".join(item.upper() for item in METRIC_KEY_MAP) |
| 116 | + raise ValueError( |
| 117 | + "validating.validation_metric must be one of " |
| 118 | + f"{supported_metrics}, got {metric!r}." |
| 119 | + ) |
113 | 120 | return normalized_metric, METRIC_KEY_MAP[normalized_metric] |
114 | 121 |
|
115 | 122 |
|
@@ -255,7 +262,7 @@ def __init__( |
255 | 262 | self.enabled = ( |
256 | 263 | self.full_validation |
257 | 264 | and self.start_step is not None |
258 | | - and self.start_step < num_steps |
| 265 | + and self.start_step <= num_steps |
259 | 266 | ) |
260 | 267 | self.step_column_width = max(len("step"), len(str(num_steps))) |
261 | 268 | self._write_mode = "a" if restart_training else "w" |
@@ -308,21 +315,60 @@ def run( |
308 | 315 | dist.barrier() |
309 | 316 |
|
310 | 317 | result: FullValidationResult | None = None |
| 318 | + caught_exception: Exception | None = None |
| 319 | + error_message = None |
311 | 320 | save_path = [None] |
312 | 321 | if self.rank == 0: |
313 | | - result = self._evaluate(display_step) |
314 | | - save_path[0] = result.saved_best_path |
| 322 | + try: |
| 323 | + result = self._evaluate(display_step) |
| 324 | + save_path[0] = result.saved_best_path |
| 325 | + except Exception as exc: |
| 326 | + caught_exception = exc |
| 327 | + error_message = ( |
| 328 | + "Full validation failed on rank 0 during evaluation:\n" |
| 329 | + f"{traceback.format_exc()}" |
| 330 | + ) |
| 331 | + |
| 332 | + self._raise_if_distributed_error(error_message, caught_exception) |
315 | 333 |
|
316 | 334 | if self.is_distributed: |
317 | 335 | dist.broadcast_object_list(save_path, src=0) |
318 | 336 |
|
319 | 337 | if save_path[0] is not None: |
320 | | - save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) |
321 | | - if self.rank == 0: |
322 | | - self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) |
| 338 | + try: |
| 339 | + if not self.is_distributed or self.zero_stage == 0: |
| 340 | + if self.rank == 0: |
| 341 | + save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) |
| 342 | + else: |
| 343 | + save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) |
| 344 | + if self.rank == 0: |
| 345 | + self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) |
| 346 | + except Exception as exc: |
| 347 | + caught_exception = exc |
| 348 | + error_message = ( |
| 349 | + "Full validation failed while saving the best checkpoint:\n" |
| 350 | + f"{traceback.format_exc()}" |
| 351 | + ) |
| 352 | + else: |
| 353 | + error_message = None |
| 354 | + caught_exception = None |
| 355 | + |
| 356 | + self._raise_if_distributed_error(error_message, caught_exception) |
323 | 357 |
|
324 | 358 | if self.rank == 0: |
325 | | - self._log_result(result) |
| 359 | + try: |
| 360 | + self._log_result(result) |
| 361 | + except Exception as exc: |
| 362 | + caught_exception = exc |
| 363 | + error_message = ( |
| 364 | + "Full validation failed while writing logs:\n" |
| 365 | + f"{traceback.format_exc()}" |
| 366 | + ) |
| 367 | + else: |
| 368 | + error_message = None |
| 369 | + caught_exception = None |
| 370 | + |
| 371 | + self._raise_if_distributed_error(error_message, caught_exception) |
326 | 372 |
|
327 | 373 | if self.is_distributed: |
328 | 374 | dist.barrier() |
@@ -367,8 +413,12 @@ def evaluate_all_systems(self) -> dict[str, float]: |
367 | 413 |
|
368 | 414 | system_metrics = [] |
369 | 415 | for dataset in self.validation_data.systems: |
370 | | - assert isinstance(dataset, DeepmdDataSetForLoader) |
371 | | - system_metrics.append(self._evaluate_system(dataset._data_system)) |
| 416 | + if not isinstance(dataset, DeepmdDataSetForLoader): |
| 417 | + raise TypeError( |
| 418 | + "Full validation expects each dataset in validation_data.systems " |
| 419 | + f"to be DeepmdDataSetForLoader, got {type(dataset)!r}." |
| 420 | + ) |
| 421 | + system_metrics.append(self._evaluate_system(dataset.data_system)) |
372 | 422 |
|
373 | 423 | aggregated = weighted_average([metric for metric in system_metrics if metric]) |
374 | 424 | return { |
@@ -555,6 +605,25 @@ def _initialize_best_checkpoints(self, restart_training: bool) -> None: |
555 | 605 | else: |
556 | 606 | self._prune_best_checkpoints() |
557 | 607 |
|
| 608 | + def _raise_if_distributed_error( |
| 609 | + self, |
| 610 | + local_error_message: str | None, |
| 611 | + local_exception: Exception | None = None, |
| 612 | + ) -> None: |
| 613 | + """Propagate a local error to all ranks and raise consistently.""" |
| 614 | + error_message = local_error_message |
| 615 | + if self.is_distributed: |
| 616 | + gathered_errors = [None] * dist.get_world_size() |
| 617 | + dist.all_gather_object(gathered_errors, local_error_message) |
| 618 | + error_message = next( |
| 619 | + (message for message in gathered_errors if message is not None), None |
| 620 | + ) |
| 621 | + if error_message is None: |
| 622 | + return |
| 623 | + if local_exception is not None: |
| 624 | + raise RuntimeError(error_message) from local_exception |
| 625 | + raise RuntimeError(error_message) |
| 626 | + |
558 | 627 | def _log_result(self, result: FullValidationResult | None) -> None: |
559 | 628 | """Log and persist full validation results on rank 0.""" |
560 | 629 | assert result is not None |
|
0 commit comments