@@ -33,7 +33,7 @@ def save_image(image, output_dir: str, filename: str) -> str:
3333 return filepath
3434
3535
36- def load_pipeline (pipeline_class_name : str , module_path : str , weight_path : str , backend : str , device : str , torch_dtype : torch .dtype ):
36+ def load_pipeline (pipeline_class_name : str , module_path : str , weight_path : str , model_id : str , backend : str , device : str , torch_dtype : torch .dtype ):
3737 module = importlib .import_module (module_path )
3838 pipeline_cls = getattr (module , pipeline_class_name )
3939
@@ -45,12 +45,18 @@ def load_pipeline(pipeline_class_name: str, module_path: str, weight_path: str,
4545 )
4646 elif backend == "modelscope" :
4747 from modelscope import snapshot_download
48- local_path = snapshot_download (weight_path )
48+ local_path = snapshot_download (model_id or weight_path )
4949 pipe = pipeline_cls .from_pretrained (
5050 local_path ,
5151 torch_dtype = torch_dtype ,
5252 trust_remote_code = True ,
5353 )
54+ elif backend == "hf" :
55+ pipe = pipeline_cls .from_pretrained (
56+ model_id ,
57+ torch_dtype = torch_dtype ,
58+ trust_remote_code = True ,
59+ )
5460 else :
5561 pipe = pipeline_cls .from_pretrained (
5662 weight_path ,
@@ -182,6 +188,7 @@ def generate_reference_only(config_dir: str, output_dir: str):
182188 pipe = load_pipeline (
183189 pipeline_class , module_path ,
184190 variant_data .get ("weight_path" , variant_data ["model_id" ]),
191+ variant_data ["model_id" ],
185192 variant_data .get ("backend" , "local" ),
186193 device , torch_dtype ,
187194 )
0 commit comments