Skip to content

Commit 1560c57

Browse files
committed
remove one patch
1 parent f2805e3 commit 1560c57

2 files changed

Lines changed: 14 additions & 12 deletions

File tree

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -984,12 +984,12 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
984984
self.assertTrue(model_inputs["use_cache"] is True)
985985
self.assertTrue(model_inputs["foo"] == "bar")
986986

987+
init_input_ids = input_ids[:, :2]
988+
dynamic_cache = transformers.cache_utils.DynamicCache(config=config)
989+
dynamic_cache = model(
990+
init_input_ids, past_key_values=dynamic_cache
991+
).past_key_values
987992
with self.subTest(case="case5"):
988-
init_input_ids = input_ids[:, :2]
989-
dynamic_cache = transformers.cache_utils.DynamicCache(config=config)
990-
dynamic_cache = model(
991-
init_input_ids, past_key_values=dynamic_cache
992-
).past_key_values
993993
with self.assertRaises((AttributeError, TypeError)):
994994
model_inputs = model.prepare_inputs_for_generation(
995995
input_ids, past_key_values=dynamic_cache

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import inspect
2-
import os
3-
from typing import Optional, Tuple, Union
2+
from typing import Optional, Tuple
43
import packaging.version as pv
54
import torch
65
import transformers
@@ -22,11 +21,11 @@ class patched_GenerationMixin:
2221
if pv.Version(transformers.__version__) >= pv.Version("4.56")
2322
else "prepare_inputs_for_generation"
2423
),
25-
(
26-
"_sample"
27-
if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
28-
else None
29-
),
24+
# (
25+
# "_sample"
26+
# if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
27+
# else None
28+
# ),
3029
]
3130
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
3231

@@ -299,6 +298,8 @@ def prepare_inputs_for_generation(
299298
model_inputs.pop("labels", None)
300299
return model_inputs
301300

301+
'''
302+
# drops a patch since it is for a very specific version.
302303
def _sample(
303304
self,
304305
input_ids: torch.LongTensor,
@@ -484,3 +485,4 @@ def _sample(
484485
)
485486
else:
486487
return input_ids
488+
'''

0 commit comments

Comments
 (0)