1616import skimage
1717import skimage .measure
1818from getScheduler import getScheduler , SCHEDULERS
19- from getPipeline import getPipelineForModel , listAvailablePipelines , clearPipelines
19+ from getPipeline import (
20+ getPipelineClass ,
21+ getPipelineForModel ,
22+ listAvailablePipelines ,
23+ clearPipelines ,
24+ )
2025import re
2126import requests
2227from download import download_model , normalize_model_id
@@ -228,7 +233,7 @@ def sendStatus():
228233 model_dir = os .path .join (MODELS_DIR , normalized_model_id )
229234 pipeline_name = call_inputs .get ("PIPELINE" , None )
230235 if pipeline_name :
231- pipeline_class = getattr ( diffusers_pipelines , pipeline_name )
236+ pipeline_class = getPipelineClass ( pipeline_name )
232237 if last_model_id != normalized_model_id :
233238 # if not downloaded_models.get(normalized_model_id, None):
234239 if not os .path .isdir (model_dir ):
@@ -250,7 +255,7 @@ def sendStatus():
250255 hf_model_id = hf_model_id ,
251256 model_precision = model_precision ,
252257 send_opts = send_opts ,
253- pipeline_class = pipeline_class ,
258+ pipeline_class = pipeline_class if pipeline_name else None ,
254259 )
255260 # downloaded_models.update({normalized_model_id: True})
256261 clearPipelines ()
@@ -267,7 +272,7 @@ def sendStatus():
267272 precision = model_precision ,
268273 revision = model_revision ,
269274 send_opts = send_opts ,
270- pipeline_class = pipeline_class ,
275+ pipeline_class = pipeline_class if pipeline_name else None ,
271276 )
272277 await send (
273278 "loadModel" , "done" , {"startRequestId" : startRequestId }, send_opts
0 commit comments