Skip to content

Commit 0c22345

Browse files
committed
[TRTLLM-12154][test] Move model path helper to common
Signed-off-by: Brian Nguyen <brnguyen@nvidia.com>
1 parent 392fd4b commit 0c22345

2 files changed

Lines changed: 15 additions & 12 deletions

File tree

tests/integration/defs/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def _war_check_output(*args, **kwargs):
5656
return venv.run_cmd(cmd, caller=_war_check_output, env=env, **kwargs)
5757

5858

59+
def resolve_llm_model_path(model_path: str) -> str:
60+
"""Resolve a model subpath relative to the test LLM model root."""
61+
if os.path.isabs(model_path):
62+
return model_path
63+
64+
from .conftest import llm_models_root
65+
return os.path.join(llm_models_root(), model_path)
66+
67+
5968
def venv_mpi_check_call(venv, mpi_cmd, python_cmd, **kwargs):
6069
"""
6170
This function WAR check_call() to run python_cmd with mpi.

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
import pytest
3232
import yaml
3333
from defs.common import get_free_port_in_ci as get_free_port
34-
from defs.common import parse_gsm8k_output, wait_for_server
34+
from defs.common import (parse_gsm8k_output, resolve_llm_model_path,
35+
wait_for_server)
3536
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
3637
skip_no_hopper, skip_pre_blackwell, skip_pre_hopper)
3738
from defs.trt_test_alternative import check_call, check_output, print_info
@@ -343,13 +344,6 @@ def setup_model_symlink(llm_venv, model_root, dest_subpath):
343344
os.symlink(model_root, dst, target_is_directory=True)
344345

345346

346-
def _resolve_llm_model_path(model_path: str) -> str:
347-
"""Resolve a model subpath relative to the test LLM model root."""
348-
if os.path.isabs(model_path):
349-
return model_path
350-
return os.path.join(llm_models_root(), model_path)
351-
352-
353347
ClientTestSet = namedtuple('ClientTestSet', [
354348
'completion', 'completion_streaming', 'chat', 'chat_streaming',
355349
'verify_completion', 'verify_streaming_completion', 'verify_chat',
@@ -636,7 +630,7 @@ def setup_disagg_cluster(
636630
if isinstance(speculative_config, dict):
637631
speculative_model = speculative_config.get("speculative_model")
638632
if speculative_model:
639-
speculative_config["speculative_model"] = _resolve_llm_model_path(
633+
speculative_config["speculative_model"] = resolve_llm_model_path(
640634
speculative_model)
641635

642636
disagg_cluster = get_default_disagg_cluster_config()
@@ -670,7 +664,7 @@ def setup_disagg_cluster(
670664
# Launch workers
671665
model = model_name or config.get("model")
672666
if model:
673-
model = _resolve_llm_model_path(model)
667+
model = resolve_llm_model_path(model)
674668
ctx_workers = []
675669
gen_workers = []
676670
disagg_server = None
@@ -2317,7 +2311,7 @@ def test_disaggregated_gpt_oss_120b_harmony(disaggregated_test_root,
23172311
def test_disaggregated_qwen3_32b_fp8(disaggregated_test_root,
23182312
disaggregated_example_root, llm_venv,
23192313
model_path):
2320-
model_dir = _resolve_llm_model_path(model_path)
2314+
model_dir = resolve_llm_model_path(model_path)
23212315
setup_model_symlink(llm_venv, model_dir, model_path)
23222316

23232317
run_disaggregated_test(disaggregated_example_root,
@@ -2407,7 +2401,7 @@ def test_disaggregated_stress_test(disaggregated_test_root,
24072401
# Unpack configuration from dataclass
24082402
model_path = test_config.model_path
24092403
test_desc = test_config.test_desc
2410-
model_dir = _resolve_llm_model_path(model_path)
2404+
model_dir = resolve_llm_model_path(model_path)
24112405
setup_model_symlink(llm_venv, model_dir, model_path)
24122406

24132407
config_file = get_test_config(test_desc, disaggregated_example_root,

0 commit comments

Comments
 (0)