Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"TRAINING_JOB_PREFIX = \"e2e-v3-pytorch\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"AWS_REGION = Session().boto_region_name\n",
"PYTORCH_TRAINING_IMAGE = f\"763104351884.dkr.ecr.{AWS_REGION}.amazonaws.com/pytorch-training:1.13.1-cpu-py39\"\n",
"\n",
"# Generate unique identifiers\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"AWS_REGION = Session().boto_region_name\n",
"\n",
"# Get PyTorch training image dynamically\n",
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
Expand Down Expand Up @@ -330,25 +330,33 @@
"outputs": [],
"source": [
"# Get the latest version of the registered model\n",
"# NOTE: MLflow 3.x removed `registered_model.latest_versions`. Use\n",
"# `client.search_model_versions()` instead.\n",
"from mlflow import MlflowClient\n",
"\n",
"client = MlflowClient()\n",
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
"\n",
"latest_version = registered_model.latest_versions[0]\n",
"# Search for the latest version of the registered model (MLflow 3.x compatible)\n",
"versions = client.search_model_versions(\n",
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
" order_by=['version_number DESC'],\n",
" max_results=1\n",
")\n",
"\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: The order_by parameter value 'version_number DESC' should be verified against the MLflow search_model_versions API. In some MLflow versions, the supported field name is 'version_number', while in others it may be 'creation_timestamp'. If this notebook targets MLflow 3.x specifically, please confirm this is the correct field name to avoid a runtime error for users.

"if not versions:\n",
" raise ValueError(f\"No versions found for model '{MLFLOW_REGISTERED_MODEL_NAME}'\")\n",
"\n",
"latest_version = versions[0]\n",
"model_version = latest_version.version\n",
"model_source = latest_version.source\n",
"\n",
"# Get S3 URL of model files (for info only)\n",
"artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)\n",
"\n",
"# MLflow model registry path to use with ModelBuilder\n",
"mlflow_model_path = f\"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}\"\n",
"\n",
"print(f\"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}\")\n",
"print(f\"Latest Version: {model_version}\")\n",
"print(f\"Source: {model_source}\")\n",
"print(f\"Model artifacts location: {artifact_uri}\")"
"print(f\"Source (artifact location): {model_source}\")\n",
"print(f\"MLflow model path for deployment: {mlflow_model_path}\")"
]
},
{
Expand Down Expand Up @@ -427,6 +435,8 @@
"from sagemaker.serve.mode.function_pointers import Mode\n",
"\n",
"# Cloud deployment to SageMaker endpoint\n",
"# Note: 'dependencies' parameter is deprecated. You may see a deprecation warning.\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good addition of the deprecation note for the dependencies parameter. However, the comment suggests using configure_for_torchserve() — could you verify this is the correct V3 replacement method name and add a brief code example or link? Users may not know how to apply this guidance.

"# Use configure_for_torchserve() for new projects.\n",
"model_builder = ModelBuilder(\n",
" mode=Mode.SAGEMAKER_ENDPOINT,\n",
" schema_builder=schema_builder,\n",
Expand Down Expand Up @@ -481,23 +491,43 @@
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"# Test with JSON input\n",
"# Test with JSON input using V3-native endpoint invocation\n",
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
"\n",
"runtime_client = boto3.client('sagemaker-runtime')\n",
"response = runtime_client.invoke_endpoint(\n",
" EndpointName=core_endpoint.endpoint_name,\n",
" Body=json.dumps(test_data),\n",
" ContentType='application/json'\n",
"result = core_endpoint.invoke(\n",
" body=json.dumps(test_data),\n",
" content_type='application/json'\n",
")\n",
"\n",
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
"print(f\"Input: {test_data}\")\n",
"print(f\"Prediction: {prediction}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test with multiple inputs\n",
"test_inputs = [\n",
" [[0.5, 0.3, 0.2, 0.1]],\n",
" [[0.9, 0.1, 0.8, 0.2]],\n",
" [[0.2, 0.7, 0.4, 0.6]]\n",
"]\n",
"\n",
"for i, test_input in enumerate(test_inputs, 1):\n",
" result = core_endpoint.invoke(\n",
" body=json.dumps(test_input),\n",
" content_type='application/json'\n",
" )\n",
" \n",
" prediction = json.loads(result.body.read().decode('utf-8'))\n",
" print(f\"Test {i} - Input {test_input}: {prediction}\")\n",
" print('-' * 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -551,7 +581,12 @@
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
"\n",
"Key patterns:\n",
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n",
"- V3-native `core_endpoint.invoke()` for inference\n",
"\n",
"**MLflow 3.x API Note:**\n",
"- Use `client.search_model_versions()` instead of the removed `registered_model.latest_versions` attribute\n",
"- Use `latest_version.source` for artifact location instead of `client.get_model_version_download_uri()`\n"
]
}
],
Expand Down
Loading