@@ -96,22 +96,21 @@ def _ensure_transformers_config_compat(config: DistilBertConfig) -> DistilBertCo
9696 The benchmark can run with container images whose transformers version differs
9797 from the launcher environment. Some versions assume these attributes exist.
9898 """
99+ # Use a default config instance as the source of canonical attributes for the
100+ # transformers version available on the worker. This avoids chasing one
101+ # missing field at a time (e.g. torchscript, output_attentions).
102+ default_config = DistilBertConfig ()
103+ for key , value in default_config .to_dict ().items ():
104+ if not hasattr (config , key ):
105+ setattr (config , key , value )
106+
107+ # Keep non-serialized fields explicitly for older/newer transformers mixes.
99108 if not hasattr (config , 'pruned_heads' ):
100109 config .pruned_heads = {}
101110 if not hasattr (config , 'torchscript' ):
102111 config .torchscript = False
103112 if not hasattr (config , 'return_dict' ):
104113 config .return_dict = True
105- if not hasattr (config , 'output_attentions' ):
106- config .output_attentions = False
107- if not hasattr (config , 'output_hidden_states' ):
108- config .output_hidden_states = False
109- if not hasattr (config , 'use_cache' ):
110- config .use_cache = False
111- if not hasattr (config , 'is_decoder' ):
112- config .is_decoder = False
113- if not hasattr (config , 'add_cross_attention' ):
114- config .add_cross_attention = False
115114 return config
116115
117116
0 commit comments