Skip to content

Commit 8ac0348

Browse files
authored
Hugging face model handler #3 (#38696)
1 parent c31f232 commit 8ac0348

5 files changed

Lines changed: 114 additions & 2 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"revision": 2
3+
"revision": 3
44
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
pipelines:
17+
- pipeline:
18+
type: chain
19+
transforms:
20+
- type: Create
21+
config:
22+
elements:
23+
- text: "I love Apache Beam!"
24+
- text: "I hate this error."
25+
- type: RunInference
26+
config:
27+
model_handler:
28+
type: "HuggingFacePipeline"
29+
config:
30+
task: "text-classification"
31+
inference_fn:
32+
callable: |
33+
def real_inference(batch, pipeline, inference_args):
34+
predictions = pipeline(batch, **inference_args)
35+
36+
# If it's a single dictionary (batch size of 1), wrap it in a list
37+
if isinstance(predictions, dict):
38+
predictions = [predictions]
39+
40+
return {
41+
'label': [p['label'] for p in predictions],
42+
'score': [p['score'] for p in predictions]
43+
}
44+
preprocess:
45+
callable: 'lambda x: x.text'
46+
- type: MapToFields
47+
config:
48+
language: python
49+
fields:
50+
text: text
51+
sentiment:
52+
callable: 'lambda x: x.inference.inference["label"]'
53+
- type: AssertEqual
54+
config:
55+
elements:
56+
- text: "I love Apache Beam!"
57+
sentiment: "POSITIVE"
58+
- text: "I hate this error."
59+
sentiment: "NEGATIVE"
60+
61+
options:
62+
yaml_experimental_features: ['ML']

sdks/python/apache_beam/yaml/tests/runinference.yaml renamed to sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml

File renamed without changes.

sdks/python/apache_beam/yaml/yaml_ml.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,55 @@ def inference_output_type(self):
282282
('model_id', Optional[str])])
283283

284284

285+
@ModelHandlerProvider.register_handler_type('HuggingFacePipeline')
286+
class HuggingFacePipelineProvider(ModelHandlerProvider):
287+
def __init__(
288+
self,
289+
task: Optional[str] = None,
290+
model: Optional[str] = None,
291+
preprocess: Optional[dict[str, str]] = None,
292+
postprocess: Optional[dict[str, str]] = None,
293+
device: Optional[Any] = None,
294+
inference_fn: Optional[dict[str, str]] = None,
295+
load_pipeline_args: Optional[dict[str, Any]] = None,
296+
**kwargs):
297+
try:
298+
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
299+
except ImportError:
300+
raise ValueError(
301+
'Unable to import HuggingFacePipelineModelHandler. Please '
302+
'install transformers dependencies.')
303+
304+
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')}
305+
306+
inference_fn_obj = self.parse_processing_transform(
307+
inference_fn, 'inference_fn') if inference_fn else None
308+
309+
handler_kwargs = {}
310+
if inference_fn_obj:
311+
handler_kwargs['inference_fn'] = inference_fn_obj
312+
313+
_handler = HuggingFacePipelineModelHandler(
314+
task=task,
315+
model=model,
316+
device=device,
317+
load_pipeline_args=load_pipeline_args,
318+
**handler_kwargs,
319+
**kwargs)
320+
321+
super().__init__(_handler, preprocess, postprocess)
322+
323+
@staticmethod
324+
def validate(config):
325+
if not config or (not config.get('task') and not config.get('model')):
326+
raise ValueError(
327+
"HuggingFacePipeline requires either 'task' or "
328+
"'model' to be specified.")
329+
330+
def inference_output_type(self):
331+
return Any
332+
333+
285334
@beam.ptransform.ptransform_fn
286335
def run_inference(
287336
pcoll,

sdks/python/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,8 @@ def get_portability_package_data():
654654
'transformers': [
655655
'transformers>=4.28.0,<4.56.0',
656656
'tensorflow>=2.12.0',
657-
'torch>=1.9.0'
657+
# Avoid torch 2.12.0+ which fails to run unit tests with segfault
658+
'torch>=1.9.0,<2.12.0'
658659
],
659660
'ml_cpu': [
660661
'tensorflow>=2.12.0',

0 commit comments

Comments
 (0)