Skip to content

Commit 64ffb82

Browse files
authored
Fix: run_llm.py reports several errors (#4163)
1 parent a6ce2ee commit 64ffb82

3 files changed

Lines changed: 35 additions & 4 deletions

File tree

py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.fx
99
from torch._dispatch.python import enable_python_dispatcher
10+
import torch._inductor.fx_passes.reinplace
1011
from torch._inductor.fx_utils import get_fake_args_kwargs, get_node_storage, get_storage
1112
from torch._subclasses.fake_tensor import FakeTensorMode
1213
from torch.fx.experimental.symbolic_shapes import (

tools/llm/static_cache_v1.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,18 @@ def get_static_tensor(tensor: torch.Tensor):
105105
min_max_opt = extract_var_range_info(seq_len)
106106
max_seq_len = min_max_opt["max"]
107107

108-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
108+
# Get the ShapeEnv from the existing fake tensors in the graph rather than
109+
# creating a new one. Using a fresh ShapeEnv causes a KeyError in
110+
# FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to
111+
# the FakeTensorMode's ShapeEnv.
112+
fake_tensors = [
113+
node.meta["val"]
114+
for node in gm.graph.nodes
115+
if "val" in node.meta
116+
and isinstance(node.meta["val"], torch._subclasses.fake_tensor.FakeTensor)
117+
]
118+
shape_env = fake_tensors[0].fake_mode.shape_env
109119

110-
shape_env = ShapeEnv()
111120
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
112121
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
113122
torch._check(start_idx_unbacked_symint >= 0)
@@ -120,6 +129,12 @@ def get_static_tensor(tensor: torch.Tensor):
120129
start_idx_input.meta["val"] = start_idx_unbacked_symint
121130
end_idx_input.meta["val"] = end_idx_unbacked_symint
122131

132+
# u0/u1 are scalar index values, not tensor shape dimensions, so they will
133+
# never appear in any output tensor shape. Clear them from the pending list
134+
# so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when
135+
# processing subsequent call_function nodes (placeholder nodes are skipped).
136+
shape_env.pending_fresh_unbacked_symbols.clear()
137+
123138
return kv_inputs, start_idx_input, end_idx_input
124139

125140

tools/llm/static_cache_v2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,18 @@ def get_static_tensor(tensor: torch.Tensor):
110110
else:
111111
max_seq_len = seq_len
112112

113-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
113+
# Get the ShapeEnv from the existing fake tensors in the graph rather than
114+
# creating a new one. Using a fresh ShapeEnv causes a KeyError in
115+
# FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to
116+
# the FakeTensorMode's ShapeEnv.
117+
fake_tensors = [
118+
node.meta["val"]
119+
for node in gm.graph.nodes
120+
if "val" in node.meta
121+
and isinstance(node.meta["val"], torch._subclasses.fake_tensor.FakeTensor)
122+
]
123+
shape_env = fake_tensors[0].fake_mode.shape_env
114124

115-
shape_env = ShapeEnv()
116125
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
117126
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
118127
torch._check(start_idx_unbacked_symint >= 0)
@@ -125,6 +134,12 @@ def get_static_tensor(tensor: torch.Tensor):
125134
start_idx_input.meta["val"] = start_idx_unbacked_symint
126135
end_idx_input.meta["val"] = end_idx_unbacked_symint
127136

137+
# u0/u1 are scalar index values, not tensor shape dimensions, so they will
138+
# never appear in any output tensor shape. Clear them from the pending list
139+
# so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when
140+
# processing subsequent call_function nodes (placeholder nodes are skipped).
141+
shape_env.pending_fresh_unbacked_symbols.clear()
142+
128143
return kv_inputs, start_idx_input, end_idx_input
129144

130145

0 commit comments

Comments
 (0)