Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"outputs": [],
"source": [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The %pip install cell only installs mlflow>=3.4.0, but the training code's requirements.txt (referenced in Step 4) presumably still pins mlflow==3.4.0. The PR description mentions this inconsistency but the diff only shows changes to the notebook cells and the dependencies dict in ModelBuilder. If requirements.txt is a separate file in the repo, it should also be updated to use >= constraints for consistency. Could you confirm whether requirements.txt is part of this PR or needs a separate change?

"# Install fix for MLflow path resolution issues\n",
Comment thread
aviruthen marked this conversation as resolved.
"# Note: Version pins (mlflow==3.4.0) may need updating - check for latest compatible versions\n",
"%pip install mlflow==3.4.0"
]
},
Expand All @@ -60,7 +61,9 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
Comment thread
aviruthen marked this conversation as resolved.
"import uuid\n",
"import boto3\n",
"from sagemaker.core import image_uris\n",
"from sagemaker.core.helper.session_helper import Session\n",
"\n",
Expand All @@ -71,7 +74,8 @@
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"sagemaker_session = Session()\n",
"AWS_REGION = sagemaker_session.boto_region_name\n",
"\n",
"# Get PyTorch training image dynamically\n",
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
Expand Down Expand Up @@ -297,6 +301,7 @@
"\n",
"# Training on SageMaker managed infrastructure\n",
"model_trainer = ModelTrainer(\n",
" sagemaker_session=sagemaker_session,\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good addition of sagemaker_session=sagemaker_session. However, per SDK conventions, note that sagemaker_session is typically placed as the last parameter in constructor calls (matching the convention that optional session parameters come at the end). Consider moving it after the other parameters for consistency with other V3 example notebooks, though this is a minor style point for a notebook.

" training_image=PYTORCH_TRAINING_IMAGE,\n",
" source_code=SourceCode(\n",
" source_dir=training_code_dir,\n",
Expand Down Expand Up @@ -333,9 +338,14 @@
"from mlflow import MlflowClient\n",
"\n",
"client = MlflowClient()\n",
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
"\n",
"latest_version = registered_model.latest_versions[0]\n",
"# Use search_model_versions (compatible with MLflow 3.x)\n",
"model_versions = client.search_model_versions(\n",
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: trailing closing paren on max_results=1 line is missing a trailing comma, and the list ['creation_timestamp DESC'] uses single quotes inside double-quoted JSON string — this is fine for Python but worth noting for consistency. More importantly, consider adding a guard for the case where model_versions is empty (no versions registered yet), e.g.:

if not model_versions:
    raise RuntimeError(f"No model versions found for '{MLFLOW_REGISTERED_MODEL_NAME}'")

This would give users a clear error message instead of an IndexError.

" order_by=['version_number DESC'],\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order_by=['version_number DESC'] parameter — please verify this is the correct field name for MLflow 3.x's search_model_versions. In some MLflow versions, the field is version_number, but in others it may be creation_timestamp or just version. If this is incorrect, the notebook will fail at runtime. The MLflow 3.x docs indicate order_by supports "version_number DESC" but it's worth double-checking against the version pinned in the install cell.

" max_results=1\n",
")\n",
"latest_version = model_versions[0]\n",
"model_version = latest_version.version\n",
"model_source = latest_version.source\n",
"\n",
Expand Down Expand Up @@ -366,54 +376,23 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import torch\n",
"from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n",
"from sagemaker.serve.builder.schema_builder import SchemaBuilder\n",
"\n",
"# =============================================================================\n",
"# Custom translators for PyTorch tensor conversion\n",
"# \n",
"# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n",
"# These translators handle the conversion between JSON payloads and PyTorch tensors.\n",
"# Schema Builder for MLflow Model\n",
"#\n",
"# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n",
"# serialization/deserialization automatically. We only need to provide sample\n",
"# input/output for schema inference - no custom translators needed.\n",
"# =============================================================================\n",
"\n",
"class PyTorchInputTranslator(CustomPayloadTranslator):\n",
" \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n",
" def __init__(self):\n",
" super().__init__(content_type='application/json', accept_type='application/json')\n",
" \n",
" def serialize_payload_to_bytes(self, payload: object) -> bytes:\n",
" if isinstance(payload, torch.Tensor):\n",
" return json.dumps(payload.tolist()).encode('utf-8')\n",
" return json.dumps(payload).encode('utf-8')\n",
" \n",
" def deserialize_payload_from_stream(self, stream) -> object:\n",
" data = json.load(stream)\n",
" return torch.tensor(data, dtype=torch.float32)\n",
"\n",
"class PyTorchOutputTranslator(CustomPayloadTranslator):\n",
" \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n",
" def __init__(self):\n",
" super().__init__(content_type='application/json', accept_type='application/json')\n",
" \n",
" def serialize_payload_to_bytes(self, payload: object) -> bytes:\n",
" if isinstance(payload, torch.Tensor):\n",
" return json.dumps(payload.tolist()).encode('utf-8')\n",
" return json.dumps(payload).encode('utf-8')\n",
" \n",
" def deserialize_payload_from_stream(self, stream) -> object:\n",
" return json.load(stream)\n",
"\n",
"# Sample input/output for schema inference\n",
"sample_input = [[0.1, 0.2, 0.3, 0.4]]\n",
"sample_output = [[0.8, 0.2]]\n",
"\n",
"schema_builder = SchemaBuilder(\n",
" sample_input=sample_input,\n",
" sample_output=sample_output,\n",
" input_translator=PyTorchInputTranslator(),\n",
" output_translator=PyTorchOutputTranslator()\n",
" sample_output=sample_output\n",
")"
]
},
Expand Down Expand Up @@ -481,19 +460,15 @@
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"# Test with JSON input\n",
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
"\n",
"runtime_client = boto3.client('sagemaker-runtime')\n",
"response = runtime_client.invoke_endpoint(\n",
" EndpointName=core_endpoint.endpoint_name,\n",
" Body=json.dumps(test_data),\n",
" ContentType='application/json'\n",
"result = core_endpoint.invoke(\n",
" body=json.dumps(test_data),\n",
" content_type='application/json'\n",
Comment thread
aviruthen marked this conversation as resolved.
")\n",
"\n",
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
"print(f\"Input: {test_data}\")\n",
"print(f\"Prediction: {prediction}\")"
]
Expand Down Expand Up @@ -551,7 +526,9 @@
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
"\n",
"Key patterns:\n",
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
"- MLflow pyfunc handles model serialization automatically\n",
"- `SchemaBuilder` with sample input/output for schema inference\n",
"- `core_endpoint.invoke()` for V3-style endpoint invocation\n"
]
}
],
Expand Down
Loading