Skip to content

Commit 7a0d05b

Browse files
committed
fix
1 parent d90592a commit 7a0d05b

1 file changed

Lines changed: 9 additions & 10 deletions

File tree

sdks/python/apache_beam/examples/inference/pytorch_sentiment.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,21 @@ def _ensure_transformers_config_compat(config: DistilBertConfig) -> DistilBertCo
9696
The benchmark can run with container images whose transformers version differs
9797
from the launcher environment. Some versions assume these attributes exist.
9898
"""
99+
# Use a default config instance as the source of canonical attributes for the
100+
# transformers version available on the worker. This avoids chasing one
101+
# missing field at a time (e.g. torchscript, output_attentions).
102+
default_config = DistilBertConfig()
103+
for key, value in default_config.to_dict().items():
104+
if not hasattr(config, key):
105+
setattr(config, key, value)
106+
107+
# Keep non-serialized fields explicitly for older/newer transformers mixes.
99108
if not hasattr(config, 'pruned_heads'):
100109
config.pruned_heads = {}
101110
if not hasattr(config, 'torchscript'):
102111
config.torchscript = False
103112
if not hasattr(config, 'return_dict'):
104113
config.return_dict = True
105-
if not hasattr(config, 'output_attentions'):
106-
config.output_attentions = False
107-
if not hasattr(config, 'output_hidden_states'):
108-
config.output_hidden_states = False
109-
if not hasattr(config, 'use_cache'):
110-
config.use_cache = False
111-
if not hasattr(config, 'is_decoder'):
112-
config.is_decoder = False
113-
if not hasattr(config, 'add_cross_attention'):
114-
config.add_cross_attention = False
115114
return config
116115

117116

0 commit comments

Comments
 (0)