Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import datetime
import os
import pathlib
import re
import signal
import subprocess
import sys
import tempfile
import textwrap
import time
import traceback
import typing
import uuid
import warnings
from sys import exit
Expand Down Expand Up @@ -398,6 +400,66 @@ def get_one_of(*args) -> str:
return ""


_CHECKPOINT_ATTEMPT_RE = re.compile(r"(.*)-dn0-(\d+)(/_flytecheckpoints/.*)$")


def _build_checkpoint_src_list(
checkpoint_path: str,
prev_checkpoint: typing.Optional[str],
) -> typing.List[str]:
"""Build an ordered list of previous checkpoint paths to try during restore.

Flyte Propeller passes a single ``prev_checkpoint`` pointing at attempt N-1.
If that attempt was killed before writing a checkpoint the chain breaks.
This function generates paths for *all* previous attempts (N-1 … 0) so the
checkpointer can walk back and find the most recent successful one.

The checkpoint path pattern is:
``<prefix>-dn0-{attempt}/_flytecheckpoints/``

If we cannot parse the pattern we fall back to ``[prev_checkpoint]``.

Args:
checkpoint_path: The checkpoint *destination* for the current attempt.
prev_checkpoint: The single previous checkpoint path from Propeller (may be None).

Returns:
A list of checkpoint source paths ordered most-recent-attempt first.
"""
if not checkpoint_path:
return [prev_checkpoint] if prev_checkpoint else []

m = _CHECKPOINT_ATTEMPT_RE.search(checkpoint_path)
if not m:
# Cannot parse — fall back to the single prev_checkpoint from Propeller
return [prev_checkpoint] if prev_checkpoint else []

prefix, current_attempt_str, suffix = m.group(1), m.group(2), m.group(3)
current_attempt = int(current_attempt_str)

if current_attempt == 0:
# First attempt — nothing to walk back to
return []

# Build paths N-1, N-2, … 0 using the same prefix structure.
# NOTE: The prefix (including the hash) may differ across attempts because
# Propeller computes it per-attempt. We use the *current* attempt's prefix
# as the best guess — the hash typically stays the same within one execution.
# If Propeller provided prev_checkpoint, keep it first because its prefix is
# guaranteed correct for attempt N-1.
srcs: typing.List[str] = []
if prev_checkpoint:
srcs.append(prev_checkpoint)

for attempt in range(current_attempt - 1, -1, -1):
candidate = f"{prefix}-dn0-{attempt}{suffix}"
if candidate not in srcs:
srcs.append(candidate)

logger.debug(f"Checkpoint source candidates ({len(srcs)}): {srcs}")
return srcs


@contextlib.contextmanager
def setup_execution(
raw_output_data_prefix: str,
Expand Down Expand Up @@ -440,8 +502,9 @@ def setup_execution(

checkpointer = None
if checkpoint_path is not None:
checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")
checkpoint_srcs = _build_checkpoint_src_list(checkpoint_path, prev_checkpoint)
checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=checkpoint_srcs)
logger.debug(f"Checkpointer created with {len(checkpoint_srcs)} source(s) and dest {checkpoint_path}")

execution_parameters = ExecutionParameters(
execution_id=_identifier.WorkflowExecutionIdentifier(
Expand Down
86 changes: 79 additions & 7 deletions flytekit/core/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import io
import logging
import tempfile
import typing
from abc import abstractmethod
from pathlib import Path

logger = logging.getLogger(__name__)


class Checkpoint(object):
"""
Expand Down Expand Up @@ -72,28 +75,57 @@ class SyncCheckpoint(Checkpoint):
SRC_LOCAL_FOLDER = "prev_cp"
TMP_DST_PATH = "_dst_cp"

def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = None):
def __init__(
self,
checkpoint_dest: str,
checkpoint_src: typing.Optional[typing.Union[str, typing.List[str]]] = None,
):
"""
Args:
checkpoint_src: If a previous checkpoint should exist, this path should be set to the folder that contains the checkpoint information
checkpoint_dest: Location where the new checkpoint should be copied to
checkpoint_src: One or more paths to previous checkpoint directories, tried in order.
Accepts a single path string or a list of path strings (most-recent-attempt first).
The first path that contains data wins.
checkpoint_dest: Location where the new checkpoint should be copied to.
"""
self._checkpoint_dest = checkpoint_dest
self._checkpoint_src = checkpoint_src if checkpoint_src and checkpoint_src != "" else None
if checkpoint_src is None:
self._checkpoint_srcs: typing.List[str] = []
elif isinstance(checkpoint_src, str):
self._checkpoint_srcs = [checkpoint_src] if checkpoint_src != "" else []
else:
self._checkpoint_srcs = [s for s in checkpoint_src if s and s != ""]
# Keep for backwards-compat: first candidate (or None)
self._checkpoint_src = self._checkpoint_srcs[0] if self._checkpoint_srcs else None
self._td = tempfile.TemporaryDirectory()
self._prev_download_path: typing.Optional[Path] = None

def __del__(self):
self._td.cleanup()

def prev_exists(self) -> bool:
return self._checkpoint_src is not None
return len(self._checkpoint_srcs) > 0

def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typing.Optional[Path]:
"""Download a previous checkpoint, walking back through attempts until one succeeds.

Tries each candidate in ``self._checkpoint_srcs`` (most-recent first). The first
path that contains data is used. On success the checkpoint is also copied to
``checkpoint_dest`` so the *next* attempt can find it without walking back again.

Args:
path: Local directory to download into. A temp directory is used when *None*.

Returns:
The local path where the checkpoint was restored, or *None* if no candidates exist.

Raises:
ValueError: If *path* is not a directory.
FlyteDataNotFoundException: If none of the candidates contain data.
"""
# We have to lazy load, until we fix the imports
from flytekit.core.context_manager import FlyteContextManager

if self._checkpoint_src is None or self._checkpoint_src == "":
if not self._checkpoint_srcs:
return None

if self._prev_download_path:
Expand All @@ -109,10 +141,50 @@ def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typi
if not path.is_dir():
raise ValueError("Checkpoints can be restored to a directory only.")

FlyteContextManager.current_context().file_access.download_directory(self._checkpoint_src, str(path))
fa = FlyteContextManager.current_context().file_access
last_err: typing.Optional[Exception] = None

for idx, src in enumerate(self._checkpoint_srcs):
try:
fa.download_directory(src, str(path))
# Check that the download actually produced files
if any(path.iterdir()):
logger.info(f"Checkpoint restored from candidate {idx}: {src}")
self._prev_download_path = path
self._auto_forward(fa, path)
return self._prev_download_path
# Empty directory — treat as missing and try the next candidate
logger.debug(f"Checkpoint candidate {idx} was empty: {src}")
except Exception as e:
logger.debug(f"Checkpoint candidate {idx} failed ({src}): {e}")
last_err = e

# None of the candidates worked. Re-raise the last download error if we had one,
# otherwise fall through to the original single-source behaviour so existing
# callers see the same exception they always did.
if last_err is not None:
raise last_err

# All candidates were empty directories — download from the first source so the
# original behaviour (returning the path) is preserved.
fa.download_directory(self._checkpoint_srcs[0], str(path))
self._prev_download_path = path
return self._prev_download_path

def _auto_forward(self, fa: typing.Any, local_path: Path) -> None:
"""Copy a successfully restored checkpoint to this attempt's dest path.

This "auto-forward" ensures that the next retry can always find a valid
checkpoint at attempt N's path even if N is killed before writing its own.
"""
try:
if self._checkpoint_dest:
fa.upload_directory(str(local_path), self._checkpoint_dest)
logger.debug(f"Auto-forwarded checkpoint to {self._checkpoint_dest}")
except Exception:
# Best-effort — don't let a forwarding failure block the restore.
logger.warning("Failed to auto-forward checkpoint to dest", exc_info=True)

def save(self, cp: typing.Union[Path, str, io.BufferedReader]):
# We have to lazy load, until we fix the imports
from flytekit.core.context_manager import FlyteContextManager
Expand Down