@@ -207,7 +207,17 @@ def rename_tensor(tensor, name):
207207 input_tensor_positions = tuple (mapping .values ())
208208 other_tensor_positions = tuple (mapping .keys ())
209209
210- block_size_mapping = {}
210+ mapping = {}
211+
212+ for old_tensor , new_tensor in itertools .chain (
213+ zip (input_kernel .tensors , input_tensors ),
214+ zip (other_kernel .tensors , other_tensors ),
215+ ):
216+ old_names = sorted (old_tensor .names (), key = str )
217+ new_names = sorted (new_tensor .names (), key = str )
218+
219+ for old_name , new_name in zip (old_names , new_names ):
220+ mapping [old_name ] = new_name
211221
212222 for input_tensor_position , other_tensor_position in zip (
213223 input_tensor_positions , other_tensor_positions
@@ -233,11 +243,11 @@ def rename_tensor(tensor, name):
233243 lower_bound = new_lower_bound , upper_bound = new_upper_bound
234244 )
235245
236- block_size_mapping [input_block_size ] = new_block_size
237- block_size_mapping [other_block_size ] = new_block_size
246+ mapping [input_block_size ] = new_block_size
247+ mapping [other_block_size ] = new_block_size
238248
239249 for tensor in itertools .chain (input_tensors_arranged , other_tensors_arranged ):
240- _replace_history (tensor , block_size_mapping )
250+ _replace_history (tensor , mapping )
241251
242252 fusion_info = _get_fusion_info (
243253 input_tensors_arranged [input_tensor_positions [0 ]],
0 commit comments