Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"TRAINING_JOB_PREFIX = \"e2e-v3-pytorch\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"AWS_REGION = Session().boto_region_name\n",
"PYTORCH_TRAINING_IMAGE = f\"763104351884.dkr.ecr.{AWS_REGION}.amazonaws.com/pytorch-training:1.13.1-cpu-py39\"\n",
"\n",
"# Generate unique identifiers\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"AWS_REGION = Session().boto_region_name\n",
"\n",
"# Get PyTorch training image dynamically\n",
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
Expand Down Expand Up @@ -330,25 +330,33 @@
"outputs": [],
"source": [
"# Get the latest version of the registered model\n",
"# NOTE: MLflow 3.x removed `registered_model.latest_versions`. Use\n",
"# `client.search_model_versions()` instead.\n",
"from mlflow import MlflowClient\n",
"\n",
"client = MlflowClient()\n",
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
"\n",
"latest_version = registered_model.latest_versions[0]\n",
"# Search for the latest version of the registered model (MLflow 3.x compatible)\n",
"versions = client.search_model_versions(\n",
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
" order_by=['version_number DESC'],\n",
" max_results=1\n",
")\n",
"\n",
"if not versions:\n",
" raise ValueError(f\"No versions found for model '{MLFLOW_REGISTERED_MODEL_NAME}'\")\n",
"\n",
"latest_version = versions[0]\n",
"model_version = latest_version.version\n",
"model_source = latest_version.source\n",
"\n",
"# Get S3 URL of model files (for info only)\n",
"artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)\n",
"\n",
"# MLflow model registry path to use with ModelBuilder\n",
"mlflow_model_path = f\"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}\"\n",
"\n",
"print(f\"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}\")\n",
"print(f\"Latest Version: {model_version}\")\n",
"print(f\"Source: {model_source}\")\n",
"print(f\"Model artifacts location: {artifact_uri}\")"
"print(f\"Source (artifact location): {model_source}\")\n",
"print(f\"MLflow model path for deployment: {mlflow_model_path}\")"
]
},
{
Expand Down Expand Up @@ -427,6 +435,8 @@
"from sagemaker.serve.mode.function_pointers import Mode\n",
"\n",
"# Cloud deployment to SageMaker endpoint\n",
"# Note: 'dependencies' parameter is deprecated. You may see a deprecation warning.\n",
"# Use configure_for_torchserve() for new projects.\n",
"model_builder = ModelBuilder(\n",
" mode=Mode.SAGEMAKER_ENDPOINT,\n",
" schema_builder=schema_builder,\n",
Expand Down Expand Up @@ -481,23 +491,43 @@
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"# Test with JSON input\n",
"# Test with JSON input using V3-native endpoint invocation\n",
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
"\n",
"runtime_client = boto3.client('sagemaker-runtime')\n",
"response = runtime_client.invoke_endpoint(\n",
" EndpointName=core_endpoint.endpoint_name,\n",
" Body=json.dumps(test_data),\n",
" ContentType='application/json'\n",
"result = core_endpoint.invoke(\n",
" body=json.dumps(test_data),\n",
" content_type='application/json'\n",
")\n",
"\n",
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
"print(f\"Input: {test_data}\")\n",
"print(f\"Prediction: {prediction}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test with multiple inputs\n",
"test_inputs = [\n",
" [[0.5, 0.3, 0.2, 0.1]],\n",
" [[0.9, 0.1, 0.8, 0.2]],\n",
" [[0.2, 0.7, 0.4, 0.6]]\n",
"]\n",
"\n",
"for i, test_input in enumerate(test_inputs, 1):\n",
" result = core_endpoint.invoke(\n",
" body=json.dumps(test_input),\n",
" content_type='application/json'\n",
" )\n",
" \n",
" prediction = json.loads(result.body.read().decode('utf-8'))\n",
" print(f\"Test {i} - Input {test_input}: {prediction}\")\n",
" print('-' * 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -551,7 +581,12 @@
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
"\n",
"Key patterns:\n",
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n",
"- V3-native `core_endpoint.invoke()` for inference\n",
"\n",
"**MLflow 3.x API Note:**\n",
"- Use `client.search_model_versions()` instead of the removed `registered_model.latest_versions` attribute\n",
"- Use `latest_version.source` for artifact location instead of `client.get_model_version_download_uri()`\n"
]
}
],
Expand Down
Loading