Skip to content

Commit 8f7dddd

Browse files
committed
Include tensor attribute names in history replacement mapping
1 parent 9c4f96d commit 8f7dddd

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

src/ninetoothed/fusion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,17 @@ def rename_tensor(tensor, name):
207207
input_tensor_positions = tuple(mapping.values())
208208
other_tensor_positions = tuple(mapping.keys())
209209

210-
block_size_mapping = {}
210+
mapping = {}
211+
212+
for old_tensor, new_tensor in itertools.chain(
213+
zip(input_kernel.tensors, input_tensors),
214+
zip(other_kernel.tensors, other_tensors),
215+
):
216+
old_names = sorted(old_tensor.names(), key=str)
217+
new_names = sorted(new_tensor.names(), key=str)
218+
219+
for old_name, new_name in zip(old_names, new_names):
220+
mapping[old_name] = new_name
211221

212222
for input_tensor_position, other_tensor_position in zip(
213223
input_tensor_positions, other_tensor_positions
@@ -233,11 +243,11 @@ def rename_tensor(tensor, name):
233243
lower_bound=new_lower_bound, upper_bound=new_upper_bound
234244
)
235245

236-
block_size_mapping[input_block_size] = new_block_size
237-
block_size_mapping[other_block_size] = new_block_size
246+
mapping[input_block_size] = new_block_size
247+
mapping[other_block_size] = new_block_size
238248

239249
for tensor in itertools.chain(input_tensors_arranged, other_tensors_arranged):
240-
_replace_history(tensor, block_size_mapping)
250+
_replace_history(tensor, mapping)
241251

242252
fusion_info = _get_fusion_info(
243253
input_tensors_arranged[input_tensor_positions[0]],

0 commit comments

Comments
 (0)