File tree Expand file tree Collapse file tree
_unittests/ut_torch_export_patches
onnx_diagnostic/torch_export_patches/patches Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -984,12 +984,12 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
984984 self .assertTrue (model_inputs ["use_cache" ] is True )
985985 self .assertTrue (model_inputs ["foo" ] == "bar" )
986986
987+ init_input_ids = input_ids [:, :2 ]
988+ dynamic_cache = transformers .cache_utils .DynamicCache (config = config )
989+ dynamic_cache = model (
990+ init_input_ids , past_key_values = dynamic_cache
991+ ).past_key_values
987992 with self .subTest (case = "case5" ):
988- init_input_ids = input_ids [:, :2 ]
989- dynamic_cache = transformers .cache_utils .DynamicCache (config = config )
990- dynamic_cache = model (
991- init_input_ids , past_key_values = dynamic_cache
992- ).past_key_values
993993 with self .assertRaises ((AttributeError , TypeError )):
994994 model_inputs = model .prepare_inputs_for_generation (
995995 input_ids , past_key_values = dynamic_cache
Original file line number Diff line number Diff line change 11import inspect
2- import os
3- from typing import Optional , Tuple , Union
2+ from typing import Optional , Tuple
43import packaging .version as pv
54import torch
65import transformers
@@ -22,11 +21,11 @@ class patched_GenerationMixin:
2221 if pv .Version (transformers .__version__ ) >= pv .Version ("4.56" )
2322 else "prepare_inputs_for_generation"
2423 ),
25- (
26- "_sample"
27- if pv .Version (transformers .__version__ ) == pv .Version ("4.57.0.dev0" )
28- else None
29- ),
24+ # (
25+ # "_sample"
26+ # if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
27+ # else None
28+ # ),
3029 ]
3130 _PATCHED_CLASS_ = transformers .generation .utils .GenerationMixin
3231
@@ -299,6 +298,8 @@ def prepare_inputs_for_generation(
299298 model_inputs .pop ("labels" , None )
300299 return model_inputs
301300
301+ '''
302+ # drops a patch since it is for a very specific version.
302303 def _sample(
303304 self,
304305 input_ids: torch.LongTensor,
@@ -484,3 +485,4 @@ def _sample(
484485 )
485486 else:
486487 return input_ids
488+ '''
You can’t perform that action at this time.
0 commit comments