Skip to content

Commit 4787358

Browse files
committed
added DistilBertForSequenceClassificationCompat
1 parent 7a0d05b commit 4787358

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

sdks/python/apache_beam/examples/inference/pytorch_sentiment.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ def process(self, text: str) -> Iterable[tuple[str, dict]]:
8181
yield text, {k: torch.squeeze(v) for k, v in tokenized.items()}
8282

8383

84+
class DistilBertForSequenceClassificationCompat(
85+
DistilBertForSequenceClassification):
86+
"""Builds config in worker runtime to avoid cross-env config drift."""
87+
def __init__(self, model_name: str, num_labels: int = 2):
88+
config = _ensure_transformers_config_compat(
89+
DistilBertConfig.from_pretrained(model_name, num_labels=num_labels))
90+
super().__init__(config)
91+
92+
8493
class RateLimitDoFn(beam.DoFn):
8594
def __init__(self, rate_per_sec: float):
8695
self.delay = 1.0 / rate_per_sec
@@ -265,12 +274,12 @@ def run(
265274
method = beam.io.WriteToBigQuery.Method.STREAMING_INSERTS
266275
pipeline_options.view_as(StandardOptions).streaming = True
267276

268-
model_config = _ensure_transformers_config_compat(
269-
DistilBertConfig.from_pretrained(known_args.model_path, num_labels=2))
270-
271277
model_handler = PytorchModelHandlerKeyedTensor(
272-
model_class=DistilBertForSequenceClassification,
273-
model_params={'config': model_config},
278+
model_class=DistilBertForSequenceClassificationCompat,
279+
model_params={
280+
'model_name': known_args.model_path,
281+
'num_labels': 2,
282+
},
274283
state_dict_path=known_args.model_state_dict_path,
275284
device='GPU')
276285

0 commit comments

Comments
 (0)