-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feature: MLFlow E2E Example Notebook (5513) #5701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
| "outputs": [], | ||
| "source": [ | ||
| "# Install fix for MLflow path resolution issues\n", | ||
|
aviruthen marked this conversation as resolved.
|
||
| "# Note: Version pins (mlflow==3.4.0) may need updating - check for latest compatible versions\n", | ||
| "%pip install mlflow==3.4.0" | ||
| ] | ||
| }, | ||
|
|
@@ -60,7 +61,9 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import json\n", | ||
|
aviruthen marked this conversation as resolved.
|
||
| "import uuid\n", | ||
| "import boto3\n", | ||
| "from sagemaker.core import image_uris\n", | ||
| "from sagemaker.core.helper.session_helper import Session\n", | ||
| "\n", | ||
|
|
@@ -71,7 +74,8 @@ | |
| "MLFLOW_TRACKING_ARN = \"XXXXX\"\n", | ||
| "\n", | ||
| "# AWS Configuration\n", | ||
| "AWS_REGION = Session.boto_region_name\n", | ||
| "sagemaker_session = Session()\n", | ||
| "AWS_REGION = sagemaker_session.boto_region_name\n", | ||
| "\n", | ||
| "# Get PyTorch training image dynamically\n", | ||
| "PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n", | ||
|
|
@@ -297,6 +301,7 @@ | |
| "\n", | ||
| "# Training on SageMaker managed infrastructure\n", | ||
| "model_trainer = ModelTrainer(\n", | ||
| " sagemaker_session=sagemaker_session,\n", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good addition of |
||
| " training_image=PYTORCH_TRAINING_IMAGE,\n", | ||
| " source_code=SourceCode(\n", | ||
| " source_dir=training_code_dir,\n", | ||
|
|
@@ -333,9 +338,14 @@ | |
| "from mlflow import MlflowClient\n", | ||
| "\n", | ||
| "client = MlflowClient()\n", | ||
| "registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n", | ||
| "\n", | ||
| "latest_version = registered_model.latest_versions[0]\n", | ||
| "# Use search_model_versions (compatible with MLflow 3.x)\n", | ||
| "model_versions = client.search_model_versions(\n", | ||
| " filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: trailing closing paren on if not model_versions:
raise RuntimeError(f"No model versions found for '{MLFLOW_REGISTERED_MODEL_NAME}'")This would give users a clear error message instead of an |
||
| " order_by=['version_number DESC'],\n", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| " max_results=1\n", | ||
| ")\n", | ||
| "latest_version = model_versions[0]\n", | ||
| "model_version = latest_version.version\n", | ||
| "model_source = latest_version.source\n", | ||
| "\n", | ||
|
|
@@ -366,54 +376,23 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import json\n", | ||
| "import torch\n", | ||
| "from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n", | ||
| "from sagemaker.serve.builder.schema_builder import SchemaBuilder\n", | ||
| "\n", | ||
| "# =============================================================================\n", | ||
| "# Custom translators for PyTorch tensor conversion\n", | ||
| "# \n", | ||
| "# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n", | ||
| "# These translators handle the conversion between JSON payloads and PyTorch tensors.\n", | ||
| "# Schema Builder for MLflow Model\n", | ||
| "#\n", | ||
| "# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n", | ||
| "# serialization/deserialization automatically. We only need to provide sample\n", | ||
| "# input/output for schema inference - no custom translators needed.\n", | ||
| "# =============================================================================\n", | ||
| "\n", | ||
| "class PyTorchInputTranslator(CustomPayloadTranslator):\n", | ||
| " \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n", | ||
| " def __init__(self):\n", | ||
| " super().__init__(content_type='application/json', accept_type='application/json')\n", | ||
| " \n", | ||
| " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", | ||
| " if isinstance(payload, torch.Tensor):\n", | ||
| " return json.dumps(payload.tolist()).encode('utf-8')\n", | ||
| " return json.dumps(payload).encode('utf-8')\n", | ||
| " \n", | ||
| " def deserialize_payload_from_stream(self, stream) -> object:\n", | ||
| " data = json.load(stream)\n", | ||
| " return torch.tensor(data, dtype=torch.float32)\n", | ||
| "\n", | ||
| "class PyTorchOutputTranslator(CustomPayloadTranslator):\n", | ||
| " \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n", | ||
| " def __init__(self):\n", | ||
| " super().__init__(content_type='application/json', accept_type='application/json')\n", | ||
| " \n", | ||
| " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", | ||
| " if isinstance(payload, torch.Tensor):\n", | ||
| " return json.dumps(payload.tolist()).encode('utf-8')\n", | ||
| " return json.dumps(payload).encode('utf-8')\n", | ||
| " \n", | ||
| " def deserialize_payload_from_stream(self, stream) -> object:\n", | ||
| " return json.load(stream)\n", | ||
| "\n", | ||
| "# Sample input/output for schema inference\n", | ||
| "sample_input = [[0.1, 0.2, 0.3, 0.4]]\n", | ||
| "sample_output = [[0.8, 0.2]]\n", | ||
| "\n", | ||
| "schema_builder = SchemaBuilder(\n", | ||
| " sample_input=sample_input,\n", | ||
| " sample_output=sample_output,\n", | ||
| " input_translator=PyTorchInputTranslator(),\n", | ||
| " output_translator=PyTorchOutputTranslator()\n", | ||
| " sample_output=sample_output\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -481,19 +460,15 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import boto3\n", | ||
| "\n", | ||
| "# Test with JSON input\n", | ||
| "test_data = [[0.1, 0.2, 0.3, 0.4]]\n", | ||
| "\n", | ||
| "runtime_client = boto3.client('sagemaker-runtime')\n", | ||
| "response = runtime_client.invoke_endpoint(\n", | ||
| " EndpointName=core_endpoint.endpoint_name,\n", | ||
| " Body=json.dumps(test_data),\n", | ||
| " ContentType='application/json'\n", | ||
| "result = core_endpoint.invoke(\n", | ||
| " body=json.dumps(test_data),\n", | ||
| " content_type='application/json'\n", | ||
|
aviruthen marked this conversation as resolved.
|
||
| ")\n", | ||
| "\n", | ||
| "prediction = json.loads(response['Body'].read().decode('utf-8'))\n", | ||
| "prediction = json.loads(result.body.read().decode('utf-8'))\n", | ||
| "print(f\"Input: {test_data}\")\n", | ||
| "print(f\"Prediction: {prediction}\")" | ||
| ] | ||
|
|
@@ -551,7 +526,9 @@ | |
| "- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n", | ||
| "\n", | ||
| "Key patterns:\n", | ||
| "- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n" | ||
| "- MLflow pyfunc handles model serialization automatically\n", | ||
| "- `SchemaBuilder` with sample input/output for schema inference\n", | ||
| "- `core_endpoint.invoke()` for V3-style endpoint invocation\n" | ||
| ] | ||
| } | ||
| ], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
%pip installcell only installsmlflow>=3.4.0, but the training code'srequirements.txt(referenced in Step 4) presumably still pinsmlflow==3.4.0. The PR description mentions this inconsistency but the diff only shows changes to the notebook cells and thedependenciesdict inModelBuilder. Ifrequirements.txtis a separate file in the repo, it should also be updated to use>=constraints for consistency. Could you confirm whetherrequirements.txtis part of this PR or needs a separate change?