From bbde3374504a2fb6bf5b420652a74893f6533b9e Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 01:37:28 +0000 Subject: [PATCH] fix(checkpointer): walk back through previous attempts to find valid checkpoint When a pod is killed before writing a checkpoint, the single-link chain breaks and subsequent retries start from scratch. This change makes SyncCheckpoint accept a list of previous checkpoint paths and try them in reverse order (N-1, N-2, ..., 0) until it finds one with data. On success, the checkpoint is auto-forwarded to the current attempt's dest path so the next retry can find it without walking back again. The entrypoint now computes the full list of previous attempt paths from the deterministic checkpoint path pattern (dn0-{attempt}) and passes them to SyncCheckpoint. Co-Authored-By: unknown <> --- flytekit/bin/entrypoint.py | 67 ++++++++++++++++++++++++++- flytekit/core/checkpointer.py | 86 ++++++++++++++++++++++++++++++++--- 2 files changed, 144 insertions(+), 9 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 93cd3e3846..911b4debb9 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -3,6 +3,7 @@ import datetime import os import pathlib +import re import signal import subprocess import sys @@ -10,6 +11,7 @@ import textwrap import time import traceback +import typing import uuid import warnings from sys import exit @@ -398,6 +400,66 @@ def get_one_of(*args) -> str: return "" +_CHECKPOINT_ATTEMPT_RE = re.compile(r"(.*)-dn0-(\d+)(/_flytecheckpoints/.*)$") + + +def _build_checkpoint_src_list( + checkpoint_path: str, + prev_checkpoint: typing.Optional[str], +) -> typing.List[str]: + """Build an ordered list of previous checkpoint paths to try during restore. + + Flyte Propeller passes a single ``prev_checkpoint`` pointing at attempt N-1. + If that attempt was killed before writing a checkpoint the chain breaks. + This function generates paths for *all* previous attempts (N-1 … 0) so the + checkpointer can walk back and find the most recent successful one. + + The checkpoint path pattern is: + ``-dn0-{attempt}/_flytecheckpoints/`` + + If we cannot parse the pattern we fall back to ``[prev_checkpoint]``. + + Args: + checkpoint_path: The checkpoint *destination* for the current attempt. + prev_checkpoint: The single previous checkpoint path from Propeller (may be None). + + Returns: + A list of checkpoint source paths ordered most-recent-attempt first. + """ + if not checkpoint_path: + return [prev_checkpoint] if prev_checkpoint else [] + + m = _CHECKPOINT_ATTEMPT_RE.search(checkpoint_path) + if not m: + # Cannot parse — fall back to the single prev_checkpoint from Propeller + return [prev_checkpoint] if prev_checkpoint else [] + + prefix, current_attempt_str, suffix = m.group(1), m.group(2), m.group(3) + current_attempt = int(current_attempt_str) + + if current_attempt == 0: + # First attempt — nothing to walk back to + return [] + + # Build paths N-1, N-2, … 0 using the same prefix structure. + # NOTE: The prefix (including the hash) may differ across attempts because + # Propeller computes it per-attempt. We use the *current* attempt's prefix + # as the best guess — the hash typically stays the same within one execution. + # If Propeller provided prev_checkpoint, keep it first because its prefix is + # guaranteed correct for attempt N-1. + srcs: typing.List[str] = [] + if prev_checkpoint: + srcs.append(prev_checkpoint) + + for attempt in range(current_attempt - 1, -1, -1): + candidate = f"{prefix}-dn0-{attempt}{suffix}" + if candidate not in srcs: + srcs.append(candidate) + + logger.debug(f"Checkpoint source candidates ({len(srcs)}): {srcs}") + return srcs + + @contextlib.contextmanager def setup_execution( raw_output_data_prefix: str, @@ -440,8 +502,9 @@ def setup_execution( checkpointer = None if checkpoint_path is not None: - checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) - logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") + checkpoint_srcs = _build_checkpoint_src_list(checkpoint_path, prev_checkpoint) + checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=checkpoint_srcs) + logger.debug(f"Checkpointer created with {len(checkpoint_srcs)} source(s) and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index d0fdf129e4..7a9aadfc3a 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -1,9 +1,12 @@ import io +import logging import tempfile import typing from abc import abstractmethod from pathlib import Path +logger = logging.getLogger(__name__) + class Checkpoint(object): """ @@ -72,14 +75,27 @@ class SyncCheckpoint(Checkpoint): SRC_LOCAL_FOLDER = "prev_cp" TMP_DST_PATH = "_dst_cp" - def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = None): + def __init__( + self, + checkpoint_dest: str, + checkpoint_src: typing.Optional[typing.Union[str, typing.List[str]]] = None, + ): """ Args: - checkpoint_src: If a previous checkpoint should exist, this path should be set to the folder that contains the checkpoint information - checkpoint_dest: Location where the new checkpoint should be copied to + checkpoint_src: One or more paths to previous checkpoint directories, tried in order. + Accepts a single path string or a list of path strings (most-recent-attempt first). + The first path that contains data wins. + checkpoint_dest: Location where the new checkpoint should be copied to. """ self._checkpoint_dest = checkpoint_dest - self._checkpoint_src = checkpoint_src if checkpoint_src and checkpoint_src != "" else None + if checkpoint_src is None: + self._checkpoint_srcs: typing.List[str] = [] + elif isinstance(checkpoint_src, str): + self._checkpoint_srcs = [checkpoint_src] if checkpoint_src != "" else [] + else: + self._checkpoint_srcs = [s for s in checkpoint_src if s and s != ""] + # Keep for backwards-compat: first candidate (or None) + self._checkpoint_src = self._checkpoint_srcs[0] if self._checkpoint_srcs else None self._td = tempfile.TemporaryDirectory() self._prev_download_path: typing.Optional[Path] = None @@ -87,13 +103,29 @@ def __del__(self): self._td.cleanup() def prev_exists(self) -> bool: - return self._checkpoint_src is not None + return len(self._checkpoint_srcs) > 0 def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typing.Optional[Path]: + """Download a previous checkpoint, walking back through attempts until one succeeds. + + Tries each candidate in ``self._checkpoint_srcs`` (most-recent first). The first + path that contains data is used. On success the checkpoint is also copied to + ``checkpoint_dest`` so the *next* attempt can find it without walking back again. + + Args: + path: Local directory to download into. A temp directory is used when *None*. + + Returns: + The local path where the checkpoint was restored, or *None* if no candidates exist. + + Raises: + ValueError: If *path* is not a directory. + FlyteDataNotFoundException: If none of the candidates contain data. + """ # We have to lazy load, until we fix the imports from flytekit.core.context_manager import FlyteContextManager - if self._checkpoint_src is None or self._checkpoint_src == "": + if not self._checkpoint_srcs: return None if self._prev_download_path: @@ -109,10 +141,50 @@ def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typi if not path.is_dir(): raise ValueError("Checkpoints can be restored to a directory only.") - FlyteContextManager.current_context().file_access.download_directory(self._checkpoint_src, str(path)) + fa = FlyteContextManager.current_context().file_access + last_err: typing.Optional[Exception] = None + + for idx, src in enumerate(self._checkpoint_srcs): + try: + fa.download_directory(src, str(path)) + # Check that the download actually produced files + if any(path.iterdir()): + logger.info(f"Checkpoint restored from candidate {idx}: {src}") + self._prev_download_path = path + self._auto_forward(fa, path) + return self._prev_download_path + # Empty directory — treat as missing and try the next candidate + logger.debug(f"Checkpoint candidate {idx} was empty: {src}") + except Exception as e: + logger.debug(f"Checkpoint candidate {idx} failed ({src}): {e}") + last_err = e + + # None of the candidates worked. Re-raise the last download error if we had one, + # otherwise fall through to the original single-source behaviour so existing + # callers see the same exception they always did. + if last_err is not None: + raise last_err + + # All candidates were empty directories — download from the first source so the + # original behaviour (returning the path) is preserved. + fa.download_directory(self._checkpoint_srcs[0], str(path)) self._prev_download_path = path return self._prev_download_path + def _auto_forward(self, fa: typing.Any, local_path: Path) -> None: + """Copy a successfully restored checkpoint to this attempt's dest path. + + This "auto-forward" ensures that the next retry can always find a valid + checkpoint at attempt N's path even if N is killed before writing its own. + """ + try: + if self._checkpoint_dest: + fa.upload_directory(str(local_path), self._checkpoint_dest) + logger.debug(f"Auto-forwarded checkpoint to {self._checkpoint_dest}") + except Exception: + # Best-effort — don't let a forwarding failure block the restore. + logger.warning("Failed to auto-forward checkpoint to dest", exc_info=True) + def save(self, cp: typing.Union[Path, str, io.BufferedReader]): # We have to lazy load, until we fix the imports from flytekit.core.context_manager import FlyteContextManager