Skip to content

Commit b86c2ac

Browse files
authored
## fix: resolve MLflow app discovery issues (#5924)
* fix: bypass SageMakerClient singleton for cross-region model package resolution The SageMakerClient singleton caches the first region it is initialized with and ignores subsequent region parameters. This causes Nova integ tests (which run in us-east-1) to fail when the singleton was already created with us-west-2 by an earlier test in the same process. Errors observed: - ModelPackageGroup arn:aws:sagemaker:us-west-2:784379639078:model-package-group/sdk-test-finetuned-models does not exist - DescribeModelPackage: ARN should be scoped to correct region: us-west-2 Fix: use session.boto_session.client("sagemaker") directly instead of ModelPackageGroup.get() / ModelPackage.get() in the three call sites that resolve model package resources. This respects the session's actual region without depending on the singleton's cached state. * test: update unit tests * fix: handle missing pipeline version context in lineage update _update_pipeline_lineage assumed the version context always exists. When it's been deleted or never created (e.g. prior run failure), DescribeContext throws ResourceNotFound. Now catches the error and recreates the version context with proper associations. * fix(test): add mlflow_resource_arn fixture that auto-discovers or creates app Replace hard-coded MLflow app ARN with a conftest fixture that finds an existing ready app or creates a temporary one (cleaned up after tests). Prevents failures when the hard-coded app is deleted or quota is full. X-AI-Prompt: add self-healing mlflow fixture for llm_as_judge integ tests X-AI-Tool: kiro-cli * fix(test): use correct response key "Summaries" for list_mlflow_apps API * mark two slow tests as not serial * fix: use correct response key "Summaries" in _resolve_mlflow_resource_arn * replace not-existing mlflow app * refactor: use session.sagemaker_client instead of boto_session.client Per SDK coding standards, avoid calling boto3 directly. Use the session's sagemaker_client attribute which already has the correct region bound at session creation time. * revert: remove SageMakerClient singleton bypass from feature code * test: mark TestLLMAsJudgeBaseModelFix as serial Tests share the same pipeline definition and conflict when run in parallel (Pipeline has been modified since your last read). X-AI-Prompt: mark llm_as_judge_base_model_fix as serial X-AI-Tool: kiro-cli
1 parent 97a4afe commit b86c2ac

5 files changed

Lines changed: 130 additions & 11 deletions

File tree

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,39 @@ def _update_pipeline_lineage(
359359

360360
# If pipeline lineage exists then determine whether to create a new version.
361361
pipeline_context: Context = self._get_pipeline_context()
362-
current_pipeline_version_context: Context = self._get_pipeline_version_context(
363-
last_update_time=pipeline_context.properties[LAST_UPDATE_TIME]
364-
)
362+
try:
363+
current_pipeline_version_context: Context = self._get_pipeline_version_context(
364+
last_update_time=pipeline_context.properties[LAST_UPDATE_TIME]
365+
)
366+
except ClientError as e:
367+
if e.response[ERROR][CODE] == RESOURCE_NOT_FOUND:
368+
# Pipeline version context does not exist (possibly deleted or never created).
369+
# Create a new pipeline version context and its associations.
370+
logger.info(
371+
"Pipeline version context not found. Creating new pipeline version lineage."
372+
)
373+
pipeline_context.properties["LastUpdateTime"] = self.pipeline[
374+
LAST_MODIFIED_TIME
375+
].strftime("%s")
376+
PipelineLineageEntityHandler.update_pipeline_context(
377+
pipeline_context=pipeline_context
378+
)
379+
new_pipeline_version_context: Context = self._create_pipeline_version_lineage()
380+
self._add_associations_for_pipeline(
381+
pipeline_context_arn=pipeline_context.context_arn,
382+
pipeline_versions_context_arn=new_pipeline_version_context.context_arn,
383+
input_feature_group_contexts=input_feature_group_contexts,
384+
input_raw_data_artifacts=input_raw_data_artifacts,
385+
output_feature_group_contexts=output_feature_group_contexts,
386+
transformation_code_artifact=transformation_code_artifact,
387+
)
388+
LineageAssociationHandler.add_pipeline_and_pipeline_version_association(
389+
pipeline_context_arn=pipeline_context.context_arn,
390+
pipeline_version_context_arn=new_pipeline_version_context.context_arn,
391+
sagemaker_session=self.sagemaker_session,
392+
)
393+
return
394+
raise e
365395
upstream_feature_group_associations: Iterator[AssociationSummary] = (
366396
LineageAssociationHandler.list_upstream_associations(
367397
# pylint: disable=no-member

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn: Optiona
125125
mlflow_apps_list = []
126126
paginator = sm_client.get_paginator("list_mlflow_apps")
127127
for page in paginator.paginate():
128-
mlflow_apps_list.extend(page.get("MlflowApps", []))
128+
mlflow_apps_list.extend(page.get("Summaries", []))
129129

130130
logger.info("Found %d MLflow apps: %s", len(mlflow_apps_list),
131131
[(a.get("Name", "?"), a.get("Status", "?"), a.get("MlflowVersion", "?")) for a in mlflow_apps_list])

sagemaker-train/tests/integ/train/conftest.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,92 @@ def sagemaker_session_us_east_1():
4848
"""Create a SageMaker session in us-east-1 for Nova model tests."""
4949
boto_session = boto3.Session(region_name=NOVA_REGION)
5050
return Session(boto_session=boto_session)
51+
52+
53+
import time
54+
import logging
55+
56+
logger = logging.getLogger(__name__)
57+
58+
59+
@pytest.fixture(scope="module")
60+
def mlflow_resource_arn():
61+
"""Discover or create an MLflow app for integ tests, clean up if created.
62+
63+
Looks for an existing MLflow app in Created/Updated state. If none exists,
64+
creates one and deletes it after the test module finishes.
65+
"""
66+
region = os.environ.get("AWS_DEFAULT_REGION", DEFAULT_REGION)
67+
sm_client = boto3.client("sagemaker", region_name=region)
68+
created_arn = None
69+
70+
# Try to find an existing ready app
71+
try:
72+
paginator = sm_client.get_paginator("list_mlflow_apps")
73+
for page in paginator.paginate():
74+
for app in page.get("Summaries", []):
75+
if app.get("Status") in ("Created", "Updated"):
76+
logger.info(f"Using existing MLflow app: {app['Arn']}")
77+
yield app["Arn"]
78+
return
79+
except Exception as e:
80+
logger.warning(f"Failed to list MLflow apps: {e}")
81+
82+
# No ready app found — create one
83+
logger.info("No ready MLflow app found. Creating one for integ tests...")
84+
sts_client = boto3.client("sts", region_name=region)
85+
account_id = sts_client.get_caller_identity()["Account"]
86+
app_name = f"integ-test-mlflow-{int(time.time())}"
87+
artifact_store_uri = f"s3://sagemaker-{region}-{account_id}/mlflow-artifacts"
88+
89+
# Ensure bucket/prefix exists
90+
s3_client = boto3.client("s3", region_name=region)
91+
bucket_name = f"sagemaker-{region}-{account_id}"
92+
try:
93+
s3_client.head_bucket(Bucket=bucket_name)
94+
except Exception:
95+
if region == "us-east-1":
96+
s3_client.create_bucket(Bucket=bucket_name)
97+
else:
98+
s3_client.create_bucket(
99+
Bucket=bucket_name,
100+
CreateBucketConfiguration={"LocationConstraint": region},
101+
)
102+
try:
103+
s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/")
104+
except Exception:
105+
pass
106+
107+
# Get execution role
108+
from sagemaker.train.defaults import TrainDefaults
109+
boto_session = boto3.Session(region_name=region)
110+
sagemaker_session = Session(boto_session=boto_session)
111+
role_arn = TrainDefaults.get_role(role=None, sagemaker_session=sagemaker_session)
112+
113+
resp = sm_client.create_mlflow_app(
114+
Name=app_name,
115+
ArtifactStoreUri=artifact_store_uri,
116+
RoleArn=role_arn,
117+
AccountDefaultStatus="DISABLED",
118+
)
119+
created_arn = resp["Arn"]
120+
logger.info(f"Created MLflow app: {created_arn}")
121+
122+
# Wait for it to become ready
123+
for _ in range(60):
124+
desc = sm_client.describe_mlflow_app(Arn=created_arn)
125+
status = desc.get("Status")
126+
if status in ("Created", "Updated"):
127+
break
128+
if status in ("Failed", "CreateFailed", "DeleteFailed"):
129+
pytest.skip(f"MLflow app creation failed: {desc.get('FailureReason')}")
130+
time.sleep(10)
131+
132+
yield created_arn
133+
134+
# Cleanup
135+
logger.info(f"Cleaning up MLflow app: {created_arn}")
136+
try:
137+
sm_client.delete_mlflow_app(Arn=created_arn)
138+
except Exception as e:
139+
logger.warning(f"Failed to delete MLflow app {created_arn}: {e}")

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
class TestLLMAsJudgeBaseModelFix:
8282
"""Integration test for base model fix in LLMAsJudgeEvaluator"""
8383

84-
def test_base_model_evaluation_uses_correct_weights(self):
84+
def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
8585
"""
8686
Test that base model evaluation uses original base model weights.
8787
@@ -115,7 +115,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
115115
custom_metrics=TEST_CONFIG["custom_metrics_json"],
116116
s3_output_path=TEST_CONFIG["s3_output_path"],
117117
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
118-
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
118+
mlflow_resource_arn=mlflow_resource_arn,
119119
)
120120

121121
# Verify evaluator configuration
@@ -251,7 +251,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
251251
# Re-raise to fail the test
252252
raise
253253

254-
def test_base_model_false_still_works(self):
254+
def test_base_model_false_still_works(self, mlflow_resource_arn):
255255
"""
256256
Test that evaluate_base_model=False still works correctly (backward compatibility).
257257
@@ -272,7 +272,7 @@ def test_base_model_false_still_works(self):
272272
builtin_metrics=TEST_CONFIG["builtin_metrics"],
273273
s3_output_path=TEST_CONFIG["s3_output_path"],
274274
evaluate_base_model=False, # Only evaluate custom model
275-
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
275+
mlflow_resource_arn=mlflow_resource_arn,
276276
)
277277

278278
# Verify evaluator configuration

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_client, moc
8383
mock_get_domain.return_value = "d-123456789"
8484
mock_sm_client = Mock()
8585
mock_paginator = Mock()
86-
mock_paginator.paginate.return_value = [{"MlflowApps": []}]
86+
mock_paginator.paginate.return_value = [{"Summaries": []}]
8787
mock_sm_client.get_paginator.return_value = mock_paginator
8888
mock_get_client.return_value = mock_sm_client
8989
expected_arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/new-app"
@@ -633,7 +633,7 @@ def test_upgrades_when_below_min_version(self, mock_get_client, mock_upgrade, mo
633633
}
634634
mock_sm_client = Mock()
635635
mock_paginator = Mock()
636-
mock_paginator.paginate.return_value = [{"MlflowApps": [old_app]}]
636+
mock_paginator.paginate.return_value = [{"Summaries": [old_app]}]
637637
mock_sm_client.get_paginator.return_value = mock_paginator
638638
mock_get_client.return_value = mock_sm_client
639639

@@ -659,7 +659,7 @@ def test_no_upgrade_when_meets_version(self, mock_get_client, mock_domain):
659659
}
660660
mock_sm_client = Mock()
661661
mock_paginator = Mock()
662-
mock_paginator.paginate.return_value = [{"MlflowApps": [app]}]
662+
mock_paginator.paginate.return_value = [{"Summaries": [app]}]
663663
mock_sm_client.get_paginator.return_value = mock_paginator
664664
mock_get_client.return_value = mock_sm_client
665665

0 commit comments

Comments
 (0)