Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patch_qwen2_5,
patch_funnel,
)
from onnx_diagnostic.export.api import to_onnx


Expand Down Expand Up @@ -787,6 +790,42 @@ def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
self.assertLess(results.diffs[0]["abs"], 1e-5)

@unittest.skipIf(not patch_funnel, "Funnel not part of this transformers")
def test_model_funnel(self):
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched_FunnelAttentionStructure,
patched_FunnelRelMultiheadAttention,
)

pos = torch.tensor([0, 4, 5, 8], dtype=torch.long)
stride = 2
config = transformers.models.funnel.modeling_funnel.FunnelConfig()
original = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure(config)
patched = patched_FunnelAttentionStructure()
self.assertEqualArray(
original.relative_pos(pos, stride=stride), patched.relative_pos(pos, stride=stride)
)

rmha = transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention(
config, 2
)
patched = patched_FunnelRelMultiheadAttention()
patched.config = config
for att in ["block_index", "r_r_bias", "scale", "r_kernel"]:
setattr(patched, att, getattr(rmha, att))
inputs = dict(
position_embeds=[
[torch.rand((24, 768)), None],
[torch.rand((12, 768)), torch.rand((24, 768))],
[torch.rand((6, 768)), torch.rand((12, 768))],
],
q_head=torch.rand((2, 12, 12, 64)),
context_len=12,
)
expected = rmha.relative_positional_attention(**inputs)
got = patched.relative_positional_attention(**inputs)
self.assertEqualArray(expected, got)


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 6 additions & 1 deletion onnx_diagnostic/ci_models/export_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import os
import sys
import time
import warnings
from typing import Any, Dict, List, Tuple
from .ci_helpers import (
check_for_discrepancies_and_log_everything_into_a_json_file,
Expand Down Expand Up @@ -301,7 +302,11 @@ def main(
print(f"-- config._attn_implementation={model.config._attn_implementation}")
print(f"-- model.dtype={model.dtype}")
print(f"-- model.device={model.device}")
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
try:
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
except OSError as e:
warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0)
return
print(f"-- processor={type(processor)}")

export_inputs, other_inputs = None, None
Expand Down
20 changes: 15 additions & 5 deletions onnx_diagnostic/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
__TASK__ = "image-text-to-text"


def should_have_vision_config(config):
return config.architectures != ["FuyuForCausalLM"]


def reduce_model_config(config: Any) -> Dict[str, Any]:
"""Reduces a model size."""
kwargs: Dict[str, Any] = {}
Expand Down Expand Up @@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"hidden_size",
"pad_token_id",
)
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
if should_have_vision_config(config):
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
text_config = True
else:
check_hasattr(
Expand All @@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"vision_config",
)
text_config = False
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
if should_have_vision_config(config):
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
kwargs = dict(
head_dim=(
16
Expand Down Expand Up @@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
),
width=(
224
if config is None or not hasattr(config.vision_config, "image_size")
if config is None
or not should_have_vision_config(config)
or not hasattr(config.vision_config, "image_size")
else config.vision_config.image_size
),
height=(
224
if config is None or not hasattr(config.vision_config, "image_size")
if config is None
or not should_have_vision_config(config)
or not hasattr(config.vision_config, "image_size")
else config.vision_config.image_size
),
num_channels=(
3
if config is None
if config is None or not should_have_vision_config(config)
else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
),
pad_token_id=(
Expand Down
132 changes: 84 additions & 48 deletions onnx_diagnostic/tasks/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
config.num_decoder_layers = min(config.num_decoder_layers, 2)
if hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"):
default_layer_types = [
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
]
config.encoder.num_hidden_layers = 4
config.encoder.layer_types = (
default_layer_types if config is None else config.encoder.layer_types[:4]
)
config.decoder.num_hidden_layers = 4
config.decoder.layer_types = (
default_layer_types if config is None else config.decoder.layer_types[:4]
)

update_config(config, kwargs)
return kwargs

Expand Down Expand Up @@ -177,55 +193,75 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

If the configuration is None, the function selects typical dimensions.
"""
path = 1
if config is not None:
check_hasattr(
config,
"vocab_size",
"hidden_size",
"num_attention_heads",
("num_hidden_layers", "num_layers"),
("n_positions", "d_model"),
(
"num_key_value_heads",
"num_heads",
("decoder_attention_heads", "encoder_attention_heads"),
),
)
# exceptions = {
# "PLBartForConditionalGeneration": (
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
# )
# }
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
num_hidden_layers=(
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
),
num_key_value_heads_encoder=(
16
if config is None
else _pick(
if hasattr(config, "num_attention_heads"):
check_hasattr(
config,
"encoder_attention_heads",
"num_key_value_heads",
"num_heads",
"vocab_size",
"hidden_size",
"num_attention_heads",
("num_hidden_layers", "num_layers"),
("n_positions", "d_model"),
(
"num_key_value_heads",
"num_heads",
("decoder_attention_heads", "encoder_attention_heads"),
),
)
),
num_key_value_heads_decoder=(
16
if config is None
else _pick(
config,
"decoder_attention_heads",
"num_key_value_heads",
"num_heads",
)
),
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
)
else:
check_hasattr(config, "encoder", "decoder")
path = 2

if path == 1:
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
head_dim_encoder=(
16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim")
),
head_dim_decoder=(
16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim")
),
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
num_hidden_layers=(
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
),
num_key_value_heads_encoder=(
16
if config is None
else _pick(
config,
"encoder_attention_heads",
"num_key_value_heads",
"num_heads",
)
),
num_key_value_heads_decoder=(
16
if config is None
else _pick(
config,
"decoder_attention_heads",
"num_key_value_heads",
"num_heads",
)
),
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
)
else:
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
dummy_max_token_id=config.encoder.vocab_size - 1,
num_key_value_heads_encoder=config.encoder.num_key_value_heads,
num_key_value_heads_decoder=config.decoder.num_key_value_heads,
num_hidden_layers=len(config.encoder.layer_types),
head_dim_encoder=config.encoder.head_dim,
head_dim_decoder=config.decoder.head_dim,
encoder_dim=256,
)

return kwargs, get_inputs
3 changes: 3 additions & 0 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
state_size=8 if config is None else getattr(config, "state_size", None),
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
)
elif config.__class__.__name__ == "FunnelConfig":
# does not support num_hidden_layers
kwargs = dict()
else:
kwargs = dict(
head_dim=getattr(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch

try:
import transformers.models.funnel.modeling_funnel

patch_funnel = True
except ImportError:
patch_funnel = False

if patch_funnel:
from transformers.models.funnel.modeling_funnel import _relative_shift_gather

class patched_FunnelAttentionStructure(torch.nn.Module):
_PATCHES_ = ["relative_pos"]
_PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure

def relative_pos(
self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
) -> torch.Tensor:
if pooled_pos is None:
pooled_pos = pos
ref_point = pooled_pos[0] - pos[0]
# PATCHED
num_remove = shift * pooled_pos.shape[0]
max_dist = ref_point + num_remove * stride
min_dist = pooled_pos[0] - pos[-1]
return torch.arange(
max_dist.to(torch.long),
(min_dist - 1).to(torch.long),
torch.tensor(-stride, dtype=torch.long),
dtype=torch.long,
device=pos.device,
)

class patched_FunnelRelMultiheadAttention(torch.nn.Module):
_PATCHES_ = ["relative_positional_attention"]
_PATCHED_CLASS_ = (
transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
)

def relative_positional_attention(
self, position_embeds, q_head, context_len, cls_mask=None
):
"""Relative attention score for the positional encodings"""
# q_head has shape batch_size x sea_len x n_head x d_head
if self.config.attention_type == "factorized":
phi, pi, psi, omega = position_embeds
# Shape n_head x d_head
u = self.r_r_bias * self.scale
# Shape d_model x n_head x d_head
w_r = self.r_kernel

# Shape batch_size x sea_len x n_head x d_model
q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
q_r_attention_1 = q_r_attention * phi[:, None]
q_r_attention_2 = q_r_attention * pi[:, None]

# Shape batch_size x n_head x seq_len x context_len
positional_attn = torch.einsum(
"bind,jd->bnij", q_r_attention_1, psi
) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
else:
shift = 2 if q_head.shape[1] != context_len else 1
r = position_embeds[self.block_index][shift - 1]
# Shape n_head x d_head
v = self.r_r_bias * self.scale
# Shape d_model x n_head x d_head
w_r = self.r_kernel

# Shape max_rel_len x n_head x d_model
r_head = torch.einsum("td,dnh->tnh", r, w_r)
# Shape batch_size x n_head x seq_len x max_rel_len
positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
# Shape batch_size x n_head x seq_len x context_len
positional_attn = _relative_shift_gather(positional_attn, context_len, shift)

if cls_mask is not None:
# PATCHED
positional_attn = positional_attn * cls_mask
return positional_attn
Loading
Loading