Skip to content

Commit 8dc1601

Browse files
author
CI Runner
committed
fix: pass model_id to load_pipeline for modelscope/hf download
1 parent 5d97219 commit 8dc1601

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

.github/ci_runners/run_pipeline.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def save_image(image, output_dir: str, filename: str) -> str:
3333
return filepath
3434

3535

36-
def load_pipeline(pipeline_class_name: str, module_path: str, weight_path: str, backend: str, device: str, torch_dtype: torch.dtype):
36+
def load_pipeline(pipeline_class_name: str, module_path: str, weight_path: str, model_id: str, backend: str, device: str, torch_dtype: torch.dtype):
3737
module = importlib.import_module(module_path)
3838
pipeline_cls = getattr(module, pipeline_class_name)
3939

@@ -45,12 +45,18 @@ def load_pipeline(pipeline_class_name: str, module_path: str, weight_path: str,
4545
)
4646
elif backend == "modelscope":
4747
from modelscope import snapshot_download
48-
local_path = snapshot_download(weight_path)
48+
local_path = snapshot_download(model_id or weight_path)
4949
pipe = pipeline_cls.from_pretrained(
5050
local_path,
5151
torch_dtype=torch_dtype,
5252
trust_remote_code=True,
5353
)
54+
elif backend == "hf":
55+
pipe = pipeline_cls.from_pretrained(
56+
model_id,
57+
torch_dtype=torch_dtype,
58+
trust_remote_code=True,
59+
)
5460
else:
5561
pipe = pipeline_cls.from_pretrained(
5662
weight_path,
@@ -182,6 +188,7 @@ def generate_reference_only(config_dir: str, output_dir: str):
182188
pipe = load_pipeline(
183189
pipeline_class, module_path,
184190
variant_data.get("weight_path", variant_data["model_id"]),
191+
variant_data["model_id"],
185192
variant_data.get("backend", "local"),
186193
device, torch_dtype,
187194
)

0 commit comments

Comments
 (0)