Skip to content

Commit 07b513b

Browse files
committed
feature: MLFlow E2E Example Notebook (5513)
1 parent e161199 commit 07b513b

File tree

1 file changed

+26
-49
lines changed

1 file changed

+26
-49
lines changed

v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"outputs": [],
3636
"source": [
3737
"# Install fix for MLflow path resolution issues\n",
38+
"# Note: Version pins (mlflow==3.4.0) may need updating - check for latest compatible versions\n",
3839
"%pip install mlflow==3.4.0"
3940
]
4041
},
@@ -60,7 +61,9 @@
6061
"metadata": {},
6162
"outputs": [],
6263
"source": [
64+
"import json\n",
6365
"import uuid\n",
66+
"import boto3\n",
6467
"from sagemaker.core import image_uris\n",
6568
"from sagemaker.core.helper.session_helper import Session\n",
6669
"\n",
@@ -71,7 +74,8 @@
7174
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
7275
"\n",
7376
"# AWS Configuration\n",
74-
"AWS_REGION = Session.boto_region_name\n",
77+
"sagemaker_session = Session()\n",
78+
"AWS_REGION = sagemaker_session.boto_region_name\n",
7579
"\n",
7680
"# Get PyTorch training image dynamically\n",
7781
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
@@ -297,6 +301,7 @@
297301
"\n",
298302
"# Training on SageMaker managed infrastructure\n",
299303
"model_trainer = ModelTrainer(\n",
304+
" sagemaker_session=sagemaker_session,\n",
300305
" training_image=PYTORCH_TRAINING_IMAGE,\n",
301306
" source_code=SourceCode(\n",
302307
" source_dir=training_code_dir,\n",
@@ -333,9 +338,14 @@
333338
"from mlflow import MlflowClient\n",
334339
"\n",
335340
"client = MlflowClient()\n",
336-
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
337341
"\n",
338-
"latest_version = registered_model.latest_versions[0]\n",
342+
"# Use search_model_versions (compatible with MLflow 3.x)\n",
343+
"model_versions = client.search_model_versions(\n",
344+
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
345+
" order_by=['version_number DESC'],\n",
346+
" max_results=1\n",
347+
")\n",
348+
"latest_version = model_versions[0]\n",
339349
"model_version = latest_version.version\n",
340350
"model_source = latest_version.source\n",
341351
"\n",
@@ -366,54 +376,23 @@
366376
"metadata": {},
367377
"outputs": [],
368378
"source": [
369-
"import json\n",
370-
"import torch\n",
371-
"from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n",
372379
"from sagemaker.serve.builder.schema_builder import SchemaBuilder\n",
373380
"\n",
374381
"# =============================================================================\n",
375-
"# Custom translators for PyTorch tensor conversion\n",
376-
"# \n",
377-
"# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n",
378-
"# These translators handle the conversion between JSON payloads and PyTorch tensors.\n",
382+
"# Schema Builder for MLflow Model\n",
383+
"#\n",
384+
"# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n",
385+
"# serialization/deserialization automatically. We only need to provide sample\n",
386+
"# input/output for schema inference - no custom translators needed.\n",
379387
"# =============================================================================\n",
380388
"\n",
381-
"class PyTorchInputTranslator(CustomPayloadTranslator):\n",
382-
" \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n",
383-
" def __init__(self):\n",
384-
" super().__init__(content_type='application/json', accept_type='application/json')\n",
385-
" \n",
386-
" def serialize_payload_to_bytes(self, payload: object) -> bytes:\n",
387-
" if isinstance(payload, torch.Tensor):\n",
388-
" return json.dumps(payload.tolist()).encode('utf-8')\n",
389-
" return json.dumps(payload).encode('utf-8')\n",
390-
" \n",
391-
" def deserialize_payload_from_stream(self, stream) -> object:\n",
392-
" data = json.load(stream)\n",
393-
" return torch.tensor(data, dtype=torch.float32)\n",
394-
"\n",
395-
"class PyTorchOutputTranslator(CustomPayloadTranslator):\n",
396-
" \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n",
397-
" def __init__(self):\n",
398-
" super().__init__(content_type='application/json', accept_type='application/json')\n",
399-
" \n",
400-
" def serialize_payload_to_bytes(self, payload: object) -> bytes:\n",
401-
" if isinstance(payload, torch.Tensor):\n",
402-
" return json.dumps(payload.tolist()).encode('utf-8')\n",
403-
" return json.dumps(payload).encode('utf-8')\n",
404-
" \n",
405-
" def deserialize_payload_from_stream(self, stream) -> object:\n",
406-
" return json.load(stream)\n",
407-
"\n",
408389
"# Sample input/output for schema inference\n",
409390
"sample_input = [[0.1, 0.2, 0.3, 0.4]]\n",
410391
"sample_output = [[0.8, 0.2]]\n",
411392
"\n",
412393
"schema_builder = SchemaBuilder(\n",
413394
" sample_input=sample_input,\n",
414-
" sample_output=sample_output,\n",
415-
" input_translator=PyTorchInputTranslator(),\n",
416-
" output_translator=PyTorchOutputTranslator()\n",
395+
" sample_output=sample_output\n",
417396
")"
418397
]
419398
},
@@ -481,19 +460,15 @@
481460
"metadata": {},
482461
"outputs": [],
483462
"source": [
484-
"import boto3\n",
485-
"\n",
486463
"# Test with JSON input\n",
487464
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
488465
"\n",
489-
"runtime_client = boto3.client('sagemaker-runtime')\n",
490-
"response = runtime_client.invoke_endpoint(\n",
491-
" EndpointName=core_endpoint.endpoint_name,\n",
492-
" Body=json.dumps(test_data),\n",
493-
" ContentType='application/json'\n",
466+
"result = core_endpoint.invoke(\n",
467+
" body=json.dumps(test_data),\n",
468+
" content_type='application/json'\n",
494469
")\n",
495470
"\n",
496-
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
471+
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
497472
"print(f\"Input: {test_data}\")\n",
498473
"print(f\"Prediction: {prediction}\")"
499474
]
@@ -551,7 +526,9 @@
551526
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
552527
"\n",
553528
"Key patterns:\n",
554-
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
529+
"- MLflow pyfunc handles model serialization automatically\n",
530+
"- `SchemaBuilder` with sample input/output for schema inference\n",
531+
"- `core_endpoint.invoke()` for V3-style endpoint invocation\n"
555532
]
556533
}
557534
],

0 commit comments

Comments
 (0)