|
35 | 35 | "outputs": [], |
36 | 36 | "source": [ |
37 | 37 | "# Install fix for MLflow path resolution issues\n", |
38 | | - "# Note: Version pins (mlflow==3.4.0) may need updating - check for latest compatible versions\n", |
39 | | - "%pip install mlflow==3.4.0" |
| 38 | + "# Using minimum version constraint so the notebook stays compatible with future releases\n", |
| 39 | + "%pip install 'mlflow>=3.4.0'" |
40 | 40 | ] |
41 | 41 | }, |
42 | 42 | { |
|
63 | 63 | "source": [ |
64 | 64 | "import json\n", |
65 | 65 | "import uuid\n", |
66 | | - "import boto3\n", |
67 | 66 | "from sagemaker.core import image_uris\n", |
68 | 67 | "from sagemaker.core.helper.session_helper import Session\n", |
69 | 68 | "\n", |
|
301 | 300 | "\n", |
302 | 301 | "# Training on SageMaker managed infrastructure\n", |
303 | 302 | "model_trainer = ModelTrainer(\n", |
304 | | - " sagemaker_session=sagemaker_session,\n", |
305 | 303 | " training_image=PYTORCH_TRAINING_IMAGE,\n", |
306 | 304 | " source_code=SourceCode(\n", |
307 | 305 | " source_dir=training_code_dir,\n", |
308 | 306 | " entry_script=\"train.py\",\n", |
309 | 307 | " requirements=\"requirements.txt\",\n", |
310 | 308 | " ),\n", |
311 | 309 | " base_job_name=training_job_name,\n", |
| 310 | + " sagemaker_session=sagemaker_session,\n", |
312 | 311 | ")\n", |
313 | 312 | "\n", |
314 | 313 | "# Start training job\n", |
|
340 | 339 | "client = MlflowClient()\n", |
341 | 340 | "\n", |
342 | 341 | "# Use search_model_versions (compatible with MLflow 3.x)\n", |
| 342 | + "# Note: order_by field name may vary across MLflow versions;\n", |
| 343 | + "# 'creation_timestamp' is broadly supported.\n", |
343 | 344 | "model_versions = client.search_model_versions(\n", |
344 | 345 | " filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n", |
345 | | - " order_by=['version_number DESC'],\n", |
| 346 | + " order_by=['creation_timestamp DESC'],\n", |
346 | 347 | " max_results=1\n", |
347 | 348 | ")\n", |
348 | 349 | "latest_version = model_versions[0]\n", |
|
413 | 414 | " \"MLFLOW_MODEL_PATH\": mlflow_model_path,\n", |
414 | 415 | " \"MLFLOW_TRACKING_ARN\": MLFLOW_TRACKING_ARN\n", |
415 | 416 | " },\n", |
416 | | - " dependencies={\"auto\": False, \"custom\": [\"mlflow==3.4.0\", \"sagemaker==3.3.1\", \"numpy==2.4.1\", \"cloudpickle==3.1.2\"]},\n", |
| 417 | + " dependencies={\"auto\": False, \"custom\": [\"mlflow>=3.4.0\", \"sagemaker>=3.3.1\", \"numpy>=2.4.1\", \"cloudpickle>=3.1.2\"]},\n", |
417 | 418 | ")\n", |
418 | 419 | "\n", |
419 | 420 | "print(f\"ModelBuilder configured with MLflow model: {mlflow_model_path}\")" |
|
464 | 465 | "test_data = [[0.1, 0.2, 0.3, 0.4]]\n", |
465 | 466 | "\n", |
466 | 467 | "result = core_endpoint.invoke(\n", |
467 | | - " body=json.dumps(test_data),\n", |
| 468 | + " body=json.dumps(test_data).encode('utf-8'),\n", |
468 | 469 | " content_type='application/json'\n", |
469 | 470 | ")\n", |
470 | 471 | "\n", |
471 | | - "prediction = json.loads(result.body.read().decode('utf-8'))\n", |
| 472 | + "# The invoke() response body may be a streaming object or bytes;\n", |
| 473 | + "# handle both cases for robustness.\n", |
| 474 | + "response_body = result.body\n", |
| 475 | + "if hasattr(response_body, 'read'):\n", |
| 476 | + " response_body = response_body.read()\n", |
| 477 | + "if isinstance(response_body, bytes):\n", |
| 478 | + " response_body = response_body.decode('utf-8')\n", |
| 479 | + "prediction = json.loads(response_body)\n", |
472 | 480 | "print(f\"Input: {test_data}\")\n", |
473 | 481 | "print(f\"Prediction: {prediction}\")" |
474 | 482 | ] |
|
0 commit comments