Skip to content

Commit 7e7f7f9

Browse files
authored
Merge branch 'master' into fix/skip-none-hyperparameters
2 parents e69f8db + 736781b commit 7e7f7f9

11 files changed

Lines changed: 48 additions & 19 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig
1616
from sagemaker.train.configs import InputData, OutputDataConfig
1717
from sagemaker.train.defaults import TrainDefaults
18+
from sagemaker.train.constants import get_sagemaker_hub_name
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -317,13 +318,15 @@ def _resolve_model_package_arn(model_package) -> Optional[str]:
317318

318319

319320
def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session,
320-
hub_name: str = "SageMakerPublicHub") -> tuple:
321+
hub_name: Optional[str] = None) -> tuple:
321322
"""Get fine-tuning options and model ARN for given customization technique.
322323
323324
Returns:
324325
tuple: (FineTuningOptions, model_arn, is_gated_model)
325326
"""
326-
327+
if hub_name is None:
328+
hub_name = get_sagemaker_hub_name()
329+
327330
try:
328331

329332
hub_content = _get_hub_content_metadata(

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from enum import Enum
1515
import re
1616
from sagemaker.train.base_trainer import BaseTrainer
17+
from sagemaker.train.constants import get_sagemaker_hub_name
1718
from sagemaker.core.utils.utils import Unassigned
1819

1920

@@ -52,8 +53,6 @@ class _ModelResolver:
5253
and fine-tuned ModelPackage objects/ARNs.
5354
"""
5455

55-
DEFAULT_HUB_NAME = "SageMakerPublicHub"
56-
5756
def __init__(self, sagemaker_session=None):
5857
"""
5958
Initialize the resolver.
@@ -89,7 +88,7 @@ def resolve_model_info(
8988
if base_model.startswith("arn:aws:sagemaker:") and ":model-package/" in base_model:
9089
return self._resolve_model_package_arn(base_model)
9190
else:
92-
return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME)
91+
return self._resolve_jumpstart_model(base_model, hub_name or get_sagemaker_hub_name())
9392
# Handle BaseTrainer type
9493
elif isinstance(base_model, BaseTrainer):
9594
if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job,

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
+ f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}",
4141
]
4242

43-
HUB_NAME = "SageMakerPublicHub"
43+
def get_sagemaker_hub_name() -> str:
44+
"""Return the SageMaker Hub name, honoring SAGEMAKER_HUB_NAME env var override.
45+
46+
Resolved at call time so tests and dev workflows can override the hub
47+
without re-importing this module. Defaults to ``"SageMakerPublicHub"``.
48+
"""
49+
return os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub")
4450

4551
# Allowed reward model IDs for RLAIF trainer with region restrictions
4652
_ALLOWED_REWARD_MODEL_IDS = {

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2525
from sagemaker.core.telemetry.constants import Feature
26-
from sagemaker.train.constants import HUB_NAME
26+
from sagemaker.train.constants import get_sagemaker_hub_name
2727

2828
logger = logging.getLogger(__name__)
2929
logger.setLevel(logging.INFO)
@@ -244,7 +244,7 @@ def train(self,
244244
)
245245

246246
vpc_config = self.networking if self.networking else None
247-
tags = _get_studio_tags(self._model_name, HUB_NAME)
247+
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
248248

249249
# Build TrainingJob.create() arguments
250250
create_args = {

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .execution import EvaluationPipelineExecution
2222
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2323
from sagemaker.core.telemetry.constants import Feature
24+
from sagemaker.train.constants import get_sagemaker_hub_name
2425

2526
_logger = logging.getLogger(__name__)
2627

@@ -466,7 +467,7 @@ def hyperparameters(self):
466467

467468
override_params = _get_evaluation_override_params(
468469
hub_content_name=hub_content_name,
469-
hub_name="SageMakerPublicHub",
470+
hub_name=get_sagemaker_hub_name(),
470471
evaluation_type=evaluation_type,
471472
region=region,
472473
session=boto_session

sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .execution import EvaluationPipelineExecution
1717
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
1818
from sagemaker.core.telemetry.constants import Feature
19+
from sagemaker.train.constants import get_sagemaker_hub_name
1920

2021
_logger = logging.getLogger(__name__)
2122

@@ -240,7 +241,7 @@ def hyperparameters(self):
240241

241242
override_params = _get_evaluation_override_params(
242243
hub_content_name=hub_content_name,
243-
hub_name="SageMakerPublicHub",
244+
hub_name=get_sagemaker_hub_name(),
244245
evaluation_type="DeterministicEvaluation",
245246
region=region,
246247
session=boto_session
@@ -365,7 +366,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict:
365366
_logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}")
366367
override_params = _get_evaluation_override_params(
367368
hub_content_name=hub_content_name,
368-
hub_name="SageMakerPublicHub",
369+
hub_name=get_sagemaker_hub_name(),
369370
evaluation_type="DeterministicEvaluation",
370371
region=region,
371372
session=session

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2929
from sagemaker.core.telemetry.constants import Feature
30-
from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS
30+
from sagemaker.train.constants import get_sagemaker_hub_name, _ALLOWED_REWARD_MODEL_IDS
3131

3232
logger = logging.getLogger(__name__)
3333

@@ -263,7 +263,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
263263
)
264264

265265
vpc_config = self.networking if self.networking else None
266-
tags = _get_studio_tags(self._model_name, HUB_NAME)
266+
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
267267

268268
# Build TrainingJob.create() arguments
269269
create_args = {
@@ -358,7 +358,7 @@ def _process_non_builtin_reward_prompt(self):
358358
sagemaker_session=self.sagemaker_session
359359
)
360360
hub_content = _get_hub_content_metadata(
361-
hub_name=HUB_NAME,
361+
hub_name=get_sagemaker_hub_name(),
362362
hub_content_type="JsonDoc",
363363
hub_content_name=self.reward_prompt,
364364
session=session.boto_session,

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2727
from sagemaker.core.telemetry.constants import Feature
28-
from sagemaker.train.constants import HUB_NAME
28+
from sagemaker.train.constants import get_sagemaker_hub_name
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -251,7 +251,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
251251
)
252252

253253
vpc_config = self.networking if self.networking else None
254-
tags = _get_studio_tags(self._model_name, HUB_NAME)
254+
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
255255

256256
# Build TrainingJob.create() arguments
257257
create_args = {

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2626
from sagemaker.core.telemetry.constants import Feature
27-
from sagemaker.train.constants import HUB_NAME
27+
from sagemaker.train.constants import get_sagemaker_hub_name
2828

2929
logger = logging.getLogger(__name__)
3030
logger.setLevel(logging.INFO)
@@ -245,7 +245,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
245245
)
246246

247247
vpc_config = self.networking if self.networking else None
248-
tags = _get_studio_tags(self._model_name, HUB_NAME)
248+
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
249249

250250
# Build TrainingJob.create() arguments
251251
create_args = {

sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def test_resolver_initialization(self):
6363
"""Test ModelResolver initialization."""
6464
resolver = _ModelResolver()
6565
assert resolver.sagemaker_session is None
66-
assert resolver.DEFAULT_HUB_NAME == "SageMakerPublicHub"
6766

6867
def test_resolver_with_session(self):
6968
"""Test ModelResolver with custom session."""

0 commit comments

Comments
 (0)