@@ -83,8 +83,8 @@ def torch_matmul():
8383 torch_result = torch_matmul ()
8484
8585 # Create infinicore tensors
86- infini_a = create_infinicore_tensor (torch_a , device )
87- infini_b = create_infinicore_tensor (torch_b , device )
86+ infini_a = create_infinicore_tensor (torch_a , device_str )
87+ infini_b = create_infinicore_tensor (torch_b , device_str )
8888
8989 # Out-of-place matmul
9090 def infini_matmul ():
@@ -93,9 +93,7 @@ def infini_matmul():
9393 infini_result = infini_matmul ()
9494
9595 # Validate results using common method
96- is_valid = compare_results (
97- infini_result , torch_result , dtype , config , device_str , device
98- )
96+ is_valid = compare_results (infini_result , torch_result , dtype , config , device_str )
9997 assert is_valid , "Matmul test failed"
10098
10199 # Performance test
@@ -152,8 +150,8 @@ def torch_matmul_inplace():
152150 torch_matmul_inplace ()
153151
154152 # Create infinicore tensors
155- infini_a = create_infinicore_tensor (torch_a , device )
156- infini_b = create_infinicore_tensor (torch_b , device )
153+ infini_a = create_infinicore_tensor (torch_a , device_str )
154+ infini_b = create_infinicore_tensor (torch_b , device_str )
157155 infini_c = infinicore .empty (
158156 result_shape , dtype = dtype , device = infinicore .device (device_str , 0 )
159157 )
@@ -166,9 +164,7 @@ def infini_matmul_inplace():
166164 infini_matmul_inplace ()
167165
168166 # Validate results using common method
169- is_valid = compare_results (
170- infini_c , torch_preallocated , dtype , config , device_str , device
171- )
167+ is_valid = compare_results (infini_c , torch_preallocated , dtype , config , device_str )
172168 assert is_valid , "In-place matmul test failed"
173169
174170 # Performance test
0 commit comments