From f07d2db010d2c112f2be8c20c84a59c034f4d16f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 1 Jan 2026 09:54:39 +0100 Subject: [PATCH 1/6] add support for fx.Proxy --- onnx_diagnostic/helpers/helper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 972cce4b..923c92ef 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -801,6 +801,9 @@ def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F print(f"[string_type] TT8:{type(obj)}") return repr(obj).replace(" ", "").replace("\n", " ") + if isinstance(obj, torch.fx.proxy.Proxy): + return repr(obj) + if ignore: if verbose: print(f"[string_type] CACHE4:{type(obj)}") From 31b272f4c7adaf67e84bdbbb4907e2a7a6235a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 1 Jan 2026 19:27:24 +0100 Subject: [PATCH 2/6] fix fake --- onnx_diagnostic/helpers/fake_tensor_helper.py | 8 ++++++-- .../patches/_patch_transformers_dynamic_cache.py | 9 +++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/helpers/fake_tensor_helper.py b/onnx_diagnostic/helpers/fake_tensor_helper.py index c32e1830..eed297c4 100644 --- a/onnx_diagnostic/helpers/fake_tensor_helper.py +++ b/onnx_diagnostic/helpers/fake_tensor_helper.py @@ -105,6 +105,8 @@ def fake_reshape( reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum( axis=tuple(sorted(sh)), keepdim=True ) + if len(reduced_tensor.shape) == 0 == len(new_shape): + return reduced_tensor return reduced_tensor.expand(*new_shape) def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821 @@ -157,7 +159,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any: ) if type(x) is dict: return { - k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k]) + k: self.make_fake_with_dynamic_dimensions( + v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None + ) for k, v in x.items() } if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}: @@ -231,7 +235,7 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any: x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device) - t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type] + t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type] assert t.device == x.device, f"device mismatch {x.device} -> {t.device}" assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}" return t diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py index 79fe9ab1..2a5dba2b 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py @@ -27,8 +27,13 @@ def lazy_initialization(self, key_states: torch.Tensor): new_shape = list(key_states.shape) new_shape[-2] = 0 # PATCHED: used a tensor with an empty shape and not en empty list to initialize - self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) - self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) + if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor): + with key_states.fake_mode: + self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) + self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) + else: + self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) + self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) if patch_is_initialized: self.is_initialized = True From b94c7319750b883b1050aeac8638b15deedf9a28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 2 Jan 2026 13:12:29 +0100 Subject: [PATCH 3/6] fix --- .../patches/_patch_transformers_dynamic_cache.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py index 2a5dba2b..7fcf2718 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py @@ -24,6 +24,13 @@ class patched_DynamicLayer: def lazy_initialization(self, key_states: torch.Tensor): self.dtype, self.device = key_states.dtype, key_states.device + assert ( + hasattr(key_states, "shape") and key_states is not None + ), f"Attribute 'shape' is wrong for type {type(key_states)}" + assert isinstance(key_states.shape, tuple), ( + f"Unxpected type {type(key_states.shape)} for key_states.shape, " + f"__dict__={key_states.shape.__dict__}" + ) new_shape = list(key_states.shape) new_shape[-2] = 0 # PATCHED: used a tensor with an empty shape and not en empty list to initialize From 36aa79e447393d08832f7232f5376b7da24500ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 2 Jan 2026 13:20:39 +0100 Subject: [PATCH 4/6] rewrite patch --- onnx_diagnostic/helpers/cache_helper.py | 8 -------- .../patches/_patch_transformers_dynamic_cache.py | 15 +++++---------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 3ff36d9b..ff0977ab 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -28,14 +28,6 @@ def __init__(self, cache=None): ] self.key_cache = [layer.keys for layer in layers] self.value_cache = [layer.values for layer in layers] - if None in self.key_cache or None in self.value_cache: - from .helper import string_type - - raise AssertionError( - f"issue with key_cache={string_type(self.key_cache)}, " - f"or value_cache={string_type(self.value_cache)}, " - f"cache.layers={string_type(cache.layers)}" - ) elif cache is not None and hasattr(cache, "key_cache"): self.key_cache = cache.key_cache self.value_cache = cache.value_cache diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py index 7fcf2718..168b6cd8 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py @@ -27,20 +27,15 @@ def lazy_initialization(self, key_states: torch.Tensor): assert ( hasattr(key_states, "shape") and key_states is not None ), f"Attribute 'shape' is wrong for type {type(key_states)}" - assert isinstance(key_states.shape, tuple), ( - f"Unxpected type {type(key_states.shape)} for key_states.shape, " - f"__dict__={key_states.shape.__dict__}" - ) - new_shape = list(key_states.shape) - new_shape[-2] = 0 + like = key_states[:, :0] # PATCHED: used a tensor with an empty shape and not en empty list to initialize if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor): with key_states.fake_mode: - self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) - self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) + self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device) + self.values = torch.empty_like(like, dtype=self.dtype, device=self.device) else: - self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) - self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) + self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device) + self.values = torch.empty_like(like, dtype=self.dtype, device=self.device) if patch_is_initialized: self.is_initialized = True From 7bd77506abc357af4b40f418384ae91a4db7c8a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 2 Jan 2026 14:53:50 +0100 Subject: [PATCH 5/6] fix --- .../patches/_patch_transformers_dynamic_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py index 168b6cd8..af742380 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py @@ -27,7 +27,7 @@ def lazy_initialization(self, key_states: torch.Tensor): assert ( hasattr(key_states, "shape") and key_states is not None ), f"Attribute 'shape' is wrong for type {type(key_states)}" - like = key_states[:, :0] + like = key_states[:, :, :0] # PATCHED: used a tensor with an empty shape and not en empty list to initialize if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor): with key_states.fake_mode: From 63cc707df582eebe87ceff965ac660273ae1e4f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 2 Jan 2026 15:54:53 +0100 Subject: [PATCH 6/6] fix --- .../patches/_patch_transformers_dynamic_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py index af742380..d29f5616 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py @@ -27,7 +27,7 @@ def lazy_initialization(self, key_states: torch.Tensor): assert ( hasattr(key_states, "shape") and key_states is not None ), f"Attribute 'shape' is wrong for type {type(key_states)}" - like = key_states[:, :, :0] + like = torch.narrow(key_states, dim=-2, start=0, length=0) # PATCHED: used a tensor with an empty shape and not en empty list to initialize if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor): with key_states.fake_mode: