@@ -144,17 +144,18 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
144144 """
145145 See
146146 :func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
147+ If caches are used, it requires ``transformers>=4.57``.
147148 """
148149 if x is None :
149150 return None , None
150- if isinstance ( x , (list , tuple ) ):
151+ if type ( x ) in (list , tuple ):
151152 return x .__class__ (
152153 [
153154 self .make_fake_with_dynamic_dimensions (i , dynamic_shapes = ds )
154155 for i , ds in zip (x , dynamic_shapes )
155156 ]
156157 )
157- if isinstance ( x , dict ) :
158+ if type ( x ) is dict :
158159 return {
159160 k : self .make_fake_with_dynamic_dimensions (v , dynamic_shapes = dynamic_shapes [k ])
160161 for k , v in x .items ()
@@ -187,6 +188,17 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
187188 x .cross_attention_cache , dynamic_shapes = dynamic_shapes [1 ]
188189 )
189190 return x
191+ if x .__class__ .__name__ == "BaseModelOutput" :
192+ assert (
193+ list (x .keys ()) == ["last_hidden_state" ] and x .last_hidden_state is not None
194+ ), (
195+ f"Field 'last_hidden_state' is empty for { type (x )} or other fields "
196+ f"{ list (x .keys ())} are used."
197+ )
198+ x .last_hidden_state = self .make_fake_with_dynamic_dimensions (
199+ x .last_hidden_state , dynamic_shapes = dynamic_shapes [0 ]
200+ )
201+ return x
190202 if hasattr (x , "shape" ):
191203 assert dynamic_shapes is None or isinstance (dynamic_shapes , dict ), (
192204 f"dynamic_shapes must be a dictionary at this stage but "
0 commit comments