4747
4848class SentimentPostProcessor (beam .DoFn ):
4949 """Processes PredictionResult to extract sentiment label and confidence."""
50- def __init__ (self , tokenizer : DistilBertTokenizerFast ):
51- self .tokenizer = tokenizer
52-
5350 def process (self , element : tuple [str , PredictionResult ]) -> Iterable [dict ]:
5451 text , prediction_result = element
5552 logits = prediction_result .inference ['logits' ]
@@ -62,16 +59,35 @@ def process(self, element: tuple[str, PredictionResult]) -> Iterable[dict]:
6259 }
6360
6461
65- def tokenize_text (text : str ,
66- tokenizer : DistilBertTokenizerFast ) -> tuple [str , dict ]:
67- """Tokenizes input text using the specified tokenizer."""
68- tokenized = tokenizer (
69- text ,
70- padding = 'max_length' ,
71- truncation = True ,
72- max_length = 128 ,
73- return_tensors = "pt" )
74- return text , {k : torch .squeeze (v ) for k , v in tokenized .items ()}
62+ class TokenizeTextDoFn (beam .DoFn ):
63+ """Initializes tokenizer per worker and tokenizes input text."""
64+ def __init__ (self , model_path : str ):
65+ self .model_path = model_path
66+ self .tokenizer = None
67+
68+ def setup (self ):
69+ self .tokenizer = DistilBertTokenizerFast .from_pretrained (self .model_path )
70+ # Some transformers builds expose pad token through legacy attributes.
71+ if not hasattr (self .tokenizer , '_pad_token' ):
72+ self .tokenizer ._pad_token = '[PAD]'
73+
74+ def process (self , text : str ) -> Iterable [tuple [str , dict ]]:
75+ tokenized = self .tokenizer (
76+ text ,
77+ padding = 'max_length' ,
78+ truncation = True ,
79+ max_length = 128 ,
80+ return_tensors = "pt" )
81+ yield text , {k : torch .squeeze (v ) for k , v in tokenized .items ()}
82+
83+
84+ class DistilBertForSequenceClassificationCompat (
85+ DistilBertForSequenceClassification ):
86+ """Builds config in worker runtime to avoid cross-env config drift."""
87+ def __init__ (self , model_name : str , num_labels : int = 2 ):
88+ config = _ensure_transformers_config_compat (
89+ DistilBertConfig .from_pretrained (model_name , num_labels = num_labels ))
90+ super ().__init__ (config )
7591
7692
7793class RateLimitDoFn (beam .DoFn ):
@@ -83,6 +99,30 @@ def process(self, element):
8399 yield element
84100
85101
102+ def _ensure_transformers_config_compat (config : DistilBertConfig ) -> DistilBertConfig :
103+ """Adds missing config attributes for cross-version transformers compatibility.
104+
105+ The benchmark can run with container images whose transformers version differs
106+ from the launcher environment. Some versions assume these attributes exist.
107+ """
108+ # Use a default config instance as the source of canonical attributes for the
109+ # transformers version available on the worker. This avoids chasing one
110+ # missing field at a time (e.g. torchscript, output_attentions).
111+ default_config = DistilBertConfig ()
112+ for key , value in default_config .to_dict ().items ():
113+ if not hasattr (config , key ):
114+ setattr (config , key , value )
115+
116+ # Keep non-serialized fields explicitly for older/newer transformers mixes.
117+ if not hasattr (config , 'pruned_heads' ):
118+ config .pruned_heads = {}
119+ if not hasattr (config , 'torchscript' ):
120+ config .torchscript = False
121+ if not hasattr (config , 'return_dict' ):
122+ config .return_dict = True
123+ return config
124+
125+
86126def parse_known_args (argv ):
87127 """Parses command-line arguments for pipeline execution."""
88128 parser = argparse .ArgumentParser ()
@@ -235,13 +275,14 @@ def run(
235275 pipeline_options .view_as (StandardOptions ).streaming = True
236276
237277 model_handler = PytorchModelHandlerKeyedTensor (
238- model_class = DistilBertForSequenceClassification ,
239- model_params = {'config' : DistilBertConfig (num_labels = 2 )},
278+ model_class = DistilBertForSequenceClassificationCompat ,
279+ model_params = {
280+ 'model_name' : known_args .model_path ,
281+ 'num_labels' : 2 ,
282+ },
240283 state_dict_path = known_args .model_state_dict_path ,
241284 device = 'GPU' )
242285
243- tokenizer = DistilBertTokenizerFast .from_pretrained (known_args .model_path )
244-
245286 pipeline = test_pipeline or beam .Pipeline (options = pipeline_options )
246287
247288 # Main pipeline: read, process, write result to BigQuery output table
@@ -264,9 +305,9 @@ def run(
264305
265306 _ = (
266307 input
267- | 'Tokenize' >> beam .Map ( lambda text : tokenize_text ( text , tokenizer ))
308+ | 'Tokenize' >> beam .ParDo ( TokenizeTextDoFn ( known_args . model_path ))
268309 | 'RunInference' >> RunInference (KeyedModelHandler (model_handler ))
269- | 'PostProcess' >> beam .ParDo (SentimentPostProcessor (tokenizer ))
310+ | 'PostProcess' >> beam .ParDo (SentimentPostProcessor ())
270311 | 'WriteToBigQuery' >> beam .io .WriteToBigQuery (
271312 known_args .output_table ,
272313 schema = 'text:STRING, sentiment:STRING, confidence:FLOAT' ,
@@ -275,8 +316,12 @@ def run(
275316 method = method ))
276317
277318 result = pipeline .run ()
319+ # For streaming benchmarks, run for a bounded duration, then cancel and wait
320+ # until terminal state so benchmark harness does not observe a lingering
321+ # CANCELLING job.
278322 result .wait_until_finish (duration = 1800000 ) # 30 min
279323 result .cancel ()
324+ result .wait_until_finish (duration = 600000 ) # up to 10 min to settle cancel
280325
281326 cleanup_pubsub_resources (
282327 project = known_args .project ,
0 commit comments