Skip to content

Commit d7ee9b0

Browse files
Fix the export issue
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
1 parent c7549b8 commit d7ee9b0

6 files changed

Lines changed: 203 additions & 3 deletions

File tree

nemo_export/model_adapters/embedding/embedding_adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import torch.nn.functional as F
2121
from transformers import AutoModel, AutoTokenizer
2222

23+
from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export
24+
2325

2426
class LlamaBidirectionalHFAdapter(torch.nn.Module):
2527
"""
@@ -266,5 +268,9 @@ def get_llama_bidirectional_hf_model(
266268
if attn_implementation:
267269
model.config._attn_implementation = attn_implementation
268270

271+
if attn_implementation:
272+
# Replace the transformers>=5.0 bidirectional mask builder, which is not ONNX-traceable.
273+
patch_bidirectional_mask_for_export(model)
274+
269275
adapted_model = LlamaBidirectionalHFAdapter(model=model, normalize=normalize, pooling_module=pooling_module)
270276
return adapted_model, tokenizer
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import types
17+
18+
import torch
19+
20+
21+
def patch_bidirectional_mask_for_export(model: torch.nn.Module) -> bool:
22+
"""Override LlamaBidirectional ``_create_bidirectional_mask`` with a trace-friendly version.
23+
24+
The ``create_bidirectional_mask`` helper in transformers>=5.0 is not traceable by the
25+
TorchScript ONNX exporter: under tracing it dispatches into ``sdpa_mask`` (even with
26+
``attn_implementation="eager"``, since ``eager_mask`` reuses ``sdpa_mask``) and crashes with
27+
``IndexError: tuple index out of range`` while converting the deprecated ``cache_position``
28+
argument. Since ONNX export uses eager attention, we build the additive 4D mask directly,
29+
which is numerically equivalent for fully-bidirectional attention and traces cleanly.
30+
31+
The replacement is bound to every submodule that defines ``_create_bidirectional_mask`` so it
32+
works whether the method lives on the top-level model (embedding) or a nested backbone
33+
(reranker: ``LlamaBidirectionalForSequenceClassification.model``).
34+
35+
Args:
36+
model: The loaded HuggingFace model (or wrapper) to patch in place.
37+
38+
Returns:
39+
bool: True if at least one module was patched, False otherwise.
40+
"""
41+
42+
def _create_bidirectional_mask(self, input_embeds, attention_mask):
43+
if attention_mask is None:
44+
return None
45+
dtype = input_embeds.dtype
46+
expanded = attention_mask[:, None, None, :].to(dtype) # (batch, 1, 1, seq_len)
47+
return (1.0 - expanded) * torch.finfo(dtype).min
48+
49+
patched = False
50+
for module in model.modules():
51+
if hasattr(type(module), "_create_bidirectional_mask"):
52+
module._create_bidirectional_mask = types.MethodType(_create_bidirectional_mask, module)
53+
patched = True
54+
return patched

nemo_export/model_adapters/reranker/reranker_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from transformers import AutoModelForSequenceClassification, AutoTokenizer
2020

21+
from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export
22+
2123

2224
class SequenceClassificationModelAdapterWithoutTypeIds(torch.nn.Module):
2325
"""Adapter for sequence classification models that don't use token type IDs.
@@ -137,6 +139,8 @@ def get_llama_reranker_hf_model(
137139
# reset config to handle case where config is mutated after init
138140
# TODO: remove when we're no longer using Llama 3.1 model with `_attn_implementation` set in __init__ method.
139141
model.config._attn_implementation = attn_implementation
142+
# Replace the transformers>=5.0 bidirectional mask builder, which is not ONNX-traceable.
143+
patch_bidirectional_mask_for_export(model)
140144

141145
tokenizer = AutoTokenizer.from_pretrained(
142146
model_name_or_path,

tests/unit_tests/export/model_adapters/embedding/test_embedding_adapter.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,10 @@ def test_get_model_with_trust_remote_code(self, mock_auto_model, mock_auto_token
341341
mock_auto_tokenizer.from_pretrained.assert_called_once_with("test/model", trust_remote_code=True)
342342
mock_auto_model.from_pretrained.assert_called_once_with("test/model", torch_dtype=None, trust_remote_code=True)
343343

344+
@patch("nemo_export.model_adapters.embedding.embedding_adapter.patch_bidirectional_mask_for_export")
344345
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer")
345346
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel")
346-
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer):
347+
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask):
347348
"""Test model loading with a specific attention implementation."""
348349
mock_tokenizer = Mock()
349350
mock_tokenizer.padding_side = "right"
@@ -365,9 +366,31 @@ def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tok
365366
"test/model", torch_dtype=None, trust_remote_code=False, attn_implementation="eager"
366367
)
367368
assert mock_config._attn_implementation == "eager"
369+
# The bidirectional mask builder is patched for ONNX export compatibility.
370+
mock_patch_mask.assert_called_once_with(mock_model)
368371
assert isinstance(adapted_model, LlamaBidirectionalHFAdapter)
369372
assert tokenizer == mock_tokenizer
370373

374+
@patch("nemo_export.model_adapters.embedding.embedding_adapter.patch_bidirectional_mask_for_export")
375+
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer")
376+
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel")
377+
def test_get_model_without_attn_implementation_skips_mask_patch(
378+
self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask
379+
):
380+
"""The mask builder must not be patched when no attention implementation is requested."""
381+
mock_tokenizer = Mock()
382+
mock_tokenizer.padding_side = "right"
383+
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
384+
385+
mock_model = Mock()
386+
mock_model.config = Mock()
387+
mock_model.eval.return_value = mock_model
388+
mock_auto_model.from_pretrained.return_value = mock_model
389+
390+
get_llama_bidirectional_hf_model(model_name_or_path="test/model", normalize=True)
391+
392+
mock_patch_mask.assert_not_called()
393+
371394
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer")
372395
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel")
373396
def test_get_model_pooling_mode_adjustment_last(self, mock_auto_model, mock_auto_tokenizer):

tests/unit_tests/export/model_adapters/reranker/test_reranker_adapter.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,10 @@ def test_get_model_with_trust_remote_code(self, mock_auto_model, mock_auto_token
269269
# Verify tokenizer loading with trust_remote_code
270270
mock_auto_tokenizer.from_pretrained.assert_called_once_with("test-model", trust_remote_code=True)
271271

272+
@patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export")
272273
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
273274
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
274-
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer):
275+
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask):
275276
"""Test loading a model with specific attention implementation."""
276277
# Setup mocks
277278
mock_model = Mock()
@@ -295,6 +296,28 @@ def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tok
295296

296297
# Verify config is reset after init
297298
assert mock_config._attn_implementation == attn_impl
299+
# The bidirectional mask builder is patched for ONNX export compatibility.
300+
mock_patch_mask.assert_called_once_with(mock_model)
301+
302+
@patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export")
303+
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
304+
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
305+
def test_get_model_without_attn_implementation_skips_mask_patch(
306+
self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask
307+
):
308+
"""The mask builder must not be patched when no attention implementation is requested."""
309+
mock_model = Mock()
310+
mock_model.config = Mock()
311+
mock_model.eval.return_value = mock_model
312+
mock_auto_model.from_pretrained.return_value = mock_model
313+
314+
mock_tokenizer = Mock()
315+
mock_tokenizer.model_input_names = ["input_ids", "attention_mask"]
316+
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
317+
318+
get_llama_reranker_hf_model("test-model")
319+
320+
mock_patch_mask.assert_not_called()
298321

299322
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
300323
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
@@ -325,9 +348,10 @@ def test_get_model_with_pathlike_input(self, mock_auto_model, mock_auto_tokenize
325348
# Verify tokenizer loading
326349
mock_auto_tokenizer.from_pretrained.assert_called_once_with(model_path, trust_remote_code=False)
327350

351+
@patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export")
328352
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
329353
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
330-
def test_get_model_all_parameters(self, mock_auto_model, mock_auto_tokenizer):
354+
def test_get_model_all_parameters(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask):
331355
"""Test loading a model with all parameters specified."""
332356
# Setup mocks
333357
mock_model = Mock()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export
18+
19+
20+
class _BidirectionalModule(torch.nn.Module):
21+
"""Minimal stand-in for a LlamaBidirectionalModel exposing the patched method."""
22+
23+
def _create_bidirectional_mask(self, input_embeds, attention_mask):
24+
# Original (sentinel) implementation that the patch must replace.
25+
return "original"
26+
27+
28+
class _Wrapper(torch.nn.Module):
29+
"""Stand-in for the reranker layout where the method lives on a nested backbone."""
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.model = _BidirectionalModule()
34+
35+
36+
class TestPatchBidirectionalMaskForExport:
37+
"""Test cases for patch_bidirectional_mask_for_export."""
38+
39+
def test_patches_top_level_module(self):
40+
model = _BidirectionalModule()
41+
assert patch_bidirectional_mask_for_export(model) is True
42+
43+
input_embeds = torch.zeros(1, 3, 4)
44+
# The replacement no longer returns the sentinel.
45+
assert model._create_bidirectional_mask(input_embeds, None) is None
46+
47+
def test_patches_nested_module(self):
48+
wrapper = _Wrapper()
49+
assert patch_bidirectional_mask_for_export(wrapper) is True
50+
51+
input_embeds = torch.zeros(1, 3, 4)
52+
assert wrapper.model._create_bidirectional_mask(input_embeds, None) is None
53+
54+
def test_returns_false_when_method_absent(self):
55+
model = torch.nn.Linear(4, 4)
56+
assert patch_bidirectional_mask_for_export(model) is False
57+
58+
def test_mask_none_returns_none(self):
59+
model = _BidirectionalModule()
60+
patch_bidirectional_mask_for_export(model)
61+
assert model._create_bidirectional_mask(torch.zeros(1, 2, 4), None) is None
62+
63+
def test_additive_mask_values(self):
64+
model = _BidirectionalModule()
65+
patch_bidirectional_mask_for_export(model)
66+
67+
dtype = torch.float32
68+
input_embeds = torch.zeros(1, 3, 4, dtype=dtype)
69+
attention_mask = torch.tensor([[1, 1, 0]])
70+
71+
mask = model._create_bidirectional_mask(input_embeds, attention_mask)
72+
73+
assert mask.shape == (1, 1, 1, 3)
74+
assert mask.dtype == dtype
75+
# Real positions are unmasked (0.0); the padded position gets the dtype minimum.
76+
assert mask[0, 0, 0, 0].item() == 0.0
77+
assert mask[0, 0, 0, 1].item() == 0.0
78+
assert mask[0, 0, 0, 2].item() == torch.finfo(dtype).min
79+
80+
def test_additive_mask_matches_input_dtype(self):
81+
model = _BidirectionalModule()
82+
patch_bidirectional_mask_for_export(model)
83+
84+
input_embeds = torch.zeros(1, 2, 4, dtype=torch.float16)
85+
attention_mask = torch.tensor([[1, 0]])
86+
87+
mask = model._create_bidirectional_mask(input_embeds, attention_mask)
88+
assert mask.dtype == torch.float16
89+
assert mask[0, 0, 0, 1].item() == torch.finfo(torch.float16).min

0 commit comments

Comments
 (0)