Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ def __init__(self, cache=None):
]
self.key_cache = [layer.keys for layer in layers]
self.value_cache = [layer.values for layer in layers]
if None in self.key_cache or None in self.value_cache:
from .helper import string_type

raise AssertionError(
f"issue with key_cache={string_type(self.key_cache)}, "
f"or value_cache={string_type(self.value_cache)}, "
f"cache.layers={string_type(cache.layers)}"
)
elif cache is not None and hasattr(cache, "key_cache"):
self.key_cache = cache.key_cache
self.value_cache = cache.value_cache
Expand Down
8 changes: 6 additions & 2 deletions onnx_diagnostic/helpers/fake_tensor_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def fake_reshape(
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
axis=tuple(sorted(sh)), keepdim=True
)
if len(reduced_tensor.shape) == 0 == len(new_shape):
return reduced_tensor
return reduced_tensor.expand(*new_shape)

def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
Expand Down Expand Up @@ -157,7 +159,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
)
if type(x) is dict:
return {
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
k: self.make_fake_with_dynamic_dimensions(
v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None
)
for k, v in x.items()
}
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
Expand Down Expand Up @@ -231,7 +235,7 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:

x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)

t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type]
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
return t
Expand Down
3 changes: 3 additions & 0 deletions onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,9 @@ def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F
print(f"[string_type] TT8:{type(obj)}")
return repr(obj).replace(" ", "").replace("\n", " ")

if isinstance(obj, torch.fx.proxy.Proxy):
return repr(obj)

if ignore:
if verbose:
print(f"[string_type] CACHE4:{type(obj)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@ class patched_DynamicLayer:

def lazy_initialization(self, key_states: torch.Tensor):
self.dtype, self.device = key_states.dtype, key_states.device
new_shape = list(key_states.shape)
new_shape[-2] = 0
assert (
hasattr(key_states, "shape") and key_states is not None
), f"Attribute 'shape' is wrong for type {type(key_states)}"
like = torch.narrow(key_states, dim=-2, start=0, length=0)
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
with key_states.fake_mode:
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
else:
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
if patch_is_initialized:
self.is_initialized = True

Expand Down
Loading