Skip to content

Commit 919b2e6

Browse files
committed
first step for t5gemma
1 parent 3d96a61 commit 919b2e6

3 files changed

Lines changed: 100 additions & 53 deletions

File tree

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
__TASK__ = "image-text-to-text"
1414

1515

16+
def should_have_vision_config(config):
17+
return config.architectures != ["FuyuForCausalLM"]
18+
19+
1620
def reduce_model_config(config: Any) -> Dict[str, Any]:
1721
"""Reduces a model size."""
1822
kwargs: Dict[str, Any] = {}
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
477481
"hidden_size",
478482
"pad_token_id",
479483
)
480-
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
484+
if should_have_vision_config(config):
485+
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
481486
text_config = True
482487
else:
483488
check_hasattr(
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
491496
"vision_config",
492497
)
493498
text_config = False
494-
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
499+
if should_have_vision_config(config):
500+
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
495501
kwargs = dict(
496502
head_dim=(
497503
16
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
552558
),
553559
width=(
554560
224
555-
if config is None or not hasattr(config.vision_config, "image_size")
561+
if config is None
562+
or not should_have_vision_config(config)
563+
or not hasattr(config.vision_config, "image_size")
556564
else config.vision_config.image_size
557565
),
558566
height=(
559567
224
560-
if config is None or not hasattr(config.vision_config, "image_size")
568+
if config is None
569+
or not should_have_vision_config(config)
570+
or not hasattr(config.vision_config, "image_size")
561571
else config.vision_config.image_size
562572
),
563573
num_channels=(
564574
3
565-
if config is None
575+
if config is None or not should_have_vision_config(config)
566576
else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
567577
),
568578
pad_token_id=(

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
1818
config.num_decoder_layers = min(config.num_decoder_layers, 2)
1919
if hasattr(config, "num_hidden_layers"):
2020
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
21+
if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"):
22+
default_layer_types = [
23+
"sliding_attention",
24+
"full_attention",
25+
"sliding_attention",
26+
"full_attention",
27+
]
28+
config.encoder.num_hidden_layers = 4
29+
config.encoder.layer_types = (
30+
default_layer_types if config is None else config.encoder.layer_types[:4]
31+
)
32+
config.decoder.num_hidden_layers = 4
33+
config.decoder.layer_types = (
34+
default_layer_types if config is None else config.decoder.layer_types[:4]
35+
)
36+
2137
update_config(config, kwargs)
2238
return kwargs
2339

@@ -178,54 +194,74 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
178194
If the configuration is None, the function selects typical dimensions.
179195
"""
180196
if config is not None:
181-
check_hasattr(
182-
config,
183-
"vocab_size",
184-
"hidden_size",
185-
"num_attention_heads",
186-
("num_hidden_layers", "num_layers"),
187-
("n_positions", "d_model"),
188-
(
189-
"num_key_value_heads",
190-
"num_heads",
191-
("decoder_attention_heads", "encoder_attention_heads"),
192-
),
193-
)
194-
# exceptions = {
195-
# "PLBartForConditionalGeneration": (
196-
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
197-
# )
198-
# }
199-
kwargs = dict(
200-
batch_size=2,
201-
sequence_length=30,
202-
sequence_length2=3,
203-
head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
204-
head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
205-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
206-
num_hidden_layers=(
207-
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
208-
),
209-
num_key_value_heads_encoder=(
210-
16
211-
if config is None
212-
else _pick(
197+
if hasattr(config, "num_attention_heads"):
198+
check_hasattr(
213199
config,
214-
"encoder_attention_heads",
215-
"num_key_value_heads",
216-
"num_heads",
200+
"vocab_size",
201+
"hidden_size",
202+
"num_attention_heads",
203+
("num_hidden_layers", "num_layers"),
204+
("n_positions", "d_model"),
205+
(
206+
"num_key_value_heads",
207+
"num_heads",
208+
("decoder_attention_heads", "encoder_attention_heads"),
209+
),
217210
)
218-
),
219-
num_key_value_heads_decoder=(
220-
16
221-
if config is None
222-
else _pick(
223-
config,
224-
"decoder_attention_heads",
225-
"num_key_value_heads",
226-
"num_heads",
227-
)
228-
),
229-
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
230-
)
211+
path = 1
212+
else:
213+
check_hasattr(config, "encoder", "decoder")
214+
path = 2
215+
216+
if path == 1:
217+
kwargs = dict(
218+
batch_size=2,
219+
sequence_length=30,
220+
sequence_length2=3,
221+
head_dim_encoder=(
222+
16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim")
223+
),
224+
head_dim_decoder=(
225+
16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim")
226+
),
227+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
228+
num_hidden_layers=(
229+
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
230+
),
231+
num_key_value_heads_encoder=(
232+
16
233+
if config is None
234+
else _pick(
235+
config,
236+
"encoder_attention_heads",
237+
"num_key_value_heads",
238+
"num_heads",
239+
)
240+
),
241+
num_key_value_heads_decoder=(
242+
16
243+
if config is None
244+
else _pick(
245+
config,
246+
"decoder_attention_heads",
247+
"num_key_value_heads",
248+
"num_heads",
249+
)
250+
),
251+
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
252+
)
253+
else:
254+
kwargs = dict(
255+
batch_size=2,
256+
sequence_length=30,
257+
sequence_length2=3,
258+
dummy_max_token_id=config.encoder.vocab_size - 1,
259+
num_key_value_heads_encoder=config.encoder.num_key_value_heads,
260+
num_key_value_heads_decoder=config.decoder.num_key_value_heads,
261+
num_hidden_layers=len(config.encoder.layer_types),
262+
head_dim_encoder=config.encoder.head_dim,
263+
head_dim_decoder=config.decoder.head_dim,
264+
encoder_dim=256,
265+
)
266+
231267
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
SwinModel,image-feature-extraction
141141
Swinv2Model,image-feature-extraction
142142
T5ForConditionalGeneration,text2text-generation
143+
T5GemmaForConditionalGeneration,text2text-generation
143144
TableTransformerModel,image-feature-extraction
144145
TableTransformerForObjectDetection,object-detection
145146
UNet2DConditionModel,text-to-image

0 commit comments

Comments
 (0)