|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Ray checkpointing example\n", |
| 8 | + "\n", |
| 9 | + "This notebook runs a **Ray Train** checkpointing demo on **Red Hat OpenShift AI** using the CodeFlare SDK:\n", |
| 10 | + "\n", |
| 11 | + "- **Red Hat build of Kueue** must be configured in your cluster (ResourceFlavor → ClusterQueue → **LocalQueue** in your namespace, and `kueue.openshift.io/managed=true` on the namespace). See the OpenShift *AI workloads* documentation for [Red Hat build of Kueue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue).\n", |
| 12 | + "- Submit a **RayJob** with a **managed Ray cluster** (`ManagedClusterConfig`) so KubeRay lifecycles the cluster with the job (`shutdownAfterJobFinishes`). The RayJob is labeled for your **LocalQueue** via `local_queue` (example: `\"default\"`).\n", |
| 13 | + "- Configure **AWS credentials** for the S3 bucket used by Ray Train checkpoints.\n", |
| 14 | + "- **Monitor** training in the Ray dashboard (**Jobs** tab), then **suspend and resume** the RayJob (`job.stop()` / `job.resubmit()`) to verify training **resumes from S3** after a simulated interruption.\n", |
| 15 | + "\n", |
| 16 | + "Training script: `train_with_checkpoints.py` in this directory (same source as the CodeFlare SDK guided demo)." |
| 17 | + ] |
| 18 | + }, |
| 19 | + { |
| 20 | + "cell_type": "markdown", |
| 21 | + "metadata": {}, |
| 22 | + "source": [ |
| 23 | + "## Import required libraries" |
| 24 | + ] |
| 25 | + }, |
| 26 | + { |
| 27 | + "cell_type": "code", |
| 28 | + "execution_count": null, |
| 29 | + "metadata": {}, |
| 30 | + "outputs": [], |
| 31 | + "source": [ |
| 32 | + "from codeflare_sdk import RayJob, ManagedClusterConfig, set_api_client, get_cluster\n", |
| 33 | + "from kube_authkit import AuthConfig, get_k8s_client\n", |
| 34 | + "import time" |
| 35 | + ] |
| 36 | + }, |
| 37 | + { |
| 38 | + "cell_type": "markdown", |
| 39 | + "metadata": {}, |
| 40 | + "source": [ |
| 41 | + "## Authenticate to your OpenShift cluster" |
| 42 | + ] |
| 43 | + }, |
| 44 | + { |
| 45 | + "cell_type": "code", |
| 46 | + "execution_count": null, |
| 47 | + "metadata": {}, |
| 48 | + "outputs": [], |
| 49 | + "source": [ |
| 50 | + "import urllib3\n", |
| 51 | + "\n", |
| 52 | + "urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)\n", |
| 53 | + "\n", |
| 54 | + "# Authenticate to your Kubernetes/OpenShift cluster using kube-authkit\n", |
| 55 | + "\n", |
| 56 | + "# Option 1: Auto-detect credentials (kubeconfig or in-cluster service account)\n", |
| 57 | + "# NOTE: In RHOAI Workbenches the workbench service account may not have Ray RBAC\n", |
| 58 | + "# permissions. Use Option 2 (token) unless your admin has granted SA permissions\n", |
| 59 | + "# (see RHOAIENG-46748). Auto-detect works if you have a local kubeconfig.\n", |
| 60 | + "# auth_config = AuthConfig(method=\"auto\")\n", |
| 61 | + "\n", |
| 62 | + "# Option 2 (Recommended for RHOAI Workbenches): Token-based authentication\n", |
| 63 | + "# Get your token with: oc whoami -t (or from the OpenShift console → Copy login command)\n", |
| 64 | + "auth_config = AuthConfig(\n", |
| 65 | + " method=\"openshift\",\n", |
| 66 | + " k8s_api_host=\"https://api.example.com:6443\",\n", |
| 67 | + " token=\"sha256~XXXXX\", # oc whoami -t\n", |
| 68 | + ")\n", |
| 69 | + "\n", |
| 70 | + "# Option 3: OIDC authentication (for BYOIDC-enabled clusters)\n", |
| 71 | + "# auth_config = AuthConfig(\n", |
| 72 | + "# method=\"oidc\",\n", |
| 73 | + "# k8s_api_host=\"https://api.example.com:6443\",\n", |
| 74 | + "# oidc_issuer=\"https://your-oidc-provider.com\",\n", |
| 75 | + "# client_id=\"your-client-id\",\n", |
| 76 | + "# use_device_flow=True, # Interactive device flow for notebook environments\n", |
| 77 | + "# )\n", |
| 78 | + "\n", |
| 79 | + "api_client = get_k8s_client(config=auth_config)\n", |
| 80 | + "# Set to False for self-signed / dev API certificates (optional).\n", |
| 81 | + "api_client.configuration.verify_ssl = False\n", |
| 82 | + "set_api_client(api_client)\n", |
| 83 | + "\n", |
| 84 | + "NAMESPACE = \"your-namespace\" # Data Science Project where LocalQueue + RayJob run\n", |
| 85 | + "JOB_NAME = \"checkpointing-job\"\n", |
| 86 | + "# Must match metadata.name of a LocalQueue in NAMESPACE (create per OpenShift Kueue docs).\n", |
| 87 | + "LOCAL_QUEUE = \"default\"" |
| 88 | + ] |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "markdown", |
| 92 | + "metadata": {}, |
| 93 | + "source": [ |
| 94 | + "## Red Hat build of Kueue (required before submit)\n", |
| 95 | + "\n", |
| 96 | + "Configure **ResourceFlavor** → **ClusterQueue** → **LocalQueue** in your project namespace, and label the namespace so Kueue manages workloads there. Official OpenShift 4.21 *AI workloads* docs:\n", |
| 97 | + "\n", |
| 98 | + "- [Configuring a resource flavor](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-resourceflavors_configuring-quotas)\n", |
| 99 | + "- [Configuring a cluster queue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-clusterqueues_configuring-quotas)\n", |
| 100 | + "- [Configuring a local queue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-localqueues_configuring-quotas) — the `LocalQueue` `metadata.name` must match **`LOCAL_QUEUE`** above (e.g. `default`).\n", |
| 101 | + "- [Labeling namespaces to allow Red Hat build of Kueue to manage jobs](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#labeling-namespaces-to-allow-red-hat-build-of-kueue-to-manage-jobs_managing-jobs-and-workloads): `oc label namespace <namespace> kueue.openshift.io/managed=true`\n", |
| 102 | + "\n", |
| 103 | + "After submit, the RayJob may show **Suspended** until Kueue **admits** a `Workload` — use `oc get workloads.kueue.x-k8s.io -n $NAMESPACE` if needed. That is **not** the same as the manual suspend used later for the checkpoint demo." |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "code", |
| 108 | + "execution_count": null, |
| 109 | + "metadata": {}, |
| 110 | + "outputs": [], |
| 111 | + "source": [ |
| 112 | + "# Optional: verify Kueue objects exist (cluster admin / user with read access)\n", |
| 113 | + "# !oc get resourceflavor.kueue.x-k8s.io\n", |
| 114 | + "# !oc get clusterqueue.kueue.x-k8s.io\n", |
| 115 | + "# !oc get localqueue.kueue.x-k8s.io -n $NAMESPACE\n", |
| 116 | + "\n", |
| 117 | + "print(f\"Namespace: {NAMESPACE!r}, RayJob name: {JOB_NAME!r}, LocalQueue: {LOCAL_QUEUE!r}\")" |
| 118 | + ] |
| 119 | + }, |
| 120 | + { |
| 121 | + "cell_type": "markdown", |
| 122 | + "metadata": {}, |
| 123 | + "source": [ |
| 124 | + "## Set your AWS credentials" |
| 125 | + ] |
| 126 | + }, |
| 127 | + { |
| 128 | + "cell_type": "code", |
| 129 | + "execution_count": null, |
| 130 | + "metadata": {}, |
| 131 | + "outputs": [], |
| 132 | + "source": [ |
| 133 | + "# Set your AWS credentials\n", |
| 134 | + "# WARNING: Do not commit credentials to version control. For production,\n", |
| 135 | + "# use OpenShift AI Data Connections or OpenShift Secrets instead.\n", |
| 136 | + "AWS_CREDENTIALS = {\n", |
| 137 | + " \"AWS_ACCESS_KEY_ID\": \"your-access-key\",\n", |
| 138 | + " \"AWS_SECRET_ACCESS_KEY\": \"your-secret-key\",\n", |
| 139 | + " \"AWS_DEFAULT_REGION\": \"us-east-1\", # e.g. \"us-east-1\"\n", |
| 140 | + " \"AWS_S3_BUCKET\": \"your-bucket-name\",\n", |
| 141 | + "}\n", |
| 142 | + "\n", |
| 143 | + "# If using temporary credentials (SSO/federated), add the session token:\n", |
| 144 | + "# AWS_CREDENTIALS[\"AWS_SESSION_TOKEN\"] = \"your-session-token\"\n", |
| 145 | + "\n", |
| 146 | + "print(f\"Using bucket: {AWS_CREDENTIALS['AWS_S3_BUCKET']}\")" |
| 147 | + ] |
| 148 | + }, |
| 149 | + { |
| 150 | + "cell_type": "markdown", |
| 151 | + "metadata": {}, |
| 152 | + "source": [ |
| 153 | + "## Submit RayJob (managed Ray cluster + Kueue local queue)" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "cell_type": "code", |
| 158 | + "execution_count": null, |
| 159 | + "metadata": {}, |
| 160 | + "outputs": [], |
| 161 | + "source": [ |
| 162 | + "managed = ManagedClusterConfig(\n", |
| 163 | + " num_workers=2,\n", |
| 164 | + " head_cpu_requests=2,\n", |
| 165 | + " head_cpu_limits=4,\n", |
| 166 | + " head_memory_requests=4,\n", |
| 167 | + " head_memory_limits=8,\n", |
| 168 | + " worker_cpu_requests=2,\n", |
| 169 | + " worker_cpu_limits=4,\n", |
| 170 | + " worker_memory_requests=4,\n", |
| 171 | + " worker_memory_limits=8,\n", |
| 172 | + ")\n", |
| 173 | + "\n", |
| 174 | + "job = RayJob(\n", |
| 175 | + " job_name=JOB_NAME,\n", |
| 176 | + " entrypoint=\"python train_with_checkpoints.py\",\n", |
| 177 | + " cluster_config=managed,\n", |
| 178 | + " namespace=NAMESPACE,\n", |
| 179 | + " local_queue=LOCAL_QUEUE,\n", |
| 180 | + " runtime_env={\n", |
| 181 | + " \"working_dir\": \"./\",\n", |
| 182 | + " \"pip\": [\"torch\", \"torchvision\", \"s3fs\", \"pyarrow\"],\n", |
| 183 | + " \"env_vars\": {\n", |
| 184 | + " **AWS_CREDENTIALS,\n", |
| 185 | + " \"RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S\": \"120\", # Allow time for worker scheduling\n", |
| 186 | + " },\n", |
| 187 | + " },\n", |
| 188 | + ")\n", |
| 189 | + "\n", |
| 190 | + "job.submit()\n", |
| 191 | + "print(\n", |
| 192 | + " \"RayJob submitted. If status stays Suspended briefly, Kueue may still be admitting the Workload.\"\n", |
| 193 | + ")\n", |
| 194 | + "print(f\"RayCluster name (when assigned): {job.cluster_name}\")\n", |
| 195 | + "print(\"Watch logs for: NO CHECKPOINT FOUND - Starting fresh\")" |
| 196 | + ] |
| 197 | + }, |
| 198 | + { |
| 199 | + "cell_type": "markdown", |
| 200 | + "metadata": {}, |
| 201 | + "source": [ |
| 202 | + "## Monitor job progress (status + Ray dashboard Jobs)\n", |
| 203 | + "\n", |
| 204 | + "Poll `job.status()`, then open the **Ray dashboard** URL from the RayCluster created by your RayJob. Use **Jobs** in the dashboard for live driver logs (epochs, checkpoint messages)." |
| 205 | + ] |
| 206 | + }, |
| 207 | + { |
| 208 | + "cell_type": "code", |
| 209 | + "execution_count": null, |
| 210 | + "metadata": {}, |
| 211 | + "outputs": [], |
| 212 | + "source": [ |
| 213 | + "print(job.status())\n", |
| 214 | + "# KubeRay assigns a generated name — not job.cluster_name (the template).\n", |
| 215 | + "cluster = None\n", |
| 216 | + "for _ in range(36):\n", |
| 217 | + " status_data = job._api.get_job_status(name=job.name, k8s_namespace=NAMESPACE)\n", |
| 218 | + " ray_cluster_name = (status_data or {}).get(\"rayClusterName\")\n", |
| 219 | + " if ray_cluster_name:\n", |
| 220 | + " cluster = get_cluster(ray_cluster_name, namespace=NAMESPACE, verify_tls=False)\n", |
| 221 | + " if cluster is not None:\n", |
| 222 | + " break\n", |
| 223 | + " time.sleep(5)\n", |
| 224 | + "if cluster is None:\n", |
| 225 | + " raise RuntimeError(\n", |
| 226 | + " \"RayCluster not ready — check RayJob / Workload admission and operator logs.\"\n", |
| 227 | + " )\n", |
| 228 | + "print(f\"Ray Dashboard (open in browser): {cluster.cluster_dashboard_uri()}\")\n", |
| 229 | + "print(\"In the dashboard, open Jobs and stream logs for the training driver.\")\n", |
| 230 | + "print(\n", |
| 231 | + " \"Wait for at least one full epoch and a checkpoint to S3 before running the suspend cell.\"\n", |
| 232 | + ")" |
| 233 | + ] |
| 234 | + }, |
| 235 | + { |
| 236 | + "cell_type": "markdown", |
| 237 | + "metadata": {}, |
| 238 | + "source": [ |
| 239 | + "## Suspend RayJob (checkpoint demo)\n", |
| 240 | + "\n", |
| 241 | + "After logs show at least **one epoch** and a checkpoint written to **S3**, suspend the RayJob. This is a **manual** suspend for the demo (distinct from Kueue holding the job until admission right after submit).\n", |
| 242 | + "\n", |
| 243 | + "Use **Pause** in the OpenShift AI UI, or run the next cell (`job.stop()`)." |
| 244 | + ] |
| 245 | + }, |
| 246 | + { |
| 247 | + "cell_type": "code", |
| 248 | + "execution_count": null, |
| 249 | + "metadata": {}, |
| 250 | + "outputs": [], |
| 251 | + "source": [ |
| 252 | + "print(\"=\" * 60)\n", |
| 253 | + "print(\"SUSPENDING RayJob (checkpoint demo — not deleting the RayJob CR)\")\n", |
| 254 | + "print(\"Checkpoints remain in S3.\")\n", |
| 255 | + "print(\"=\" * 60)\n", |
| 256 | + "\n", |
| 257 | + "job.stop()\n", |
| 258 | + "print(\"Stop requested; poll job.status() until the RayJob reports suspended / non-running.\")" |
| 259 | + ] |
| 260 | + }, |
| 261 | + { |
| 262 | + "cell_type": "markdown", |
| 263 | + "metadata": {}, |
| 264 | + "source": [ |
| 265 | + "## Resume RayJob\n", |
| 266 | + "\n", |
| 267 | + "Use **Resume** in the OpenShift AI UI, or run `job.resubmit()` in the next cell. When the RayCluster is back, confirm in the dashboard **Jobs** view: `RESUMING FROM CHECKPOINT - Starting at epoch N`." |
| 268 | + ] |
| 269 | + }, |
| 270 | + { |
| 271 | + "cell_type": "code", |
| 272 | + "execution_count": null, |
| 273 | + "metadata": {}, |
| 274 | + "outputs": [], |
| 275 | + "source": [ |
| 276 | + "print(\"=\" * 60)\n", |
| 277 | + "print(\"RESUMING RayJob after suspend\")\n", |
| 278 | + "print(\"Watch for RESUMING FROM CHECKPOINT in dashboard Jobs logs\")\n", |
| 279 | + "print(\"=\" * 60)\n", |
| 280 | + "\n", |
| 281 | + "job.resubmit()\n", |
| 282 | + "time.sleep(10)\n", |
| 283 | + "print(job.status())" |
| 284 | + ] |
| 285 | + }, |
| 286 | + { |
| 287 | + "cell_type": "markdown", |
| 288 | + "metadata": {}, |
| 289 | + "source": [ |
| 290 | + "## Verify resume from checkpoint\n", |
| 291 | + "\n", |
| 292 | + "In the Ray dashboard **Jobs** tab, look for:\n", |
| 293 | + "\n", |
| 294 | + "```\n", |
| 295 | + "RESUMING FROM CHECKPOINT - Starting at epoch N\n", |
| 296 | + "Previous loss: X.XXXX\n", |
| 297 | + "```\n", |
| 298 | + "\n", |
| 299 | + "That confirms optimizer and progress were restored from S3 across the suspend/resume cycle." |
| 300 | + ] |
| 301 | + }, |
| 302 | + { |
| 303 | + "cell_type": "code", |
| 304 | + "execution_count": null, |
| 305 | + "metadata": {}, |
| 306 | + "outputs": [], |
| 307 | + "source": [ |
| 308 | + "print(job.status())\n", |
| 309 | + "try:\n", |
| 310 | + " cluster = get_cluster(job.cluster_name, namespace=NAMESPACE, verify_tls=False)\n", |
| 311 | + " print(f\"Ray Dashboard: {cluster.cluster_dashboard_uri()}\")\n", |
| 312 | + "except Exception as e:\n", |
| 313 | + " print(f\"Could not resolve cluster yet: {e}\")\n", |
| 314 | + "print(\"Check Jobs tab for: RESUMING FROM CHECKPOINT - Starting at epoch N\")" |
| 315 | + ] |
| 316 | + }, |
| 317 | + { |
| 318 | + "cell_type": "markdown", |
| 319 | + "metadata": {}, |
| 320 | + "source": [ |
| 321 | + "## Cleanup\n", |
| 322 | + "\n", |
| 323 | + "Delete the RayJob and tear down the RayCluster if it is still present." |
| 324 | + ] |
| 325 | + }, |
| 326 | + { |
| 327 | + "cell_type": "code", |
| 328 | + "execution_count": null, |
| 329 | + "metadata": {}, |
| 330 | + "outputs": [], |
| 331 | + "source": [ |
| 332 | + "print(\"Cleaning up...\")\n", |
| 333 | + "cluster_name = job.cluster_name\n", |
| 334 | + "try:\n", |
| 335 | + " job.delete()\n", |
| 336 | + "except Exception:\n", |
| 337 | + " pass\n", |
| 338 | + "\n", |
| 339 | + "try:\n", |
| 340 | + " c = get_cluster(cluster_name, namespace=NAMESPACE, verify_tls=False)\n", |
| 341 | + " c.down()\n", |
| 342 | + "except Exception:\n", |
| 343 | + " pass\n", |
| 344 | + "\n", |
| 345 | + "print(\"Cleanup attempted (RayJob delete; cluster.down if RayCluster still exists).\")" |
| 346 | + ] |
| 347 | + }, |
| 348 | + { |
| 349 | + "cell_type": "code", |
| 350 | + "execution_count": null, |
| 351 | + "metadata": {}, |
| 352 | + "outputs": [], |
| 353 | + "source": [ |
| 354 | + "# No explicit logout needed - authentication is managed automatically by kube-authkit" |
| 355 | + ] |
| 356 | + } |
| 357 | + ], |
| 358 | + "metadata": { |
| 359 | + "kernelspec": { |
| 360 | + "display_name": "Python 3", |
| 361 | + "language": "python", |
| 362 | + "name": "python3" |
| 363 | + }, |
| 364 | + "language_info": { |
| 365 | + "codemirror_mode": { |
| 366 | + "name": "ipython", |
| 367 | + "version": 3 |
| 368 | + }, |
| 369 | + "file_extension": ".py", |
| 370 | + "mimetype": "text/x-python", |
| 371 | + "name": "python", |
| 372 | + "nbconvert_exporter": "python", |
| 373 | + "pygments_lexer": "ipython3", |
| 374 | + "version": "3.12.11" |
| 375 | + } |
| 376 | + }, |
| 377 | + "nbformat": 4, |
| 378 | + "nbformat_minor": 2 |
| 379 | +} |
0 commit comments