Skip to content

Commit 96c07a0

Browse files
committed
fix: MLFlow E2E Example Notebook (5513)
1 parent ee420cc commit 96c07a0

File tree

2 files changed

+54
-19
lines changed

2 files changed

+54
-19
lines changed

v3-examples/inference-examples/train-inference-e2e-example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"TRAINING_JOB_PREFIX = \"e2e-v3-pytorch\"\n",
6262
"\n",
6363
"# AWS Configuration\n",
64-
"AWS_REGION = Session.boto_region_name\n",
64+
"AWS_REGION = Session().boto_region_name\n",
6565
"PYTORCH_TRAINING_IMAGE = f\"763104351884.dkr.ecr.{AWS_REGION}.amazonaws.com/pytorch-training:1.13.1-cpu-py39\"\n",
6666
"\n",
6767
"# Generate unique identifiers\n",

v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
7272
"\n",
7373
"# AWS Configuration\n",
74-
"AWS_REGION = Session.boto_region_name\n",
74+
"AWS_REGION = Session().boto_region_name\n",
7575
"\n",
7676
"# Get PyTorch training image dynamically\n",
7777
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
@@ -330,25 +330,33 @@
330330
"outputs": [],
331331
"source": [
332332
"# Get the latest version of the registered model\n",
333+
"# NOTE: MLflow 3.x removed `registered_model.latest_versions`. Use\n",
334+
"# `client.search_model_versions()` instead.\n",
333335
"from mlflow import MlflowClient\n",
334336
"\n",
335337
"client = MlflowClient()\n",
336-
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
337338
"\n",
338-
"latest_version = registered_model.latest_versions[0]\n",
339+
"# Search for the latest version of the registered model (MLflow 3.x compatible)\n",
340+
"versions = client.search_model_versions(\n",
341+
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
342+
" order_by=['version_number DESC'],\n",
343+
" max_results=1\n",
344+
")\n",
345+
"\n",
346+
"if not versions:\n",
347+
" raise ValueError(f\"No versions found for model '{MLFLOW_REGISTERED_MODEL_NAME}'\")\n",
348+
"\n",
349+
"latest_version = versions[0]\n",
339350
"model_version = latest_version.version\n",
340351
"model_source = latest_version.source\n",
341352
"\n",
342-
"# Get S3 URL of model files (for info only)\n",
343-
"artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)\n",
344-
"\n",
345353
"# MLflow model registry path to use with ModelBuilder\n",
346354
"mlflow_model_path = f\"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}\"\n",
347355
"\n",
348356
"print(f\"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}\")\n",
349357
"print(f\"Latest Version: {model_version}\")\n",
350-
"print(f\"Source: {model_source}\")\n",
351-
"print(f\"Model artifacts location: {artifact_uri}\")"
358+
"print(f\"Source (artifact location): {model_source}\")\n",
359+
"print(f\"MLflow model path for deployment: {mlflow_model_path}\")"
352360
]
353361
},
354362
{
@@ -427,6 +435,8 @@
427435
"from sagemaker.serve.mode.function_pointers import Mode\n",
428436
"\n",
429437
"# Cloud deployment to SageMaker endpoint\n",
438+
"# Note: 'dependencies' parameter is deprecated. You may see a deprecation warning.\n",
439+
"# Use configure_for_torchserve() for new projects.\n",
430440
"model_builder = ModelBuilder(\n",
431441
" mode=Mode.SAGEMAKER_ENDPOINT,\n",
432442
" schema_builder=schema_builder,\n",
@@ -481,23 +491,43 @@
481491
"metadata": {},
482492
"outputs": [],
483493
"source": [
484-
"import boto3\n",
485-
"\n",
486-
"# Test with JSON input\n",
494+
"# Test with JSON input using V3-native endpoint invocation\n",
487495
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
488496
"\n",
489-
"runtime_client = boto3.client('sagemaker-runtime')\n",
490-
"response = runtime_client.invoke_endpoint(\n",
491-
" EndpointName=core_endpoint.endpoint_name,\n",
492-
" Body=json.dumps(test_data),\n",
493-
" ContentType='application/json'\n",
497+
"result = core_endpoint.invoke(\n",
498+
" body=json.dumps(test_data),\n",
499+
" content_type='application/json'\n",
494500
")\n",
495501
"\n",
496-
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
502+
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
497503
"print(f\"Input: {test_data}\")\n",
498504
"print(f\"Prediction: {prediction}\")"
499505
]
500506
},
507+
{
508+
"cell_type": "code",
509+
"execution_count": null,
510+
"metadata": {},
511+
"outputs": [],
512+
"source": [
513+
"# Test with multiple inputs\n",
514+
"test_inputs = [\n",
515+
" [[0.5, 0.3, 0.2, 0.1]],\n",
516+
" [[0.9, 0.1, 0.8, 0.2]],\n",
517+
" [[0.2, 0.7, 0.4, 0.6]]\n",
518+
"]\n",
519+
"\n",
520+
"for i, test_input in enumerate(test_inputs, 1):\n",
521+
" result = core_endpoint.invoke(\n",
522+
" body=json.dumps(test_input),\n",
523+
" content_type='application/json'\n",
524+
" )\n",
525+
" \n",
526+
" prediction = json.loads(result.body.read().decode('utf-8'))\n",
527+
" print(f\"Test {i} - Input {test_input}: {prediction}\")\n",
528+
" print('-' * 50)"
529+
]
530+
},
501531
{
502532
"cell_type": "markdown",
503533
"metadata": {},
@@ -551,7 +581,12 @@
551581
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
552582
"\n",
553583
"Key patterns:\n",
554-
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
584+
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n",
585+
"- V3-native `core_endpoint.invoke()` for inference\n",
586+
"\n",
587+
"**MLflow 3.x API Note:**\n",
588+
"- Use `client.search_model_versions()` instead of the removed `registered_model.latest_versions` attribute\n",
589+
"- Use `latest_version.source` for artifact location instead of `client.get_model_version_download_uri()`\n"
555590
]
556591
}
557592
],

0 commit comments

Comments
 (0)