Skip to content

Commit 3d84e35

Browse files
committed
fix fake tensors
1 parent f0a6ead commit 3d84e35

3 files changed

Lines changed: 16 additions & 2 deletions

File tree

_unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def test_n_validate_phi35_mini_instruct(self):
226226
self.clean_dump()
227227

228228
@hide_stdout()
229+
@requires_transformers("4.57")
229230
def test_o_validate_model_export_fake(self):
230231
mid = "arnir0/Tiny-LLM"
231232
summary, data = validate_model(

onnx_diagnostic/export/shape_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def make_fake_with_dynamic_dimensions(
210210
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211211
Parameter ``existing`` is used to reused the same object when the dynamic
212212
dimension is given the same name as another one.
213+
This function works with caches only if ``transformers>=4.57``.
213214
214215
A simple tensor:
215216

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,18 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
144144
"""
145145
See
146146
:func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
147+
If caches are used, it requires ``transformers>=4.57``.
147148
"""
148149
if x is None:
149150
return None, None
150-
if isinstance(x, (list, tuple)):
151+
if type(x) in (list, tuple):
151152
return x.__class__(
152153
[
153154
self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
154155
for i, ds in zip(x, dynamic_shapes)
155156
]
156157
)
157-
if isinstance(x, dict):
158+
if type(x) is dict:
158159
return {
159160
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
160161
for k, v in x.items()
@@ -187,6 +188,17 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
187188
x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
188189
)
189190
return x
191+
if x.__class__.__name__ == "BaseModelOutput":
192+
assert (
193+
list(x.keys()) == ["last_hidden_state"] and x.last_hidden_state is not None
194+
), (
195+
f"Field 'last_hidden_state' is empty for {type(x)} or other fields "
196+
f"{list(x.keys())} are used."
197+
)
198+
x.last_hidden_state = self.make_fake_with_dynamic_dimensions(
199+
x.last_hidden_state, dynamic_shapes=dynamic_shapes[0]
200+
)
201+
return x
190202
if hasattr(x, "shape"):
191203
assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
192204
f"dynamic_shapes must be a dictionary at this stage but "

0 commit comments

Comments
 (0)