Skip to content

Commit 0d86285

Browse files
Merge branch 'master' into master
2 parents 11311c1 + a6303a4 commit 0d86285

File tree

9 files changed

+79
-41
lines changed

9 files changed

+79
-41
lines changed

sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def _is_bad_path(path, base):
5757
bool: True if the path is not rooted under the base directory, False otherwise.
5858
"""
5959
# joinpath will ignore base if path is absolute
60-
return not _get_resolved_path(joinpath(base, path)).startswith(base)
60+
resolved = _get_resolved_path(joinpath(base, path))
61+
return os.path.commonpath([resolved, base]) != base
6162

6263

6364
def _is_bad_link(info, base):
@@ -77,19 +78,18 @@ def _is_bad_link(info, base):
7778
return _is_bad_path(info.linkname, base=tip)
7879

7980

80-
def _get_safe_members(members):
81+
def _get_safe_members(members, base):
8182
"""A generator that yields members that are safe to extract.
8283
8384
It filters out bad paths and bad links.
8485
8586
Args:
8687
members (list): A list of members to check.
88+
base (str): The base directory for extraction.
8789
8890
Yields:
8991
tarfile.TarInfo: The tar file info.
9092
"""
91-
base = _get_resolved_path("")
92-
9393
for file_info in members:
9494
if _is_bad_path(file_info.name, base):
9595
logger.error("%s is blocked (illegal path)", file_info.name)
@@ -120,7 +120,8 @@ def custom_extractall_tarfile(tar, extract_path):
120120
if hasattr(tarfile, "data_filter"):
121121
tar.extractall(path=extract_path, filter="data")
122122
else:
123-
tar.extractall(path=extract_path, members=_get_safe_members(tar))
123+
base = _get_resolved_path(extract_path)
124+
tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base))
124125

125126

126127
def repack(inference_script, model_archive, source_dir=None): # pragma: no cover

sagemaker-mlops/tests/unit/workflow/test_repack_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_get_safe_members_all_safe():
105105
mock_member2.islnk = Mock(return_value=False)
106106

107107
members = [mock_member1, mock_member2]
108-
safe_members = list(_get_safe_members(members))
108+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
109109

110110
assert len(safe_members) == 2
111111
assert mock_member1 in safe_members
@@ -128,7 +128,7 @@ def test_get_safe_members_filters_bad_path():
128128
mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd"
129129

130130
members = [mock_member_safe, mock_member_bad]
131-
safe_members = list(_get_safe_members(members))
131+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
132132

133133
assert len(safe_members) == 1
134134
assert mock_member_safe in safe_members
@@ -152,7 +152,7 @@ def test_get_safe_members_filters_bad_symlink():
152152
mock_is_bad_link.return_value = True
153153

154154
members = [mock_member_safe, mock_member_symlink]
155-
safe_members = list(_get_safe_members(members))
155+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
156156

157157
assert len(safe_members) == 1
158158
assert mock_member_safe in safe_members
@@ -176,7 +176,7 @@ def test_get_safe_members_filters_bad_hardlink():
176176
mock_is_bad_link.return_value = True
177177

178178
members = [mock_member_safe, mock_member_hardlink]
179-
safe_members = list(_get_safe_members(members))
179+
safe_members = list(_get_safe_members(members, "/tmp/extract"))
180180

181181
assert len(safe_members) == 1
182182
assert mock_member_safe in safe_members

sagemaker-train/src/sagemaker/ai_registry/dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def get_versions(self) -> List["DataSet"]:
389389

390390
return datasets
391391

392-
@classmethod
393392
@classmethod
394393
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DataSet.get_all")
395394
def get_all(cls, max_results: Optional[int] = None, sagemaker_session=None):

sagemaker-train/src/sagemaker/train/common_utils/show_results_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def _parse_response(response_str: str) -> str:
341341

342342
def _format_score(score: float) -> str:
343343
"""Format score as percentage: 0.8333 -> '83.3%' """
344+
if score is None:
345+
return "N/A"
344346
return f"{score * 100:.1f}%"
345347

346348

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,18 @@ def _get_base_template_context(
701701
Returns:
702702
dict: Base template context dictionary
703703
"""
704+
# Generate default mlflow_experiment_name if not provided
705+
# This is required by AWS when ModelPackageGroupArn is not provided in training jobs
706+
mlflow_experiment_name = self.mlflow_experiment_name
707+
if not mlflow_experiment_name and self.mlflow_resource_arn:
708+
# Use pipeline_name as default experiment name
709+
mlflow_experiment_name = '{{ pipeline_name }}'
710+
_logger.info("No mlflow_experiment_name provided, using pipeline_name as default")
711+
704712
return {
705713
'role_arn': role_arn,
706714
'mlflow_resource_arn': self.mlflow_resource_arn,
707-
'mlflow_experiment_name': self.mlflow_experiment_name,
715+
'mlflow_experiment_name': mlflow_experiment_name,
708716
'mlflow_run_name': self.mlflow_run_name,
709717
'model_package_group_arn': model_package_group_arn,
710718
'source_model_package_arn': self._source_model_package_arn,

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,9 @@
10281028
{% if kms_key_id %},
10291029
"KmsKeyId": "{{ kms_key_id }}"
10301030
{% endif %}
1031+
},
1032+
"ModelPackageConfig": {
1033+
"ModelPackageGroupArn": "{{ model_package_group_arn }}"
10311034
}{% if dataset_uri %},
10321035
"InputDataConfig": [
10331036
{

sagemaker-train/tests/integ/ai_registry/conftest.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,8 @@ def cleanup_list():
106106
"""Track resources for cleanup."""
107107
resources = []
108108
yield resources
109-
for evaluator in resources:
109+
for resource in resources:
110110
try:
111-
from sagemaker.ai_registry.air_hub import AIRHub
112-
AIRHub.delete_hub_content(
113-
hub_content_type=evaluator.hub_content_type,
114-
hub_content_name=evaluator.name,
115-
hub_content_version=evaluator.version
116-
)
111+
resource.delete()
117112
except Exception:
118113
pass

sagemaker-train/tests/integ/ai_registry/test_dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,10 @@ def test_create_dataset_from_s3_nova_eval(self, unique_name, test_bucket, cleanu
129129
cleanup_list.append(dataset)
130130
assert dataset.name == unique_name
131131

132-
def test_get_dataset(self, unique_name, sample_jsonl_file):
132+
def test_get_dataset(self, unique_name, sample_jsonl_file, cleanup_list):
133133
"""Test retrieving dataset by name."""
134134
created = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
135+
cleanup_list.append(created)
135136
retrieved = DataSet.get(unique_name)
136137
assert retrieved.name == created.name
137138
assert retrieved.arn == created.arn
@@ -141,16 +142,18 @@ def test_get_all_datasets(self):
141142
datasets = list(DataSet.get_all(max_results=5))
142143
assert isinstance(datasets, list)
143144

144-
def test_dataset_refresh(self, unique_name, sample_jsonl_file):
145+
def test_dataset_refresh(self, unique_name, sample_jsonl_file, cleanup_list):
145146
"""Test refreshing dataset status."""
146147
dataset = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
148+
cleanup_list.append(dataset)
147149
dataset.refresh()
148150
time.sleep(3)
149151
assert dataset.status in [HubContentStatus.IMPORTING.value, HubContentStatus.AVAILABLE.value]
150152

151-
def test_dataset_get_versions(self, unique_name, sample_jsonl_file):
153+
def test_dataset_get_versions(self, unique_name, sample_jsonl_file, cleanup_list):
152154
"""Test getting dataset versions."""
153155
dataset = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
156+
cleanup_list.append(dataset)
154157
versions = dataset.get_versions()
155158
assert len(versions) >= 1
156159
assert all(isinstance(v, DataSet) for v in versions)
@@ -178,7 +181,7 @@ def test_create_dataset_version(self, unique_name, sample_jsonl_file, cleanup_li
178181
"""Test creating new dataset version."""
179182
dataset = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
180183
result = dataset.create_version(sample_jsonl_file)
181-
cleanup_list.append(cleanup_list)
184+
cleanup_list.append(dataset)
182185
assert result is True
183186

184187
def test_dataset_validation_invalid_extension(self, unique_name):

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,23 @@ def test_base_model_evaluation_uses_correct_weights(self):
144144
# Check that we have both base and custom inference steps
145145
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
146146

147-
logger.info(f"Pipeline steps: {step_names}")
147+
logger.info(f"Pipeline steps ({len(step_names)}): {step_names}")
148148

149-
# Verify both inference steps exist
150-
has_base_step = any("BaseInference" in name for name in step_names)
151-
has_custom_step = any("CustomInference" in name for name in step_names)
149+
# If no steps yet, wait a bit for pipeline to initialize
150+
if not step_names:
151+
logger.info("No steps found yet, waiting for pipeline initialization...")
152+
import time
153+
time.sleep(10)
154+
execution.refresh()
155+
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
156+
logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}")
152157

153-
assert has_base_step, "Pipeline should have EvaluateBaseInferenceModel step"
154-
assert has_custom_step, "Pipeline should have EvaluateCustomInferenceModel step"
158+
# Verify both inference steps exist (case-insensitive, flexible matching)
159+
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
160+
has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names)
161+
162+
assert has_base_step, f"Pipeline should have base inference step. Found steps: {step_names}"
163+
assert has_custom_step, f"Pipeline should have custom inference step. Found steps: {step_names}"
155164

156165
logger.info(f"✓ Pipeline has both base and custom inference steps")
157166
logger.info(f" Base model step: {'Found' if has_base_step else 'Missing'}")
@@ -175,7 +184,11 @@ def test_base_model_evaluation_uses_correct_weights(self):
175184

176185
# Display results
177186
logger.info(" Fetching results (first 10 rows)...")
178-
execution.show_results(limit=10, offset=0, show_explanations=False)
187+
try:
188+
execution.show_results(limit=10, offset=0, show_explanations=False)
189+
except (TypeError, ValueError) as e:
190+
logger.warning(f" Could not display results due to formatting issue: {e}")
191+
logger.info(" Results are available but display utility has a bug with None scores")
179192

180193
# Verify S3 output path
181194
assert execution.s3_output_path is not None
@@ -206,14 +219,19 @@ def test_base_model_evaluation_uses_correct_weights(self):
206219
if execution.status.failure_reason:
207220
logger.error(f" Failure reason: {execution.status.failure_reason}")
208221

209-
# Log step failures
222+
# Log step failures with detailed information
210223
if execution.status.step_details:
211-
logger.error("\nFailed steps:")
224+
logger.error("\n" + "=" * 80)
225+
logger.error("DETAILED STEP FAILURE INFORMATION:")
226+
logger.error("=" * 80)
212227
for step in execution.status.step_details:
213-
if "failed" in step.status.lower():
214-
logger.error(f" {step.name}: {step.status}")
215-
if step.failure_reason:
216-
logger.error(f" Reason: {step.failure_reason}")
228+
logger.error(f"\nStep: {step.name}")
229+
logger.error(f" Status: {step.status}")
230+
logger.error(f" Start Time: {step.start_time}")
231+
logger.error(f" End Time: {step.end_time}")
232+
if step.failure_reason:
233+
logger.error(f" ❌ FAILURE REASON: {step.failure_reason}")
234+
logger.error("=" * 80)
217235

218236
# Re-raise to fail the test
219237
raise
@@ -259,14 +277,23 @@ def test_base_model_false_still_works(self):
259277
execution.refresh()
260278
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
261279

262-
logger.info(f"Pipeline steps: {step_names}")
280+
logger.info(f"Pipeline steps ({len(step_names)}): {step_names}")
281+
282+
# If no steps yet, wait a bit for pipeline to initialize
283+
if not step_names:
284+
logger.info("No steps found yet, waiting for pipeline initialization...")
285+
import time
286+
time.sleep(10)
287+
execution.refresh()
288+
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
289+
logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}")
263290

264-
# Should NOT have base inference step
265-
has_base_step = any("BaseInference" in name for name in step_names)
266-
has_custom_step = any("CustomInference" in name for name in step_names)
291+
# Should NOT have base inference step (case-insensitive, flexible matching)
292+
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
293+
has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names)
267294

268-
assert not has_base_step, "Pipeline should NOT have EvaluateBaseInferenceModel step when evaluate_base_model=False"
269-
assert has_custom_step, "Pipeline should have EvaluateCustomInferenceModel step"
295+
assert not has_base_step, f"Pipeline should NOT have base inference step when evaluate_base_model=False. Found steps: {step_names}"
296+
assert has_custom_step, f"Pipeline should have custom inference step. Found steps: {step_names}"
270297

271298
logger.info(f"✓ Pipeline structure correct for evaluate_base_model=False")
272299
logger.info(f" Base model step: {'Found (ERROR!)' if has_base_step else 'Not present (correct)'}")

0 commit comments

Comments
 (0)