|
35 | 35 | "outputs": [], |
36 | 36 | "source": [ |
37 | 37 | "# Install fix for MLflow path resolution issues\n", |
| 38 | + "# Note: Version pins (mlflow==3.4.0) may need updating - check for latest compatible versions\n", |
38 | 39 | "%pip install mlflow==3.4.0" |
39 | 40 | ] |
40 | 41 | }, |
|
60 | 61 | "metadata": {}, |
61 | 62 | "outputs": [], |
62 | 63 | "source": [ |
| 64 | + "import json\n", |
63 | 65 | "import uuid\n", |
| 66 | + "import boto3\n", |
64 | 67 | "from sagemaker.core import image_uris\n", |
65 | 68 | "from sagemaker.core.helper.session_helper import Session\n", |
66 | 69 | "\n", |
|
71 | 74 | "MLFLOW_TRACKING_ARN = \"XXXXX\"\n", |
72 | 75 | "\n", |
73 | 76 | "# AWS Configuration\n", |
74 | | - "AWS_REGION = Session.boto_region_name\n", |
| 77 | + "sagemaker_session = Session()\n", |
| 78 | + "AWS_REGION = sagemaker_session.boto_region_name\n", |
75 | 79 | "\n", |
76 | 80 | "# Get PyTorch training image dynamically\n", |
77 | 81 | "PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n", |
|
297 | 301 | "\n", |
298 | 302 | "# Training on SageMaker managed infrastructure\n", |
299 | 303 | "model_trainer = ModelTrainer(\n", |
| 304 | + " sagemaker_session=sagemaker_session,\n", |
300 | 305 | " training_image=PYTORCH_TRAINING_IMAGE,\n", |
301 | 306 | " source_code=SourceCode(\n", |
302 | 307 | " source_dir=training_code_dir,\n", |
|
333 | 338 | "from mlflow import MlflowClient\n", |
334 | 339 | "\n", |
335 | 340 | "client = MlflowClient()\n", |
336 | | - "registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n", |
337 | 341 | "\n", |
338 | | - "latest_version = registered_model.latest_versions[0]\n", |
| 342 | + "# Use search_model_versions (compatible with MLflow 3.x)\n", |
| 343 | + "model_versions = client.search_model_versions(\n", |
| 344 | + " filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n", |
| 345 | + " order_by=['version_number DESC'],\n", |
| 346 | + " max_results=1\n", |
| 347 | + ")\n", |
| 348 | + "latest_version = model_versions[0]\n", |
339 | 349 | "model_version = latest_version.version\n", |
340 | 350 | "model_source = latest_version.source\n", |
341 | 351 | "\n", |
|
366 | 376 | "metadata": {}, |
367 | 377 | "outputs": [], |
368 | 378 | "source": [ |
369 | | - "import json\n", |
370 | | - "import torch\n", |
371 | | - "from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n", |
372 | 379 | "from sagemaker.serve.builder.schema_builder import SchemaBuilder\n", |
373 | 380 | "\n", |
374 | 381 | "# =============================================================================\n", |
375 | | - "# Custom translators for PyTorch tensor conversion\n", |
376 | | - "# \n", |
377 | | - "# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n", |
378 | | - "# These translators handle the conversion between JSON payloads and PyTorch tensors.\n", |
| 382 | + "# Schema Builder for MLflow Model\n", |
| 383 | + "#\n", |
| 384 | + "# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n", |
| 385 | + "# serialization/deserialization automatically. We only need to provide sample\n", |
| 386 | + "# input/output for schema inference - no custom translators needed.\n", |
379 | 387 | "# =============================================================================\n", |
380 | 388 | "\n", |
381 | | - "class PyTorchInputTranslator(CustomPayloadTranslator):\n", |
382 | | - " \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n", |
383 | | - " def __init__(self):\n", |
384 | | - " super().__init__(content_type='application/json', accept_type='application/json')\n", |
385 | | - " \n", |
386 | | - " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", |
387 | | - " if isinstance(payload, torch.Tensor):\n", |
388 | | - " return json.dumps(payload.tolist()).encode('utf-8')\n", |
389 | | - " return json.dumps(payload).encode('utf-8')\n", |
390 | | - " \n", |
391 | | - " def deserialize_payload_from_stream(self, stream) -> object:\n", |
392 | | - " data = json.load(stream)\n", |
393 | | - " return torch.tensor(data, dtype=torch.float32)\n", |
394 | | - "\n", |
395 | | - "class PyTorchOutputTranslator(CustomPayloadTranslator):\n", |
396 | | - " \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n", |
397 | | - " def __init__(self):\n", |
398 | | - " super().__init__(content_type='application/json', accept_type='application/json')\n", |
399 | | - " \n", |
400 | | - " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", |
401 | | - " if isinstance(payload, torch.Tensor):\n", |
402 | | - " return json.dumps(payload.tolist()).encode('utf-8')\n", |
403 | | - " return json.dumps(payload).encode('utf-8')\n", |
404 | | - " \n", |
405 | | - " def deserialize_payload_from_stream(self, stream) -> object:\n", |
406 | | - " return json.load(stream)\n", |
407 | | - "\n", |
408 | 389 | "# Sample input/output for schema inference\n", |
409 | 390 | "sample_input = [[0.1, 0.2, 0.3, 0.4]]\n", |
410 | 391 | "sample_output = [[0.8, 0.2]]\n", |
411 | 392 | "\n", |
412 | 393 | "schema_builder = SchemaBuilder(\n", |
413 | 394 | " sample_input=sample_input,\n", |
414 | | - " sample_output=sample_output,\n", |
415 | | - " input_translator=PyTorchInputTranslator(),\n", |
416 | | - " output_translator=PyTorchOutputTranslator()\n", |
| 395 | + " sample_output=sample_output\n", |
417 | 396 | ")" |
418 | 397 | ] |
419 | 398 | }, |
|
481 | 460 | "metadata": {}, |
482 | 461 | "outputs": [], |
483 | 462 | "source": [ |
484 | | - "import boto3\n", |
485 | | - "\n", |
486 | 463 | "# Test with JSON input\n", |
487 | 464 | "test_data = [[0.1, 0.2, 0.3, 0.4]]\n", |
488 | 465 | "\n", |
489 | | - "runtime_client = boto3.client('sagemaker-runtime')\n", |
490 | | - "response = runtime_client.invoke_endpoint(\n", |
491 | | - " EndpointName=core_endpoint.endpoint_name,\n", |
492 | | - " Body=json.dumps(test_data),\n", |
493 | | - " ContentType='application/json'\n", |
| 466 | + "result = core_endpoint.invoke(\n", |
| 467 | + " body=json.dumps(test_data),\n", |
| 468 | + " content_type='application/json'\n", |
494 | 469 | ")\n", |
495 | 470 | "\n", |
496 | | - "prediction = json.loads(response['Body'].read().decode('utf-8'))\n", |
| 471 | + "prediction = json.loads(result.body.read().decode('utf-8'))\n", |
497 | 472 | "print(f\"Input: {test_data}\")\n", |
498 | 473 | "print(f\"Prediction: {prediction}\")" |
499 | 474 | ] |
|
551 | 526 | "- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n", |
552 | 527 | "\n", |
553 | 528 | "Key patterns:\n", |
554 | | - "- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n" |
| 529 | + "- MLflow pyfunc handles model serialization automatically\n", |
| 530 | + "- `SchemaBuilder` with sample input/output for schema inference\n", |
| 531 | + "- `core_endpoint.invoke()` for V3-style endpoint invocation\n" |
555 | 532 | ] |
556 | 533 | } |
557 | 534 | ], |
|
0 commit comments