@@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
1818 config .num_decoder_layers = min (config .num_decoder_layers , 2 )
1919 if hasattr (config , "num_hidden_layers" ):
2020 config .num_hidden_layers = min (config .num_hidden_layers , nhl ())
21+ if hasattr (config , "encoder" ) and hasattr (config .encoder , "layer_types" ):
22+ default_layer_types = [
23+ "sliding_attention" ,
24+ "full_attention" ,
25+ "sliding_attention" ,
26+ "full_attention" ,
27+ ]
28+ config .encoder .num_hidden_layers = 4
29+ config .encoder .layer_types = (
30+ default_layer_types if config is None else config .encoder .layer_types [:4 ]
31+ )
32+ config .decoder .num_hidden_layers = 4
33+ config .decoder .layer_types = (
34+ default_layer_types if config is None else config .decoder .layer_types [:4 ]
35+ )
36+
2137 update_config (config , kwargs )
2238 return kwargs
2339
@@ -178,54 +194,74 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
178194 If the configuration is None, the function selects typical dimensions.
179195 """
180196 if config is not None :
181- check_hasattr (
182- config ,
183- "vocab_size" ,
184- "hidden_size" ,
185- "num_attention_heads" ,
186- ("num_hidden_layers" , "num_layers" ),
187- ("n_positions" , "d_model" ),
188- (
189- "num_key_value_heads" ,
190- "num_heads" ,
191- ("decoder_attention_heads" , "encoder_attention_heads" ),
192- ),
193- )
194- # exceptions = {
195- # "PLBartForConditionalGeneration": (
196- # lambda c: c.encoder_attention_heads + c.decoder_attention_heads
197- # )
198- # }
199- kwargs = dict (
200- batch_size = 2 ,
201- sequence_length = 30 ,
202- sequence_length2 = 3 ,
203- head_dim_encoder = 16 if config is None else _pick (config , "d_kv" , "encoder_ffn_dim" ),
204- head_dim_decoder = 16 if config is None else _pick (config , "d_kv" , "decoder_ffn_dim" ),
205- dummy_max_token_id = 31999 if config is None else config .vocab_size - 1 ,
206- num_hidden_layers = (
207- 8 if config is None else _pick (config , "num_hidden_layers" , "num_layers" )
208- ),
209- num_key_value_heads_encoder = (
210- 16
211- if config is None
212- else _pick (
197+ if hasattr (config , "num_attention_heads" ):
198+ check_hasattr (
213199 config ,
214- "encoder_attention_heads" ,
215- "num_key_value_heads" ,
216- "num_heads" ,
200+ "vocab_size" ,
201+ "hidden_size" ,
202+ "num_attention_heads" ,
203+ ("num_hidden_layers" , "num_layers" ),
204+ ("n_positions" , "d_model" ),
205+ (
206+ "num_key_value_heads" ,
207+ "num_heads" ,
208+ ("decoder_attention_heads" , "encoder_attention_heads" ),
209+ ),
217210 )
218- ),
219- num_key_value_heads_decoder = (
220- 16
221- if config is None
222- else _pick (
223- config ,
224- "decoder_attention_heads" ,
225- "num_key_value_heads" ,
226- "num_heads" ,
227- )
228- ),
229- encoder_dim = 512 if config is None else _pick (config , "n_positions" , "d_model" ),
230- )
211+ path = 1
212+ else :
213+ check_hasattr (config , "encoder" , "decoder" )
214+ path = 2
215+
216+ if path == 1 :
217+ kwargs = dict (
218+ batch_size = 2 ,
219+ sequence_length = 30 ,
220+ sequence_length2 = 3 ,
221+ head_dim_encoder = (
222+ 16 if config is None else _pick (config , "d_kv" , "encoder_ffn_dim" )
223+ ),
224+ head_dim_decoder = (
225+ 16 if config is None else _pick (config , "d_kv" , "decoder_ffn_dim" )
226+ ),
227+ dummy_max_token_id = 31999 if config is None else config .vocab_size - 1 ,
228+ num_hidden_layers = (
229+ 8 if config is None else _pick (config , "num_hidden_layers" , "num_layers" )
230+ ),
231+ num_key_value_heads_encoder = (
232+ 16
233+ if config is None
234+ else _pick (
235+ config ,
236+ "encoder_attention_heads" ,
237+ "num_key_value_heads" ,
238+ "num_heads" ,
239+ )
240+ ),
241+ num_key_value_heads_decoder = (
242+ 16
243+ if config is None
244+ else _pick (
245+ config ,
246+ "decoder_attention_heads" ,
247+ "num_key_value_heads" ,
248+ "num_heads" ,
249+ )
250+ ),
251+ encoder_dim = 512 if config is None else _pick (config , "n_positions" , "d_model" ),
252+ )
253+ else :
254+ kwargs = dict (
255+ batch_size = 2 ,
256+ sequence_length = 30 ,
257+ sequence_length2 = 3 ,
258+ dummy_max_token_id = config .encoder .vocab_size - 1 ,
259+ num_key_value_heads_encoder = config .encoder .num_key_value_heads ,
260+ num_key_value_heads_decoder = config .decoder .num_key_value_heads ,
261+ num_hidden_layers = len (config .encoder .layer_types ),
262+ head_dim_encoder = config .encoder .head_dim ,
263+ head_dim_decoder = config .decoder .head_dim ,
264+ encoder_dim = 256 ,
265+ )
266+
231267 return kwargs , get_inputs
0 commit comments