Skip to content

Commit 73f28a5

Browse files
committed
[TRTLLM-12154][test] Move model path helper to common
Signed-off-by: Brian Nguyen <brnguyen@nvidia.com>
1 parent aa0ff5d commit 73f28a5

2 files changed

Lines changed: 15 additions & 12 deletions

File tree

tests/integration/defs/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def _war_check_output(*args, **kwargs):
5656
return venv.run_cmd(cmd, caller=_war_check_output, env=env, **kwargs)
5757

5858

59+
def resolve_llm_model_path(model_path: str) -> str:
60+
"""Resolve a model subpath relative to the test LLM model root."""
61+
if os.path.isabs(model_path):
62+
return model_path
63+
64+
from .conftest import llm_models_root
65+
return os.path.join(llm_models_root(), model_path)
66+
67+
5968
def venv_mpi_check_call(venv, mpi_cmd, python_cmd, **kwargs):
6069
"""
6170
This function WAR check_call() to run python_cmd with mpi.

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
import pytest
3131
import yaml
3232
from defs.common import get_free_port_in_ci as get_free_port
33-
from defs.common import parse_gsm8k_output, wait_for_server
33+
from defs.common import (parse_gsm8k_output, resolve_llm_model_path,
34+
wait_for_server)
3435
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
3536
skip_no_hopper, skip_pre_blackwell, skip_pre_hopper)
3637
from defs.trt_test_alternative import check_call, check_output, print_info
@@ -294,13 +295,6 @@ def setup_model_symlink(llm_venv, model_root, dest_subpath):
294295
os.symlink(model_root, dst, target_is_directory=True)
295296

296297

297-
def _resolve_llm_model_path(model_path: str) -> str:
298-
"""Resolve a model subpath relative to the test LLM model root."""
299-
if os.path.isabs(model_path):
300-
return model_path
301-
return os.path.join(llm_models_root(), model_path)
302-
303-
304298
ClientTestSet = namedtuple('ClientTestSet', [
305299
'completion', 'completion_streaming', 'chat', 'chat_streaming',
306300
'verify_completion', 'verify_streaming_completion', 'verify_chat',
@@ -510,7 +504,7 @@ def setup_disagg_cluster(
510504
if isinstance(speculative_config, dict):
511505
speculative_model = speculative_config.get("speculative_model")
512506
if speculative_model:
513-
speculative_config["speculative_model"] = _resolve_llm_model_path(
507+
speculative_config["speculative_model"] = resolve_llm_model_path(
514508
speculative_model)
515509

516510
disagg_cluster = get_default_disagg_cluster_config()
@@ -544,7 +538,7 @@ def setup_disagg_cluster(
544538
# Launch workers
545539
model = model_name or config.get("model")
546540
if model:
547-
model = _resolve_llm_model_path(model)
541+
model = resolve_llm_model_path(model)
548542
ctx_workers = []
549543
gen_workers = []
550544
disagg_server = None
@@ -2099,7 +2093,7 @@ def test_disaggregated_gpt_oss_120b_harmony(disaggregated_test_root,
20992093
def test_disaggregated_qwen3_32b_fp8(disaggregated_test_root,
21002094
disaggregated_example_root, llm_venv,
21012095
model_path):
2102-
model_dir = _resolve_llm_model_path(model_path)
2096+
model_dir = resolve_llm_model_path(model_path)
21032097
setup_model_symlink(llm_venv, model_dir, model_path)
21042098

21052099
run_disaggregated_test(disaggregated_example_root,
@@ -2140,7 +2134,7 @@ def test_disaggregated_stress_test(disaggregated_test_root,
21402134
# Unpack configuration from dataclass
21412135
model_path = test_config.model_path
21422136
test_desc = test_config.test_desc
2143-
model_dir = _resolve_llm_model_path(model_path)
2137+
model_dir = resolve_llm_model_path(model_path)
21442138
setup_model_symlink(llm_venv, model_dir, model_path)
21452139

21462140
config_file = get_test_config(test_desc, disaggregated_example_root,

0 commit comments

Comments
 (0)