Skip to content

Commit 2209b68

Browse files
committed
RHOAIENG-57679: add checkpointing guided notebook
1 parent 0de8ae3 commit 2209b68

2 files changed

Lines changed: 541 additions & 0 deletions

File tree

Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Ray checkpointing example\n",
8+
"\n",
9+
"This notebook runs a **Ray Train** checkpointing demo on **Red Hat OpenShift AI** using the CodeFlare SDK:\n",
10+
"\n",
11+
"- **Red Hat build of Kueue** must be configured in your cluster (ResourceFlavor → ClusterQueue → **LocalQueue** in your namespace, and `kueue.openshift.io/managed=true` on the namespace). See the OpenShift *AI workloads* documentation for [Red Hat build of Kueue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue).\n",
12+
"- Submit a **RayJob** with a **managed Ray cluster** (`ManagedClusterConfig`) so KubeRay lifecycles the cluster with the job (`shutdownAfterJobFinishes`). The RayJob is labeled for your **LocalQueue** via `local_queue` (example: `\"default\"`).\n",
13+
"- Configure **AWS credentials** for the S3 bucket used by Ray Train checkpoints.\n",
14+
"- **Monitor** training in the Ray dashboard (**Jobs** tab), then **suspend and resume** the RayJob (`job.stop()` / `job.resubmit()`) to verify training **resumes from S3** after a simulated interruption.\n",
15+
"\n",
16+
"Training script: `train_with_checkpoints.py` in this directory (same source as the CodeFlare SDK guided demo)."
17+
]
18+
},
19+
{
20+
"cell_type": "markdown",
21+
"metadata": {},
22+
"source": [
23+
"## Import required libraries"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"from codeflare_sdk import RayJob, ManagedClusterConfig, set_api_client, get_cluster\n",
33+
"from kube_authkit import AuthConfig, get_k8s_client\n",
34+
"import time"
35+
]
36+
},
37+
{
38+
"cell_type": "markdown",
39+
"metadata": {},
40+
"source": [
41+
"## Authenticate to your OpenShift cluster"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"import urllib3\n",
51+
"\n",
52+
"urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)\n",
53+
"\n",
54+
"# Authenticate to your Kubernetes/OpenShift cluster using kube-authkit\n",
55+
"\n",
56+
"# Option 1: Auto-detect credentials (kubeconfig or in-cluster service account)\n",
57+
"# NOTE: In RHOAI Workbenches the workbench service account may not have Ray RBAC\n",
58+
"# permissions. Use Option 2 (token) unless your admin has granted SA permissions\n",
59+
"# (see RHOAIENG-46748). Auto-detect works if you have a local kubeconfig.\n",
60+
"# auth_config = AuthConfig(method=\"auto\")\n",
61+
"\n",
62+
"# Option 2 (Recommended for RHOAI Workbenches): Token-based authentication\n",
63+
"# Get your token with: oc whoami -t (or from the OpenShift console → Copy login command)\n",
64+
"auth_config = AuthConfig(\n",
65+
" method=\"openshift\",\n",
66+
" k8s_api_host=\"https://api.example.com:6443\",\n",
67+
" token=\"sha256~XXXXX\", # oc whoami -t\n",
68+
")\n",
69+
"\n",
70+
"# Option 3: OIDC authentication (for BYOIDC-enabled clusters)\n",
71+
"# auth_config = AuthConfig(\n",
72+
"# method=\"oidc\",\n",
73+
"# k8s_api_host=\"https://api.example.com:6443\",\n",
74+
"# oidc_issuer=\"https://your-oidc-provider.com\",\n",
75+
"# client_id=\"your-client-id\",\n",
76+
"# use_device_flow=True, # Interactive device flow for notebook environments\n",
77+
"# )\n",
78+
"\n",
79+
"api_client = get_k8s_client(config=auth_config)\n",
80+
"# Set to False for self-signed / dev API certificates (optional).\n",
81+
"api_client.configuration.verify_ssl = False\n",
82+
"set_api_client(api_client)\n",
83+
"\n",
84+
"NAMESPACE = \"your-namespace\" # Data Science Project where LocalQueue + RayJob run\n",
85+
"JOB_NAME = \"checkpointing-job\"\n",
86+
"# Must match metadata.name of a LocalQueue in NAMESPACE (create per OpenShift Kueue docs).\n",
87+
"LOCAL_QUEUE = \"default\""
88+
]
89+
},
90+
{
91+
"cell_type": "markdown",
92+
"metadata": {},
93+
"source": [
94+
"## Red Hat build of Kueue (required before submit)\n",
95+
"\n",
96+
"Configure **ResourceFlavor** → **ClusterQueue** → **LocalQueue** in your project namespace, and label the namespace so Kueue manages workloads there. Official OpenShift 4.21 *AI workloads* docs:\n",
97+
"\n",
98+
"- [Configuring a resource flavor](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-resourceflavors_configuring-quotas)\n",
99+
"- [Configuring a cluster queue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-clusterqueues_configuring-quotas)\n",
100+
"- [Configuring a local queue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-localqueues_configuring-quotas) — the `LocalQueue` `metadata.name` must match **`LOCAL_QUEUE`** above (e.g. `default`).\n",
101+
"- [Labeling namespaces to allow Red Hat build of Kueue to manage jobs](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#labeling-namespaces-to-allow-red-hat-build-of-kueue-to-manage-jobs_managing-jobs-and-workloads): `oc label namespace <namespace> kueue.openshift.io/managed=true`\n",
102+
"\n",
103+
"After submit, the RayJob may show **Suspended** until Kueue **admits** a `Workload` — use `oc get workloads.kueue.x-k8s.io -n $NAMESPACE` if needed. That is **not** the same as the manual suspend used later for the checkpoint demo."
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"metadata": {},
110+
"outputs": [],
111+
"source": [
112+
"# Optional: verify Kueue objects exist (cluster admin / user with read access)\n",
113+
"# !oc get resourceflavor.kueue.x-k8s.io\n",
114+
"# !oc get clusterqueue.kueue.x-k8s.io\n",
115+
"# !oc get localqueue.kueue.x-k8s.io -n $NAMESPACE\n",
116+
"\n",
117+
"print(f\"Namespace: {NAMESPACE!r}, RayJob name: {JOB_NAME!r}, LocalQueue: {LOCAL_QUEUE!r}\")"
118+
]
119+
},
120+
{
121+
"cell_type": "markdown",
122+
"metadata": {},
123+
"source": [
124+
"## Set your AWS credentials"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": null,
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"# Set your AWS credentials\n",
134+
"# WARNING: Do not commit credentials to version control. For production,\n",
135+
"# use OpenShift AI Data Connections or OpenShift Secrets instead.\n",
136+
"AWS_CREDENTIALS = {\n",
137+
" \"AWS_ACCESS_KEY_ID\": \"your-access-key\",\n",
138+
" \"AWS_SECRET_ACCESS_KEY\": \"your-secret-key\",\n",
139+
" \"AWS_DEFAULT_REGION\": \"us-east-1\", # e.g. \"us-east-1\"\n",
140+
" \"AWS_S3_BUCKET\": \"your-bucket-name\",\n",
141+
"}\n",
142+
"\n",
143+
"# If using temporary credentials (SSO/federated), add the session token:\n",
144+
"# AWS_CREDENTIALS[\"AWS_SESSION_TOKEN\"] = \"your-session-token\"\n",
145+
"\n",
146+
"print(f\"Using bucket: {AWS_CREDENTIALS['AWS_S3_BUCKET']}\")"
147+
]
148+
},
149+
{
150+
"cell_type": "markdown",
151+
"metadata": {},
152+
"source": [
153+
"## Submit RayJob (managed Ray cluster + Kueue local queue)"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
161+
"source": [
162+
"managed = ManagedClusterConfig(\n",
163+
" num_workers=2,\n",
164+
" head_cpu_requests=2,\n",
165+
" head_cpu_limits=4,\n",
166+
" head_memory_requests=4,\n",
167+
" head_memory_limits=8,\n",
168+
" worker_cpu_requests=2,\n",
169+
" worker_cpu_limits=4,\n",
170+
" worker_memory_requests=4,\n",
171+
" worker_memory_limits=8,\n",
172+
")\n",
173+
"\n",
174+
"job = RayJob(\n",
175+
" job_name=JOB_NAME,\n",
176+
" entrypoint=\"python train_with_checkpoints.py\",\n",
177+
" cluster_config=managed,\n",
178+
" namespace=NAMESPACE,\n",
179+
" local_queue=LOCAL_QUEUE,\n",
180+
" runtime_env={\n",
181+
" \"working_dir\": \"./\",\n",
182+
" \"pip\": [\"torch\", \"torchvision\", \"s3fs\", \"pyarrow\"],\n",
183+
" \"env_vars\": {\n",
184+
" **AWS_CREDENTIALS,\n",
185+
" \"RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S\": \"120\", # Allow time for worker scheduling\n",
186+
" },\n",
187+
" },\n",
188+
")\n",
189+
"\n",
190+
"job.submit()\n",
191+
"print(\n",
192+
" \"RayJob submitted. If status stays Suspended briefly, Kueue may still be admitting the Workload.\"\n",
193+
")\n",
194+
"print(f\"RayCluster name (when assigned): {job.cluster_name}\")\n",
195+
"print(\"Watch logs for: NO CHECKPOINT FOUND - Starting fresh\")"
196+
]
197+
},
198+
{
199+
"cell_type": "markdown",
200+
"metadata": {},
201+
"source": [
202+
"## Monitor job progress (status + Ray dashboard Jobs)\n",
203+
"\n",
204+
"Poll `job.status()`, then open the **Ray dashboard** URL from the RayCluster created by your RayJob. Use **Jobs** in the dashboard for live driver logs (epochs, checkpoint messages)."
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"print(job.status())\n",
214+
"# KubeRay assigns a generated name — not job.cluster_name (the template).\n",
215+
"cluster = None\n",
216+
"for _ in range(36):\n",
217+
" status_data = job._api.get_job_status(name=job.name, k8s_namespace=NAMESPACE)\n",
218+
" ray_cluster_name = (status_data or {}).get(\"rayClusterName\")\n",
219+
" if ray_cluster_name:\n",
220+
" cluster = get_cluster(ray_cluster_name, namespace=NAMESPACE, verify_tls=False)\n",
221+
" if cluster is not None:\n",
222+
" break\n",
223+
" time.sleep(5)\n",
224+
"if cluster is None:\n",
225+
" raise RuntimeError(\n",
226+
" \"RayCluster not ready — check RayJob / Workload admission and operator logs.\"\n",
227+
" )\n",
228+
"print(f\"Ray Dashboard (open in browser): {cluster.cluster_dashboard_uri()}\")\n",
229+
"print(\"In the dashboard, open Jobs and stream logs for the training driver.\")\n",
230+
"print(\n",
231+
" \"Wait for at least one full epoch and a checkpoint to S3 before running the suspend cell.\"\n",
232+
")"
233+
]
234+
},
235+
{
236+
"cell_type": "markdown",
237+
"metadata": {},
238+
"source": [
239+
"## Suspend RayJob (checkpoint demo)\n",
240+
"\n",
241+
"After logs show at least **one epoch** and a checkpoint written to **S3**, suspend the RayJob. This is a **manual** suspend for the demo (distinct from Kueue holding the job until admission right after submit).\n",
242+
"\n",
243+
"Use **Pause** in the OpenShift AI UI, or run the next cell (`job.stop()`)."
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": null,
249+
"metadata": {},
250+
"outputs": [],
251+
"source": [
252+
"print(\"=\" * 60)\n",
253+
"print(\"SUSPENDING RayJob (checkpoint demo — not deleting the RayJob CR)\")\n",
254+
"print(\"Checkpoints remain in S3.\")\n",
255+
"print(\"=\" * 60)\n",
256+
"\n",
257+
"job.stop()\n",
258+
"print(\"Stop requested; poll job.status() until the RayJob reports suspended / non-running.\")"
259+
]
260+
},
261+
{
262+
"cell_type": "markdown",
263+
"metadata": {},
264+
"source": [
265+
"## Resume RayJob\n",
266+
"\n",
267+
"Use **Resume** in the OpenShift AI UI, or run `job.resubmit()` in the next cell. When the RayCluster is back, confirm in the dashboard **Jobs** view: `RESUMING FROM CHECKPOINT - Starting at epoch N`."
268+
]
269+
},
270+
{
271+
"cell_type": "code",
272+
"execution_count": null,
273+
"metadata": {},
274+
"outputs": [],
275+
"source": [
276+
"print(\"=\" * 60)\n",
277+
"print(\"RESUMING RayJob after suspend\")\n",
278+
"print(\"Watch for RESUMING FROM CHECKPOINT in dashboard Jobs logs\")\n",
279+
"print(\"=\" * 60)\n",
280+
"\n",
281+
"job.resubmit()\n",
282+
"time.sleep(10)\n",
283+
"print(job.status())"
284+
]
285+
},
286+
{
287+
"cell_type": "markdown",
288+
"metadata": {},
289+
"source": [
290+
"## Verify resume from checkpoint\n",
291+
"\n",
292+
"In the Ray dashboard **Jobs** tab, look for:\n",
293+
"\n",
294+
"```\n",
295+
"RESUMING FROM CHECKPOINT - Starting at epoch N\n",
296+
"Previous loss: X.XXXX\n",
297+
"```\n",
298+
"\n",
299+
"That confirms optimizer and progress were restored from S3 across the suspend/resume cycle."
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": null,
305+
"metadata": {},
306+
"outputs": [],
307+
"source": [
308+
"print(job.status())\n",
309+
"try:\n",
310+
" cluster = get_cluster(job.cluster_name, namespace=NAMESPACE, verify_tls=False)\n",
311+
" print(f\"Ray Dashboard: {cluster.cluster_dashboard_uri()}\")\n",
312+
"except Exception as e:\n",
313+
" print(f\"Could not resolve cluster yet: {e}\")\n",
314+
"print(\"Check Jobs tab for: RESUMING FROM CHECKPOINT - Starting at epoch N\")"
315+
]
316+
},
317+
{
318+
"cell_type": "markdown",
319+
"metadata": {},
320+
"source": [
321+
"## Cleanup\n",
322+
"\n",
323+
"Delete the RayJob and tear down the RayCluster if it is still present."
324+
]
325+
},
326+
{
327+
"cell_type": "code",
328+
"execution_count": null,
329+
"metadata": {},
330+
"outputs": [],
331+
"source": [
332+
"print(\"Cleaning up...\")\n",
333+
"cluster_name = job.cluster_name\n",
334+
"try:\n",
335+
" job.delete()\n",
336+
"except Exception:\n",
337+
" pass\n",
338+
"\n",
339+
"try:\n",
340+
" c = get_cluster(cluster_name, namespace=NAMESPACE, verify_tls=False)\n",
341+
" c.down()\n",
342+
"except Exception:\n",
343+
" pass\n",
344+
"\n",
345+
"print(\"Cleanup attempted (RayJob delete; cluster.down if RayCluster still exists).\")"
346+
]
347+
},
348+
{
349+
"cell_type": "code",
350+
"execution_count": null,
351+
"metadata": {},
352+
"outputs": [],
353+
"source": [
354+
"# No explicit logout needed - authentication is managed automatically by kube-authkit"
355+
]
356+
}
357+
],
358+
"metadata": {
359+
"kernelspec": {
360+
"display_name": "Python 3",
361+
"language": "python",
362+
"name": "python3"
363+
},
364+
"language_info": {
365+
"codemirror_mode": {
366+
"name": "ipython",
367+
"version": 3
368+
},
369+
"file_extension": ".py",
370+
"mimetype": "text/x-python",
371+
"name": "python",
372+
"nbconvert_exporter": "python",
373+
"pygments_lexer": "ipython3",
374+
"version": "3.12.11"
375+
}
376+
},
377+
"nbformat": 4,
378+
"nbformat_minor": 2
379+
}

0 commit comments

Comments
 (0)