Skip to content

Commit 4c184d4

Browse files
authored
fix: bug: ModelBuilder overwrites user-provided HF_MODEL_ID for DJL Serving, preventi (5529) (#5734)
* fix: bug: ModelBuilder overwrites user-provided HF_MODEL_ID for DJL Serving, preventi (5529) * fix: address review comments (iteration #1) * fix: address review comments (iteration #1) * fix: address review comments (iteration #1) * fix: address review comments (iteration #2)
1 parent f20a7e2 commit 4c184d4

File tree

2 files changed

+390
-8
lines changed

2 files changed

+390
-8
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _build_for_torchserve(self) -> Model:
136136
if isinstance(self.model, str):
137137
# Configure HuggingFace model support
138138
if not self._is_jumpstart_model_id():
139-
self.env_vars.update({"HF_MODEL_ID": self.model})
139+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
140140

141141
# Add HuggingFace token if available
142142
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
@@ -212,7 +212,7 @@ def _build_for_tgi(self) -> Model:
212212

213213
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
214214
# Configure HuggingFace model for TGI
215-
self.env_vars.update({"HF_MODEL_ID": self.model})
215+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
216216

217217
self.hf_model_config = _get_model_config_properties_from_hf(
218218
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
@@ -319,9 +319,9 @@ def _build_for_djl(self) -> Model:
319319
logger.debug(f"Using detected notebook instance type: {nb_instance}")
320320

321321
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
322-
# Configure HuggingFace model for DJL (preserve user-provided HF_MODEL_ID)
322+
# Configure HuggingFace model for DJL
323323
self.env_vars.setdefault("HF_MODEL_ID", self.model)
324-
324+
325325
# Get model configuration for DJL optimization
326326
self.hf_model_config = _get_model_config_properties_from_hf(
327327
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
@@ -432,7 +432,7 @@ def _build_for_triton(self) -> Model:
432432
self.env_vars.update({"HF_TASK": model_task})
433433

434434
# Configure HuggingFace authentication
435-
self.env_vars.update({"HF_MODEL_ID": self.model})
435+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
436436
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
437437
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
438438

@@ -538,7 +538,7 @@ def _build_for_tei(self) -> Model:
538538

539539
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
540540
# Configure HuggingFace model for TEI
541-
self.env_vars.update({"HF_MODEL_ID": self.model})
541+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
542542

543543
self.hf_model_config = _get_model_config_properties_from_hf(
544544
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
@@ -682,7 +682,7 @@ def _build_for_transformers(self) -> Model:
682682
if self.inference_spec is not None:
683683
hf_model_id = self.inference_spec.get_model()
684684
if isinstance(hf_model_id, str): # Only if it's a valid HF model ID
685-
self.env_vars.update({"HF_MODEL_ID": hf_model_id})
685+
self.env_vars.setdefault("HF_MODEL_ID", hf_model_id)
686686
# Get HF config only for string model IDs
687687
if hasattr(self.env_vars, "HF_API_TOKEN"):
688688
self.hf_model_config = _get_model_config_properties_from_hf(
@@ -701,7 +701,7 @@ def _build_for_transformers(self) -> Model:
701701
if model_task:
702702
self.env_vars.update({"HF_TASK": model_task})
703703

704-
self.env_vars.update({"HF_MODEL_ID": self.model})
704+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
705705

706706
# Add HuggingFace token if available
707707
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):

0 commit comments

Comments
 (0)