Skip to content

Commit c74a382

Browse files
committed
fix: address review comments (iteration #1)
1 parent 07b513b commit c74a382

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
"outputs": [],
3636
"source": [
3737
"# Install fix for MLflow path resolution issues\n",
38-
"# Note: Version pins (mlflow==3.4.0) may need updating - check for latest compatible versions\n",
39-
"%pip install mlflow==3.4.0"
38+
"# Using minimum version constraint so the notebook stays compatible with future releases\n",
39+
"%pip install 'mlflow>=3.4.0'"
4040
]
4141
},
4242
{
@@ -63,7 +63,6 @@
6363
"source": [
6464
"import json\n",
6565
"import uuid\n",
66-
"import boto3\n",
6766
"from sagemaker.core import image_uris\n",
6867
"from sagemaker.core.helper.session_helper import Session\n",
6968
"\n",
@@ -301,14 +300,14 @@
301300
"\n",
302301
"# Training on SageMaker managed infrastructure\n",
303302
"model_trainer = ModelTrainer(\n",
304-
" sagemaker_session=sagemaker_session,\n",
305303
" training_image=PYTORCH_TRAINING_IMAGE,\n",
306304
" source_code=SourceCode(\n",
307305
" source_dir=training_code_dir,\n",
308306
" entry_script=\"train.py\",\n",
309307
" requirements=\"requirements.txt\",\n",
310308
" ),\n",
311309
" base_job_name=training_job_name,\n",
310+
" sagemaker_session=sagemaker_session,\n",
312311
")\n",
313312
"\n",
314313
"# Start training job\n",
@@ -340,9 +339,11 @@
340339
"client = MlflowClient()\n",
341340
"\n",
342341
"# Use search_model_versions (compatible with MLflow 3.x)\n",
342+
"# Note: order_by field name may vary across MLflow versions;\n",
343+
"# 'creation_timestamp' is broadly supported.\n",
343344
"model_versions = client.search_model_versions(\n",
344345
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
345-
" order_by=['version_number DESC'],\n",
346+
" order_by=['creation_timestamp DESC'],\n",
346347
" max_results=1\n",
347348
")\n",
348349
"latest_version = model_versions[0]\n",
@@ -413,7 +414,7 @@
413414
" \"MLFLOW_MODEL_PATH\": mlflow_model_path,\n",
414415
" \"MLFLOW_TRACKING_ARN\": MLFLOW_TRACKING_ARN\n",
415416
" },\n",
416-
" dependencies={\"auto\": False, \"custom\": [\"mlflow==3.4.0\", \"sagemaker==3.3.1\", \"numpy==2.4.1\", \"cloudpickle==3.1.2\"]},\n",
417+
" dependencies={\"auto\": False, \"custom\": [\"mlflow>=3.4.0\", \"sagemaker>=3.3.1\", \"numpy>=2.4.1\", \"cloudpickle>=3.1.2\"]},\n",
417418
")\n",
418419
"\n",
419420
"print(f\"ModelBuilder configured with MLflow model: {mlflow_model_path}\")"
@@ -464,11 +465,18 @@
464465
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
465466
"\n",
466467
"result = core_endpoint.invoke(\n",
467-
" body=json.dumps(test_data),\n",
468+
" body=json.dumps(test_data).encode('utf-8'),\n",
468469
" content_type='application/json'\n",
469470
")\n",
470471
"\n",
471-
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
472+
"# The invoke() response body may be a streaming object or bytes;\n",
473+
"# handle both cases for robustness.\n",
474+
"response_body = result.body\n",
475+
"if hasattr(response_body, 'read'):\n",
476+
" response_body = response_body.read()\n",
477+
"if isinstance(response_body, bytes):\n",
478+
" response_body = response_body.decode('utf-8')\n",
479+
"prediction = json.loads(response_body)\n",
472480
"print(f\"Input: {test_data}\")\n",
473481
"print(f\"Prediction: {prediction}\")"
474482
]

0 commit comments

Comments
 (0)