Skip to content

Commit e304702

Browse files
committed
feature: MLFlow E2E Example Notebook (5513)
1 parent ee420cc commit e304702

File tree

1 file changed

+48
-16
lines changed

1 file changed

+48
-16
lines changed

v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"outputs": [],
6262
"source": [
6363
"import uuid\n",
64+
"import boto3\n",
6465
"from sagemaker.core import image_uris\n",
6566
"from sagemaker.core.helper.session_helper import Session\n",
6667
"\n",
@@ -71,7 +72,9 @@
7172
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
7273
"\n",
7374
"# AWS Configuration\n",
74-
"AWS_REGION = Session.boto_region_name\n",
75+
"boto_session = boto3.Session()\n",
76+
"sagemaker_session = Session(boto_session=boto_session)\n",
77+
"AWS_REGION = sagemaker_session.boto_region_name\n",
7578
"\n",
7679
"# Get PyTorch training image dynamically\n",
7780
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
@@ -297,6 +300,7 @@
297300
"\n",
298301
"# Training on SageMaker managed infrastructure\n",
299302
"model_trainer = ModelTrainer(\n",
303+
" sagemaker_session=sagemaker_session,\n",
300304
" training_image=PYTORCH_TRAINING_IMAGE,\n",
301305
" source_code=SourceCode(\n",
302306
" source_dir=training_code_dir,\n",
@@ -333,22 +337,29 @@
333337
"from mlflow import MlflowClient\n",
334338
"\n",
335339
"client = MlflowClient()\n",
336-
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
337340
"\n",
338-
"latest_version = registered_model.latest_versions[0]\n",
341+
"# Use search_model_versions (compatible with MLflow 3.x)\n",
342+
"# Note: latest_versions attribute was removed in MLflow 3.x\n",
343+
"model_versions = client.search_model_versions(\n",
344+
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
345+
" order_by=['version_number DESC'],\n",
346+
" max_results=1\n",
347+
")\n",
348+
"\n",
349+
"if not model_versions:\n",
350+
" raise ValueError(f\"No versions found for model '{MLFLOW_REGISTERED_MODEL_NAME}'\")\n",
351+
"\n",
352+
"latest_version = model_versions[0]\n",
339353
"model_version = latest_version.version\n",
340354
"model_source = latest_version.source\n",
341355
"\n",
342-
"# Get S3 URL of model files (for info only)\n",
343-
"artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)\n",
344-
"\n",
345356
"# MLflow model registry path to use with ModelBuilder\n",
346357
"mlflow_model_path = f\"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}\"\n",
347358
"\n",
348359
"print(f\"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}\")\n",
349360
"print(f\"Latest Version: {model_version}\")\n",
350-
"print(f\"Source: {model_source}\")\n",
351-
"print(f\"Model artifacts location: {artifact_uri}\")"
361+
"print(f\"Source (artifact location): {model_source}\")\n",
362+
"print(f\"MLflow model path for deployment: {mlflow_model_path}\")"
352363
]
353364
},
354365
{
@@ -481,23 +492,44 @@
481492
"metadata": {},
482493
"outputs": [],
483494
"source": [
484-
"import boto3\n",
485-
"\n",
486495
"# Test with JSON input\n",
487496
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
488497
"\n",
489-
"runtime_client = boto3.client('sagemaker-runtime')\n",
490-
"response = runtime_client.invoke_endpoint(\n",
491-
" EndpointName=core_endpoint.endpoint_name,\n",
492-
" Body=json.dumps(test_data),\n",
493-
" ContentType='application/json'\n",
498+
"result = core_endpoint.invoke(\n",
499+
" body=json.dumps(test_data),\n",
500+
" content_type='application/json'\n",
494501
")\n",
495502
"\n",
496-
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
503+
"# Decode and display the result\n",
504+
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
497505
"print(f\"Input: {test_data}\")\n",
498506
"print(f\"Prediction: {prediction}\")"
499507
]
500508
},
509+
{
510+
"cell_type": "code",
511+
"execution_count": null,
512+
"metadata": {},
513+
"outputs": [],
514+
"source": [
515+
"# Test with different tensor inputs\n",
516+
"test_inputs = [\n",
517+
" [[0.5, 0.3, 0.2, 0.1]],\n",
518+
" [[0.9, 0.1, 0.8, 0.2]],\n",
519+
" [[0.2, 0.7, 0.4, 0.6]]\n",
520+
"]\n",
521+
"\n",
522+
"for i, test_input in enumerate(test_inputs, 1):\n",
523+
" result = core_endpoint.invoke(\n",
524+
" body=json.dumps(test_input),\n",
525+
" content_type='application/json'\n",
526+
" )\n",
527+
" \n",
528+
" prediction = json.loads(result.body.read().decode('utf-8'))\n",
529+
" print(f\"Test {i} - Input {test_input}: {prediction}\")\n",
530+
" print(\"-\" * 50)"
531+
]
532+
},
501533
{
502534
"cell_type": "markdown",
503535
"metadata": {},

0 commit comments

Comments
 (0)