@@ -282,6 +282,55 @@ def inference_output_type(self):
282282 ('model_id' , Optional [str ])])
283283
284284
285+ @ModelHandlerProvider .register_handler_type ('HuggingFacePipeline' )
286+ class HuggingFacePipelineProvider (ModelHandlerProvider ):
287+ def __init__ (
288+ self ,
289+ task : Optional [str ] = None ,
290+ model : Optional [str ] = None ,
291+ preprocess : Optional [dict [str , str ]] = None ,
292+ postprocess : Optional [dict [str , str ]] = None ,
293+ device : Optional [Any ] = None ,
294+ inference_fn : Optional [dict [str , str ]] = None ,
295+ load_pipeline_args : Optional [dict [str , Any ]] = None ,
296+ ** kwargs ):
297+ try :
298+ from apache_beam .ml .inference .huggingface_inference import HuggingFacePipelineModelHandler
299+ except ImportError :
300+ raise ValueError (
301+ 'Unable to import HuggingFacePipelineModelHandler. Please '
302+ 'install transformers dependencies.' )
303+
304+ kwargs = {k : v for k , v in kwargs .items () if not k .startswith ('_' )}
305+
306+ inference_fn_obj = self .parse_processing_transform (
307+ inference_fn , 'inference_fn' ) if inference_fn else None
308+
309+ handler_kwargs = {}
310+ if inference_fn_obj :
311+ handler_kwargs ['inference_fn' ] = inference_fn_obj
312+
313+ _handler = HuggingFacePipelineModelHandler (
314+ task = task ,
315+ model = model ,
316+ device = device ,
317+ load_pipeline_args = load_pipeline_args ,
318+ ** handler_kwargs ,
319+ ** kwargs )
320+
321+ super ().__init__ (_handler , preprocess , postprocess )
322+
323+ @staticmethod
324+ def validate (config ):
325+ if not config or (not config .get ('task' ) and not config .get ('model' )):
326+ raise ValueError (
327+ "HuggingFacePipeline requires either 'task' or "
328+ "'model' to be specified." )
329+
330+ def inference_output_type (self ):
331+ return Any
332+
333+
285334@beam .ptransform .ptransform_fn
286335def run_inference (
287336 pcoll ,
0 commit comments