Skip to content

Commit c5192b9

Browse files
feat: add checkpoint resume, diff, prune commands and save discoverability
Add three new CLI subcommands to improve checkpoint UX: - `crewai checkpoint resume [id]` skips the TUI and resumes from the latest or specified checkpoint directly - `crewai checkpoint diff <id1> <id2>` compares two checkpoints showing changes in metadata, inputs, task status, and outputs - `crewai checkpoint prune --keep N --older-than Xd` removes old checkpoints from JSON dirs or SQLite databases Also writes a resume hint to stderr after every checkpoint save so users discover the command without needing to know it exists.
1 parent 54391fd commit c5192b9

4 files changed

Lines changed: 758 additions & 1 deletion

File tree

lib/crewai/src/crewai/cli/checkpoint_cli.py

Lines changed: 307 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from datetime import datetime
5+
from datetime import datetime, timedelta, timezone
66
import glob
77
import json
88
import os
@@ -37,6 +37,26 @@
3737
LIMIT 1
3838
"""
3939

40+
_DELETE_OLDER_THAN = """
41+
DELETE FROM checkpoints
42+
WHERE created_at < ?
43+
"""
44+
45+
_DELETE_KEEP_N = """
46+
DELETE FROM checkpoints WHERE rowid NOT IN (
47+
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
48+
)
49+
"""
50+
51+
_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints"
52+
53+
_SELECT_LIKE = """
54+
SELECT id, created_at, json(data)
55+
FROM checkpoints
56+
WHERE id LIKE ?
57+
ORDER BY rowid DESC
58+
"""
59+
4060

4161
_DEFAULT_DIR = "./.checkpoints"
4262
_DEFAULT_DB = "./.checkpoints.db"
@@ -262,6 +282,8 @@ def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None:
262282
def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None:
263283
with sqlite3.connect(db_path) as conn:
264284
row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone()
285+
if not row:
286+
row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone()
265287
if not row:
266288
return None
267289
cid, created_at, raw = row
@@ -384,3 +406,287 @@ def _print_info(meta: dict[str, Any]) -> None:
384406
if len(desc) > 70:
385407
desc = desc[:67] + "..."
386408
click.echo(f" {i + 1}. [{status}] {desc}")
409+
410+
411+
def _resolve_checkpoint(
412+
location: str, checkpoint_id: str | None
413+
) -> dict[str, Any] | None:
414+
if _is_sqlite(location):
415+
if checkpoint_id:
416+
return _info_sqlite_id(location, checkpoint_id)
417+
return _info_sqlite_latest(location)
418+
if os.path.isdir(location):
419+
if checkpoint_id:
420+
from crewai.state.provider.json_provider import JsonProvider
421+
422+
_json_provider: JsonProvider = JsonProvider()
423+
pattern: str = os.path.join(location, "**", "*.json")
424+
all_files: list[str] = glob.glob(pattern, recursive=True)
425+
matches: list[str] = [
426+
f for f in all_files if checkpoint_id in _json_provider.extract_id(f)
427+
]
428+
matches.sort(key=os.path.getmtime, reverse=True)
429+
if matches:
430+
return _info_json_file(matches[0])
431+
return None
432+
return _info_json_latest(location)
433+
if os.path.isfile(location):
434+
return _info_json_file(location)
435+
return None
436+
437+
438+
def _entity_type_from_meta(meta: dict[str, Any]) -> str:
439+
for ent in meta.get("entities", []):
440+
if ent.get("type") == "flow":
441+
return "flow"
442+
return "crew"
443+
444+
445+
def resume_checkpoint(location: str, checkpoint_id: str | None) -> None:
446+
import asyncio
447+
448+
meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id)
449+
if meta is None:
450+
if checkpoint_id:
451+
click.echo(f"Checkpoint not found: {checkpoint_id}")
452+
else:
453+
click.echo(f"No checkpoints found in {location}")
454+
return
455+
456+
restore_path: str = meta.get("path") or meta.get("source", "")
457+
if meta.get("db"):
458+
restore_path = f"{meta['db']}#{meta['name']}"
459+
460+
click.echo(f"Resuming from: {meta.get('name', restore_path)}")
461+
_print_info(meta)
462+
click.echo()
463+
464+
from crewai.state.checkpoint_config import CheckpointConfig
465+
466+
config: CheckpointConfig = CheckpointConfig(restore_from=restore_path)
467+
entity_type: str = _entity_type_from_meta(meta)
468+
inputs: dict[str, Any] | None = meta.get("inputs") or None
469+
470+
if entity_type == "flow":
471+
from crewai.flow.flow import Flow
472+
473+
flow = Flow.from_checkpoint(config)
474+
result = asyncio.run(flow.kickoff_async(inputs=inputs))
475+
else:
476+
from crewai.crew import Crew
477+
478+
crew = Crew.from_checkpoint(config)
479+
result = asyncio.run(crew.akickoff(inputs=inputs))
480+
481+
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
482+
483+
484+
def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]:
485+
tasks: list[dict[str, Any]] = []
486+
for ent in meta.get("entities", []):
487+
tasks.extend(
488+
{
489+
"entity": ent.get("name", "unnamed"),
490+
"description": t.get("description", ""),
491+
"completed": t.get("completed", False),
492+
"output": t.get("output", ""),
493+
}
494+
for t in ent.get("tasks", [])
495+
)
496+
return tasks
497+
498+
499+
def diff_checkpoints(location: str, id1: str, id2: str) -> None:
500+
meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1)
501+
meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2)
502+
503+
if meta1 is None:
504+
click.echo(f"Checkpoint not found: {id1}")
505+
return
506+
if meta2 is None:
507+
click.echo(f"Checkpoint not found: {id2}")
508+
return
509+
510+
name1: str = meta1.get("name", id1)
511+
name2: str = meta2.get("name", id2)
512+
513+
click.echo(f"--- {name1}")
514+
click.echo(f"+++ {name2}")
515+
click.echo()
516+
517+
fields: list[tuple[str, str]] = [
518+
("Time", "ts"),
519+
("Branch", "branch"),
520+
("Trigger", "trigger"),
521+
("Events", "event_count"),
522+
]
523+
for label, key in fields:
524+
v1: str = str(meta1.get(key, ""))
525+
v2: str = str(meta2.get(key, ""))
526+
if v1 != v2:
527+
click.echo(f" {label}:")
528+
click.echo(f" - {v1}")
529+
click.echo(f" + {v2}")
530+
531+
inputs1: dict[str, Any] = meta1.get("inputs", {})
532+
inputs2: dict[str, Any] = meta2.get("inputs", {})
533+
all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys())))
534+
changed_inputs: list[tuple[str, Any, Any]] = [
535+
(k, inputs1.get(k, ""), inputs2.get(k, ""))
536+
for k in all_keys
537+
if inputs1.get(k) != inputs2.get(k)
538+
]
539+
if changed_inputs:
540+
click.echo("\n Inputs:")
541+
for key, v1, v2 in changed_inputs:
542+
click.echo(f" {key}:")
543+
click.echo(f" - {v1}")
544+
click.echo(f" + {v2}")
545+
546+
tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1)
547+
tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2)
548+
549+
max_tasks: int = max(len(tasks1), len(tasks2))
550+
if max_tasks == 0:
551+
return
552+
553+
click.echo("\n Tasks:")
554+
for i in range(max_tasks):
555+
t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None
556+
t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None
557+
558+
if t1 is None:
559+
desc: str = t2["description"][:60] if t2 else ""
560+
click.echo(f" + {i + 1}. [new] {desc}")
561+
continue
562+
if t2 is None:
563+
desc = t1["description"][:60]
564+
click.echo(f" - {i + 1}. [removed] {desc}")
565+
continue
566+
567+
desc = str(t1["description"][:60])
568+
s1: str = "done" if t1["completed"] else "pending"
569+
s2: str = "done" if t2["completed"] else "pending"
570+
571+
if s1 != s2:
572+
click.echo(f" {i + 1}. {desc}")
573+
click.echo(f" status: {s1} -> {s2}")
574+
575+
out1: str = (t1.get("output") or "").strip()
576+
out2: str = (t2.get("output") or "").strip()
577+
if out1 != out2:
578+
if s1 == s2:
579+
click.echo(f" {i + 1}. {desc}")
580+
preview1: str = (
581+
out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)"
582+
)
583+
preview2: str = (
584+
out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)"
585+
)
586+
click.echo(" output:")
587+
click.echo(f" - {preview1}")
588+
click.echo(f" + {preview2}")
589+
590+
591+
def _parse_duration(value: str) -> timedelta:
592+
match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip())
593+
if not match:
594+
raise click.BadParameter(
595+
f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'."
596+
)
597+
amount: int = int(match.group(1))
598+
unit: str = match.group(2)
599+
if unit == "d":
600+
return timedelta(days=amount)
601+
if unit == "h":
602+
return timedelta(hours=amount)
603+
return timedelta(minutes=amount)
604+
605+
606+
def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int:
607+
pattern: str = os.path.join(location, "**", "*.json")
608+
files: list[str] = sorted(
609+
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
610+
)
611+
if not files:
612+
return 0
613+
614+
to_delete: set[str] = set()
615+
616+
if keep is not None and len(files) > keep:
617+
to_delete.update(files[keep:])
618+
619+
if older_than is not None:
620+
cutoff: datetime = datetime.now(timezone.utc) - older_than
621+
for path in files:
622+
mtime: datetime = datetime.fromtimestamp(
623+
os.path.getmtime(path), tz=timezone.utc
624+
)
625+
if mtime < cutoff:
626+
to_delete.add(path)
627+
628+
deleted: int = 0
629+
for path in to_delete:
630+
try:
631+
os.remove(path)
632+
deleted += 1
633+
except OSError: # noqa: PERF203
634+
pass
635+
636+
for dirpath, dirnames, filenames in os.walk(location, topdown=False):
637+
if dirpath != location and not filenames and not dirnames:
638+
try:
639+
os.rmdir(dirpath)
640+
except OSError:
641+
pass
642+
643+
return deleted
644+
645+
646+
def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int:
647+
deleted: int = 0
648+
with sqlite3.connect(db_path) as conn:
649+
if older_than is not None:
650+
cutoff: str = (datetime.now(timezone.utc) - older_than).strftime(
651+
"%Y%m%dT%H%M%S"
652+
)
653+
cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,))
654+
deleted += cursor.rowcount
655+
656+
if keep is not None:
657+
cursor = conn.execute(_DELETE_KEEP_N, (keep,))
658+
deleted += cursor.rowcount
659+
660+
conn.commit()
661+
return deleted
662+
663+
664+
def prune_checkpoints(
665+
location: str, keep: int | None, older_than: str | None, dry_run: bool = False
666+
) -> None:
667+
if keep is None and older_than is None:
668+
click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)")
669+
return
670+
671+
duration: timedelta | None = _parse_duration(older_than) if older_than else None
672+
673+
deleted: int
674+
if _is_sqlite(location):
675+
if dry_run:
676+
with sqlite3.connect(location) as conn:
677+
total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0]
678+
click.echo(f"Would prune from {total} checkpoint(s) in {location}")
679+
return
680+
deleted = _prune_sqlite(location, keep, duration)
681+
elif os.path.isdir(location):
682+
if dry_run:
683+
files: list[str] = glob.glob(
684+
os.path.join(location, "**", "*.json"), recursive=True
685+
)
686+
click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}")
687+
return
688+
deleted = _prune_json(location, keep, duration)
689+
else:
690+
click.echo(f"Not a directory or SQLite database: {location}")
691+
return
692+
click.echo(f"Pruned {deleted} checkpoint(s) from {location}")

lib/crewai/src/crewai/cli/cli.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,5 +873,48 @@ def checkpoint_info(path: str) -> None:
873873
info_checkpoint(_detect_location(path))
874874

875875

876+
@checkpoint.command("resume")
877+
@click.argument("checkpoint_id", required=False, default=None)
878+
@click.pass_context
879+
def checkpoint_resume(ctx: click.Context, checkpoint_id: str | None) -> None:
880+
"""Resume from a checkpoint. Defaults to the most recent."""
881+
from crewai.cli.checkpoint_cli import resume_checkpoint
882+
883+
resume_checkpoint(ctx.obj["location"], checkpoint_id)
884+
885+
886+
@checkpoint.command("diff")
887+
@click.argument("id1")
888+
@click.argument("id2")
889+
@click.pass_context
890+
def checkpoint_diff(ctx: click.Context, id1: str, id2: str) -> None:
891+
"""Compare two checkpoints side-by-side."""
892+
from crewai.cli.checkpoint_cli import diff_checkpoints
893+
894+
diff_checkpoints(ctx.obj["location"], id1, id2)
895+
896+
897+
@checkpoint.command("prune")
898+
@click.option(
899+
"--keep", type=int, default=None, help="Keep the N most recent checkpoints."
900+
)
901+
@click.option(
902+
"--older-than",
903+
default=None,
904+
help="Remove checkpoints older than duration (e.g. 7d, 24h, 30m).",
905+
)
906+
@click.option(
907+
"--dry-run", is_flag=True, help="Show what would be pruned without deleting."
908+
)
909+
@click.pass_context
910+
def checkpoint_prune(
911+
ctx: click.Context, keep: int | None, older_than: str | None, dry_run: bool
912+
) -> None:
913+
"""Remove old checkpoints."""
914+
from crewai.cli.checkpoint_cli import prune_checkpoints
915+
916+
prune_checkpoints(ctx.obj["location"], keep, older_than, dry_run)
917+
918+
876919
if __name__ == "__main__":
877920
crewai()

lib/crewai/src/crewai/state/checkpoint_listener.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def _do_checkpoint(
120120
)
121121
state._chain_lineage(cfg.provider, location)
122122

123+
checkpoint_id: str = cfg.provider.extract_id(location)
124+
msg: str = (
125+
f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}"
126+
)
127+
logger.info(msg)
128+
123129
if cfg.max_checkpoints is not None:
124130
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
125131

0 commit comments

Comments
 (0)