@@ -105,9 +105,18 @@ def get_static_tensor(tensor: torch.Tensor):
105105 min_max_opt = extract_var_range_info (seq_len )
106106 max_seq_len = min_max_opt ["max" ]
107107
108- from torch .fx .experimental .symbolic_shapes import ShapeEnv
108+ # Get the ShapeEnv from the existing fake tensors in the graph rather than
109+ # creating a new one. Using a fresh ShapeEnv causes a KeyError in
110+ # FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to
111+ # the FakeTensorMode's ShapeEnv.
112+ fake_tensors = [
113+ node .meta ["val" ]
114+ for node in gm .graph .nodes
115+ if "val" in node .meta
116+ and isinstance (node .meta ["val" ], torch ._subclasses .fake_tensor .FakeTensor )
117+ ]
118+ shape_env = fake_tensors [0 ].fake_mode .shape_env
109119
110- shape_env = ShapeEnv ()
111120 # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
112121 start_idx_unbacked_symint = shape_env .create_unbacked_symint ()
113122 torch ._check (start_idx_unbacked_symint >= 0 )
@@ -120,6 +129,12 @@ def get_static_tensor(tensor: torch.Tensor):
120129 start_idx_input .meta ["val" ] = start_idx_unbacked_symint
121130 end_idx_input .meta ["val" ] = end_idx_unbacked_symint
122131
132+ # u0/u1 are scalar index values, not tensor shape dimensions, so they will
133+ # never appear in any output tensor shape. Clear them from the pending list
134+ # so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when
135+ # processing subsequent call_function nodes (placeholder nodes are skipped).
136+ shape_env .pending_fresh_unbacked_symbols .clear ()
137+
123138 return kv_inputs , start_idx_input , end_idx_input
124139
125140
0 commit comments