Skip to content

Commit 62ca624

Browse files
committed
allow SAGEMAKER_HUB_NAME env var override for HUB_NAME constant
1 parent 4c184d4 commit 62ca624

6 files changed

Lines changed: 36 additions & 6 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig
1616
from sagemaker.train.configs import InputData, OutputDataConfig
1717
from sagemaker.train.defaults import TrainDefaults
18+
from sagemaker.train.constants import HUB_NAME
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -317,7 +318,7 @@ def _resolve_model_package_arn(model_package) -> Optional[str]:
317318

318319

319320
def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session,
320-
hub_name: str = "SageMakerPublicHub") -> tuple:
321+
hub_name: str = HUB_NAME) -> tuple:
321322
"""Get fine-tuning options and model ARN for given customization technique.
322323
323324
Returns:

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from enum import Enum
1515
import re
1616
from sagemaker.train.base_trainer import BaseTrainer
17+
from sagemaker.train.constants import HUB_NAME
1718
from sagemaker.core.utils.utils import Unassigned
1819

1920

@@ -52,7 +53,7 @@ class _ModelResolver:
5253
and fine-tuned ModelPackage objects/ARNs.
5354
"""
5455

55-
DEFAULT_HUB_NAME = "SageMakerPublicHub"
56+
DEFAULT_HUB_NAME = HUB_NAME
5657

5758
def __init__(self, sagemaker_session=None):
5859
"""

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
+ f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}",
4141
]
4242

43-
HUB_NAME = "SageMakerPublicHub"
43+
HUB_NAME = os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub")
4444

4545
# Allowed reward model IDs for RLAIF trainer with region restrictions
4646
_ALLOWED_REWARD_MODEL_IDS = {

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .execution import EvaluationPipelineExecution
2222
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2323
from sagemaker.core.telemetry.constants import Feature
24+
from sagemaker.train.constants import HUB_NAME
2425

2526
_logger = logging.getLogger(__name__)
2627

@@ -466,7 +467,7 @@ def hyperparameters(self):
466467

467468
override_params = _get_evaluation_override_params(
468469
hub_content_name=hub_content_name,
469-
hub_name="SageMakerPublicHub",
470+
hub_name=HUB_NAME,
470471
evaluation_type=evaluation_type,
471472
region=region,
472473
session=boto_session

sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .execution import EvaluationPipelineExecution
1717
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
1818
from sagemaker.core.telemetry.constants import Feature
19+
from sagemaker.train.constants import HUB_NAME
1920

2021
_logger = logging.getLogger(__name__)
2122

@@ -240,7 +241,7 @@ def hyperparameters(self):
240241

241242
override_params = _get_evaluation_override_params(
242243
hub_content_name=hub_content_name,
243-
hub_name="SageMakerPublicHub",
244+
hub_name=HUB_NAME,
244245
evaluation_type="DeterministicEvaluation",
245246
region=region,
246247
session=boto_session
@@ -365,7 +366,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict:
365366
_logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}")
366367
override_params = _get_evaluation_override_params(
367368
hub_content_name=hub_content_name,
368-
hub_name="SageMakerPublicHub",
369+
hub_name=HUB_NAME,
369370
evaluation_type="DeterministicEvaluation",
370371
region=region,
371372
session=session
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Tests for SAGEMAKER_HUB_NAME env-var override of the HUB_NAME constant."""
2+
from __future__ import absolute_import
3+
4+
import importlib
5+
import os
6+
from unittest.mock import patch
7+
8+
9+
def _reload_hub_name():
10+
"""Reload the constants module under the current env and return HUB_NAME."""
11+
from sagemaker.train import constants
12+
importlib.reload(constants)
13+
return constants.HUB_NAME
14+
15+
16+
def test_hub_name_defaults_to_public_hub():
17+
"""When SAGEMAKER_HUB_NAME is unset, HUB_NAME is SageMakerPublicHub."""
18+
env = {k: v for k, v in os.environ.items() if k != "SAGEMAKER_HUB_NAME"}
19+
with patch.dict(os.environ, env, clear=True):
20+
assert _reload_hub_name() == "SageMakerPublicHub"
21+
22+
23+
def test_hub_name_overridden_by_env_var():
24+
"""When SAGEMAKER_HUB_NAME is set, HUB_NAME reflects the override."""
25+
with patch.dict(os.environ, {"SAGEMAKER_HUB_NAME": "MyPrivateHub"}):
26+
assert _reload_hub_name() == "MyPrivateHub"

0 commit comments

Comments
 (0)