Skip to content

Commit e7a8360

Browse files
xadupreCopilot
andauthored
Update onnx_diagnostic/investigate/input_observer.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9ea2e5f commit e7a8360

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

onnx_diagnostic/investigate/input_observer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,22 @@ def __init__(
147147

148148
def remove_inputs(self, input_names: Sequence[str | int]):
149149
"""Removes inputs."""
150+
# Work on a mutable copy of positional arguments.
151+
args_list = list(self.args)
152+
150153
for name_or_pos in sorted(input_names, reverse=True):
151154
if isinstance(name_or_pos, int):
152-
if name_or_pos in self.args:
153-
del self.args[name_or_pos]
155+
idx = name_or_pos
156+
if 0 <= idx < len(args_list):
157+
del args_list[idx]
154158
else:
155159
if name_or_pos in self.kwargs:
156160
del self.kwargs[name_or_pos]
157161
elif name_or_pos in self.cst_kwargs:
158162
del self.cst_kwargs[name_or_pos]
159163

164+
# Update stored positional arguments.
165+
self.args = tuple(args_list)
160166
# remove any temporary structures
161167
self.flat_list, self.spec = torch.utils._pytree.tree_flatten((self.args, self.kwargs))
162168
self._position_to_args_kwargs: list[int | str] | None = None

0 commit comments

Comments
 (0)