Skip to content

Commit 779075c

Browse files
author
Jonathan Lee
committed
fix: multiartifact support for JS
1 parent b839617 commit 779075c

2 files changed

Lines changed: 257 additions & 3 deletions

File tree

sagemaker-core/src/sagemaker/core/jumpstart/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,9 +1726,13 @@ def get_top_config_from_ranking(
17261726
ranked_config_names = rankings.rankings
17271727
for config_name in ranked_config_names:
17281728
resolved_config = self.configs[config_name].resolved_config
1729-
if instance_type and instance_type not in getattr(
1730-
resolved_config, instance_type_attribute
1731-
):
1729+
# Fix: resolved_config is a dict (from deep_override_dict), not an object
1730+
# Use dict.get() for dicts, getattr() for objects
1731+
if isinstance(resolved_config, dict):
1732+
supported_instance_types = resolved_config.get(instance_type_attribute, [])
1733+
else:
1734+
supported_instance_types = getattr(resolved_config, instance_type_attribute, [])
1735+
if instance_type and instance_type not in supported_instance_types:
17321736
continue
17331737
return self.configs[config_name]
17341738

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""
2+
Unit tests for JumpStartMetadataConfigs.get_top_config_from_ranking()
3+
4+
These tests verify that config selection correctly filters by instance_type,
5+
handling the case where resolved_config is a dict (from deep_override_dict).
6+
7+
This addresses a bug where getattr() was incorrectly used on dict objects
8+
instead of dict key access, causing instance_type filtering to fail.
9+
"""
10+
11+
import pytest
12+
from typing import Any, Dict, List, Optional
13+
from unittest.mock import MagicMock, patch
14+
15+
16+
class TestGetTopConfigFromRanking:
17+
"""Tests for get_top_config_from_ranking method."""
18+
19+
@pytest.fixture
20+
def mock_gpu_config(self):
21+
"""Create a mock GPU config with dict resolved_config."""
22+
config = MagicMock()
23+
config.config_name = "gpu-lmi-tgi"
24+
# resolved_config is a dict (as returned by deep_override_dict)
25+
config.resolved_config = {
26+
"supported_inference_instance_types": [
27+
"ml.g5.xlarge",
28+
"ml.g5.2xlarge",
29+
"ml.g5.4xlarge",
30+
"ml.g5.12xlarge",
31+
"ml.p4d.24xlarge",
32+
],
33+
"model_id": "meta-llama/Llama-2-7b",
34+
}
35+
return config
36+
37+
@pytest.fixture
38+
def mock_neuron_config(self):
39+
"""Create a mock Neuron config with dict resolved_config."""
40+
config = MagicMock()
41+
config.config_name = "neuron-inference"
42+
# resolved_config is a dict (as returned by deep_override_dict)
43+
config.resolved_config = {
44+
"supported_inference_instance_types": [
45+
"ml.inf2.xlarge",
46+
"ml.inf2.8xlarge",
47+
"ml.inf2.24xlarge",
48+
"ml.inf2.48xlarge",
49+
],
50+
"model_id": "meta-llama/Llama-2-7b",
51+
}
52+
return config
53+
54+
@pytest.fixture
55+
def mock_ranking(self):
56+
"""Create a mock ranking with GPU first, then Neuron."""
57+
ranking = MagicMock()
58+
ranking.rankings = ["gpu-lmi-tgi", "neuron-inference"]
59+
return ranking
60+
61+
def test_no_instance_type_returns_highest_ranked(
62+
self, mock_gpu_config, mock_neuron_config, mock_ranking
63+
):
64+
"""When no instance_type specified, return highest ranked config."""
65+
from sagemaker.core.jumpstart.types import JumpStartMetadataConfigs
66+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
67+
68+
configs = JumpStartMetadataConfigs(
69+
configs={
70+
"gpu-lmi-tgi": mock_gpu_config,
71+
"neuron-inference": mock_neuron_config,
72+
},
73+
config_rankings={"default": mock_ranking},
74+
scope=JumpStartScriptScope.INFERENCE,
75+
)
76+
77+
result = configs.get_top_config_from_ranking(instance_type=None)
78+
assert result is not None
79+
assert result.config_name == "gpu-lmi-tgi"
80+
81+
def test_gpu_instance_returns_gpu_config(
82+
self, mock_gpu_config, mock_neuron_config, mock_ranking
83+
):
84+
"""When GPU instance specified, return GPU config."""
85+
from sagemaker.core.jumpstart.types import JumpStartMetadataConfigs
86+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
87+
88+
configs = JumpStartMetadataConfigs(
89+
configs={
90+
"gpu-lmi-tgi": mock_gpu_config,
91+
"neuron-inference": mock_neuron_config,
92+
},
93+
config_rankings={"default": mock_ranking},
94+
scope=JumpStartScriptScope.INFERENCE,
95+
)
96+
97+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
98+
assert result is not None
99+
assert result.config_name == "gpu-lmi-tgi"
100+
101+
def test_inferentia_instance_returns_neuron_config(
102+
self, mock_gpu_config, mock_neuron_config, mock_ranking
103+
):
104+
"""
105+
When Inferentia instance specified, return Neuron config.
106+
107+
This is the critical test case that was failing before the fix.
108+
The bug caused GPU config to be returned even for Inferentia instances
109+
because getattr() was used on a dict instead of dict key access.
110+
"""
111+
from sagemaker.core.jumpstart.types import JumpStartMetadataConfigs
112+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
113+
114+
configs = JumpStartMetadataConfigs(
115+
configs={
116+
"gpu-lmi-tgi": mock_gpu_config,
117+
"neuron-inference": mock_neuron_config,
118+
},
119+
config_rankings={"default": mock_ranking},
120+
scope=JumpStartScriptScope.INFERENCE,
121+
)
122+
123+
result = configs.get_top_config_from_ranking(instance_type="ml.inf2.24xlarge")
124+
assert result is not None
125+
assert result.config_name == "neuron-inference"
126+
127+
def test_unsupported_instance_returns_none(
128+
self, mock_gpu_config, mock_neuron_config, mock_ranking
129+
):
130+
"""When unsupported instance specified, return None."""
131+
from sagemaker.core.jumpstart.types import JumpStartMetadataConfigs
132+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
133+
134+
configs = JumpStartMetadataConfigs(
135+
configs={
136+
"gpu-lmi-tgi": mock_gpu_config,
137+
"neuron-inference": mock_neuron_config,
138+
},
139+
config_rankings={"default": mock_ranking},
140+
scope=JumpStartScriptScope.INFERENCE,
141+
)
142+
143+
result = configs.get_top_config_from_ranking(instance_type="ml.trn1.32xlarge")
144+
assert result is None
145+
146+
def test_training_scope_uses_training_instance_types(self):
147+
"""Verify training scope uses supported_training_instance_types."""
148+
from sagemaker.core.jumpstart.types import JumpStartMetadataConfigs
149+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
150+
151+
gpu_config = MagicMock()
152+
gpu_config.config_name = "gpu-training"
153+
gpu_config.resolved_config = {
154+
"supported_training_instance_types": [
155+
"ml.p4d.24xlarge",
156+
"ml.p5.48xlarge",
157+
],
158+
}
159+
160+
trn_config = MagicMock()
161+
trn_config.config_name = "trainium-training"
162+
trn_config.resolved_config = {
163+
"supported_training_instance_types": [
164+
"ml.trn1.32xlarge",
165+
"ml.trn1n.32xlarge",
166+
],
167+
}
168+
169+
ranking = MagicMock()
170+
ranking.rankings = ["gpu-training", "trainium-training"]
171+
172+
configs = JumpStartMetadataConfigs(
173+
configs={
174+
"gpu-training": gpu_config,
175+
"trainium-training": trn_config,
176+
},
177+
config_rankings={"default": ranking},
178+
scope=JumpStartScriptScope.TRAINING,
179+
)
180+
181+
# Trainium instance should select trainium config
182+
result = configs.get_top_config_from_ranking(instance_type="ml.trn1.32xlarge")
183+
assert result is not None
184+
assert result.config_name == "trainium-training"
185+
186+
def test_resolved_config_as_object_still_works(self):
187+
"""
188+
Verify that if resolved_config is an object (not dict), getattr still works.
189+
190+
This ensures backwards compatibility with any code paths where
191+
resolved_config might be an object with attributes.
192+
"""
193+
from sagemaker.core.jumpstart.types import JumpStartMetadataConfigs
194+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
195+
196+
# Create a config where resolved_config is an object, not a dict
197+
class ResolvedConfigObject:
198+
supported_inference_instance_types = ["ml.g5.xlarge", "ml.g5.2xlarge"]
199+
200+
config = MagicMock()
201+
config.config_name = "object-config"
202+
config.resolved_config = ResolvedConfigObject()
203+
204+
ranking = MagicMock()
205+
ranking.rankings = ["object-config"]
206+
207+
configs = JumpStartMetadataConfigs(
208+
configs={"object-config": config},
209+
config_rankings={"default": ranking},
210+
scope=JumpStartScriptScope.INFERENCE,
211+
)
212+
213+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
214+
assert result is not None
215+
assert result.config_name == "object-config"
216+
217+
218+
class TestResolvedConfigType:
219+
"""Tests verifying that resolved_config is correctly identified as dict."""
220+
221+
def test_deep_override_dict_returns_dict(self):
222+
"""Verify deep_override_dict returns a plain dict."""
223+
from sagemaker.core.common_utils import deep_override_dict
224+
225+
base = {"field1": "value1", "nested": {"a": 1}}
226+
override = {"field2": "value2", "nested": {"b": 2}}
227+
228+
result = deep_override_dict(base, override)
229+
230+
assert isinstance(result, dict)
231+
assert "field1" in result
232+
assert "field2" in result
233+
234+
def test_getattr_fails_on_dict(self):
235+
"""Verify that getattr fails on dict for non-existent attributes."""
236+
d = {"supported_inference_instance_types": ["ml.g5.xlarge"]}
237+
238+
with pytest.raises(AttributeError):
239+
getattr(d, "supported_inference_instance_types")
240+
241+
def test_dict_get_works(self):
242+
"""Verify that dict.get() works correctly."""
243+
d = {"supported_inference_instance_types": ["ml.g5.xlarge"]}
244+
245+
result = d.get("supported_inference_instance_types", [])
246+
assert result == ["ml.g5.xlarge"]
247+
248+
# Non-existent key returns default
249+
result = d.get("nonexistent", [])
250+
assert result == []

0 commit comments

Comments
 (0)