Skip to content

Commit bd47846

Browse files
committed
fix: address review comments (iteration #1)
1 parent 857485f commit bd47846

File tree

2 files changed

+128
-294
lines changed

2 files changed

+128
-294
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -319,45 +319,43 @@ def _build_for_djl(self) -> Model:
319319
logger.debug(f"Using detected notebook instance type: {nb_instance}")
320320

321321
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
322-
# Configure HuggingFace model for DJL
322+
# Configure HuggingFace model for DJL (preserve user-provided HF_MODEL_ID)
323323
self.env_vars.setdefault("HF_MODEL_ID", self.model)
324-
324+
325325
# Get model configuration for DJL optimization
326326
self.hf_model_config = _get_model_config_properties_from_hf(
327327
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
328328
)
329-
329+
330330
# Apply DJL-specific configurations
331331
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
332332
self.model, self.hf_model_config, self.schema_builder
333333
)
334334
self.env_vars.update(default_djl_configurations)
335-
335+
336336
# Configure schema builder for text generation
337337
if "parameters" not in self.schema_builder.sample_input:
338338
self.schema_builder.sample_input["parameters"] = {}
339339
self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens
340-
341-
# Set DJL serving defaults
340+
341+
# Set DJL serving defaults (only if not already set by user)
342342
djl_env_vars = {
343343
"OPTION_ENGINE": "Python",
344344
"SERVING_MIN_WORKERS": "1",
345-
"SERVING_MAX_WORKERS": "1",
345+
"SERVING_MAX_WORKERS": "1",
346346
"OPTION_MODEL_LOADING_TIMEOUT": "240",
347347
"OPTION_PREDICT_TIMEOUT": "60",
348-
"TENSOR_PARALLEL_DEGREE": "1", # Default, will be overridden below
349-
"HF_HOME": "/tmp",
350-
"HUGGINGFACE_HUB_CACHE": "/tmp",
348+
"TENSOR_PARALLEL_DEGREE": "1",
351349
}
352-
350+
353351
# Add HuggingFace authentication
354352
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
355353
djl_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
356-
354+
357355
# Update with defaults only if not already set
358356
for key, value in djl_env_vars.items():
359357
self.env_vars.setdefault(key, value)
360-
358+
361359
# DJL downloads models directly from HuggingFace Hub
362360
self.s3_upload_path = None
363361

@@ -369,12 +367,15 @@ def _build_for_djl(self) -> Model:
369367
else:
370368
self.s3_model_data_url, _ = self._prepare_for_mode()
371369

370+
# Set HF cache env vars to writable location (unconditionally, using setdefault
371+
# to preserve user-provided values). This is needed because /opt/ml/model/ may be
372+
# read-only when source_code artifacts are mounted there.
373+
self.env_vars.setdefault("HF_HOME", "/tmp")
374+
self.env_vars.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp")
375+
372376
# Cache management based on mode
373377
if self.mode in LOCAL_MODES:
374378
self.env_vars.update({"HF_HUB_OFFLINE": "1"})
375-
else:
376-
self.env_vars["HF_HOME"] = "/tmp"
377-
self.env_vars["HUGGINGFACE_HUB_CACHE"] = "/tmp"
378379

379380
# GPU-based tensor parallel calculation for SAGEMAKER_ENDPOINT mode
380381
if self.mode == Mode.SAGEMAKER_ENDPOINT:

0 commit comments

Comments
 (0)