@@ -81,6 +81,15 @@ def process(self, text: str) -> Iterable[tuple[str, dict]]:
8181 yield text , {k : torch .squeeze (v ) for k , v in tokenized .items ()}
8282
8383
84+ class DistilBertForSequenceClassificationCompat (
85+ DistilBertForSequenceClassification ):
86+ """Builds config in worker runtime to avoid cross-env config drift."""
87+ def __init__ (self , model_name : str , num_labels : int = 2 ):
88+ config = _ensure_transformers_config_compat (
89+ DistilBertConfig .from_pretrained (model_name , num_labels = num_labels ))
90+ super ().__init__ (config )
91+
92+
8493class RateLimitDoFn (beam .DoFn ):
8594 def __init__ (self , rate_per_sec : float ):
8695 self .delay = 1.0 / rate_per_sec
@@ -265,12 +274,12 @@ def run(
265274 method = beam .io .WriteToBigQuery .Method .STREAMING_INSERTS
266275 pipeline_options .view_as (StandardOptions ).streaming = True
267276
268- model_config = _ensure_transformers_config_compat (
269- DistilBertConfig .from_pretrained (known_args .model_path , num_labels = 2 ))
270-
271277 model_handler = PytorchModelHandlerKeyedTensor (
272- model_class = DistilBertForSequenceClassification ,
273- model_params = {'config' : model_config },
278+ model_class = DistilBertForSequenceClassificationCompat ,
279+ model_params = {
280+ 'model_name' : known_args .model_path ,
281+ 'num_labels' : 2 ,
282+ },
274283 state_dict_path = known_args .model_state_dict_path ,
275284 device = 'GPU' )
276285
0 commit comments