diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 93cd3e3846..911b4debb9 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -3,6 +3,7 @@ import datetime import os import pathlib +import re import signal import subprocess import sys @@ -10,6 +11,7 @@ import textwrap import time import traceback +import typing import uuid import warnings from sys import exit @@ -398,6 +400,66 @@ def get_one_of(*args) -> str: return "" +_CHECKPOINT_ATTEMPT_RE = re.compile(r"(.*)-dn0-(\d+)(/_flytecheckpoints/.*)$") + + +def _build_checkpoint_src_list( + checkpoint_path: str, + prev_checkpoint: typing.Optional[str], +) -> typing.List[str]: + """Build an ordered list of previous checkpoint paths to try during restore. + + Flyte Propeller passes a single ``prev_checkpoint`` pointing at attempt N-1. + If that attempt was killed before writing a checkpoint the chain breaks. + This function generates paths for *all* previous attempts (N-1 … 0) so the + checkpointer can walk back and find the most recent successful one. + + The checkpoint path pattern is: + ``-dn0-{attempt}/_flytecheckpoints/`` + + If we cannot parse the pattern we fall back to ``[prev_checkpoint]``. + + Args: + checkpoint_path: The checkpoint *destination* for the current attempt. + prev_checkpoint: The single previous checkpoint path from Propeller (may be None). + + Returns: + A list of checkpoint source paths ordered most-recent-attempt first. + """ + if not checkpoint_path: + return [prev_checkpoint] if prev_checkpoint else [] + + m = _CHECKPOINT_ATTEMPT_RE.search(checkpoint_path) + if not m: + # Cannot parse — fall back to the single prev_checkpoint from Propeller + return [prev_checkpoint] if prev_checkpoint else [] + + prefix, current_attempt_str, suffix = m.group(1), m.group(2), m.group(3) + current_attempt = int(current_attempt_str) + + if current_attempt == 0: + # First attempt — nothing to walk back to + return [] + + # Build paths N-1, N-2, … 0 using the same prefix structure. + # NOTE: The prefix (including the hash) may differ across attempts because + # Propeller computes it per-attempt. We use the *current* attempt's prefix + # as the best guess — the hash typically stays the same within one execution. + # If Propeller provided prev_checkpoint, keep it first because its prefix is + # guaranteed correct for attempt N-1. + srcs: typing.List[str] = [] + if prev_checkpoint: + srcs.append(prev_checkpoint) + + for attempt in range(current_attempt - 1, -1, -1): + candidate = f"{prefix}-dn0-{attempt}{suffix}" + if candidate not in srcs: + srcs.append(candidate) + + logger.debug(f"Checkpoint source candidates ({len(srcs)}): {srcs}") + return srcs + + @contextlib.contextmanager def setup_execution( raw_output_data_prefix: str, @@ -440,8 +502,9 @@ def setup_execution( checkpointer = None if checkpoint_path is not None: - checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) - logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") + checkpoint_srcs = _build_checkpoint_src_list(checkpoint_path, prev_checkpoint) + checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=checkpoint_srcs) + logger.debug(f"Checkpointer created with {len(checkpoint_srcs)} source(s) and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index d0fdf129e4..7a9aadfc3a 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -1,9 +1,12 @@ import io +import logging import tempfile import typing from abc import abstractmethod from pathlib import Path +logger = logging.getLogger(__name__) + class Checkpoint(object): """ @@ -72,14 +75,27 @@ class SyncCheckpoint(Checkpoint): SRC_LOCAL_FOLDER = "prev_cp" TMP_DST_PATH = "_dst_cp" - def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = None): + def __init__( + self, + checkpoint_dest: str, + checkpoint_src: typing.Optional[typing.Union[str, typing.List[str]]] = None, + ): """ Args: - checkpoint_src: If a previous checkpoint should exist, this path should be set to the folder that contains the checkpoint information - checkpoint_dest: Location where the new checkpoint should be copied to + checkpoint_src: One or more paths to previous checkpoint directories, tried in order. + Accepts a single path string or a list of path strings (most-recent-attempt first). + The first path that contains data wins. + checkpoint_dest: Location where the new checkpoint should be copied to. """ self._checkpoint_dest = checkpoint_dest - self._checkpoint_src = checkpoint_src if checkpoint_src and checkpoint_src != "" else None + if checkpoint_src is None: + self._checkpoint_srcs: typing.List[str] = [] + elif isinstance(checkpoint_src, str): + self._checkpoint_srcs = [checkpoint_src] if checkpoint_src != "" else [] + else: + self._checkpoint_srcs = [s for s in checkpoint_src if s and s != ""] + # Keep for backwards-compat: first candidate (or None) + self._checkpoint_src = self._checkpoint_srcs[0] if self._checkpoint_srcs else None self._td = tempfile.TemporaryDirectory() self._prev_download_path: typing.Optional[Path] = None @@ -87,13 +103,29 @@ def __del__(self): self._td.cleanup() def prev_exists(self) -> bool: - return self._checkpoint_src is not None + return len(self._checkpoint_srcs) > 0 def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typing.Optional[Path]: + """Download a previous checkpoint, walking back through attempts until one succeeds. + + Tries each candidate in ``self._checkpoint_srcs`` (most-recent first). The first + path that contains data is used. On success the checkpoint is also copied to + ``checkpoint_dest`` so the *next* attempt can find it without walking back again. + + Args: + path: Local directory to download into. A temp directory is used when *None*. + + Returns: + The local path where the checkpoint was restored, or *None* if no candidates exist. + + Raises: + ValueError: If *path* is not a directory. + FlyteDataNotFoundException: If none of the candidates contain data. + """ # We have to lazy load, until we fix the imports from flytekit.core.context_manager import FlyteContextManager - if self._checkpoint_src is None or self._checkpoint_src == "": + if not self._checkpoint_srcs: return None if self._prev_download_path: @@ -109,10 +141,50 @@ def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typi if not path.is_dir(): raise ValueError("Checkpoints can be restored to a directory only.") - FlyteContextManager.current_context().file_access.download_directory(self._checkpoint_src, str(path)) + fa = FlyteContextManager.current_context().file_access + last_err: typing.Optional[Exception] = None + + for idx, src in enumerate(self._checkpoint_srcs): + try: + fa.download_directory(src, str(path)) + # Check that the download actually produced files + if any(path.iterdir()): + logger.info(f"Checkpoint restored from candidate {idx}: {src}") + self._prev_download_path = path + self._auto_forward(fa, path) + return self._prev_download_path + # Empty directory — treat as missing and try the next candidate + logger.debug(f"Checkpoint candidate {idx} was empty: {src}") + except Exception as e: + logger.debug(f"Checkpoint candidate {idx} failed ({src}): {e}") + last_err = e + + # None of the candidates worked. Re-raise the last download error if we had one, + # otherwise fall through to the original single-source behaviour so existing + # callers see the same exception they always did. + if last_err is not None: + raise last_err + + # All candidates were empty directories — download from the first source so the + # original behaviour (returning the path) is preserved. + fa.download_directory(self._checkpoint_srcs[0], str(path)) self._prev_download_path = path return self._prev_download_path + def _auto_forward(self, fa: typing.Any, local_path: Path) -> None: + """Copy a successfully restored checkpoint to this attempt's dest path. + + This "auto-forward" ensures that the next retry can always find a valid + checkpoint at attempt N's path even if N is killed before writing its own. + """ + try: + if self._checkpoint_dest: + fa.upload_directory(str(local_path), self._checkpoint_dest) + logger.debug(f"Auto-forwarded checkpoint to {self._checkpoint_dest}") + except Exception: + # Best-effort — don't let a forwarding failure block the restore. + logger.warning("Failed to auto-forward checkpoint to dest", exc_info=True) + def save(self, cp: typing.Union[Path, str, io.BufferedReader]): # We have to lazy load, until we fix the imports from flytekit.core.context_manager import FlyteContextManager