33import torch
44
55
6- def generate_input (RANK : int , world_size : int , m : int , n : int , k : int , has_bias : bool , seed : int ) -> input_t :
6+ def generate_input (rank : int , world_size : int , m : int , n : int , k : int , has_bias : bool , seed : int ) -> input_t :
77 """
88 Generate random input and weights for the Gemm-ReduceScatter operation.
99
@@ -14,21 +14,22 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias:
1414 bias: Optional[torch.Tensor], # [N] or None
1515 )
1616 """
17- gen = torch .Generator (device = 'cuda' )
18- gen .manual_seed (seed + RANK )
17+ device = torch .device (f'cuda:{ rank } ' )
18+ gen = torch .Generator (device = device )
19+ gen .manual_seed (seed + rank )
1920
2021 assert m % world_size == 0 , "m must be divisible by world_size"
2122 assert k % world_size == 0 , "k must be divisible by world_size"
2223 local_k = k // world_size
2324
2425 # Generate random inputs and weights
25- input = (torch .rand ((m , local_k ), dtype = torch .bfloat16 , device = "cuda" , generator = gen ) * 2 - 1 ) * 0.01
26- weight = (torch .rand ((n , local_k ), dtype = torch .bfloat16 , device = "cuda" , generator = gen ) * 2 - 1 ) * 0.01
26+ input = (torch .rand ((m , local_k ), dtype = torch .bfloat16 , device = device , generator = gen ) * 2 - 1 ) * 0.01
27+ weight = (torch .rand ((n , local_k ), dtype = torch .bfloat16 , device = device , generator = gen ) * 2 - 1 ) * 0.01
2728
2829 bias = None
2930 if has_bias :
3031 gen .manual_seed (seed )
31- bias = (torch .rand ((n ,), dtype = torch .bfloat16 , device = "cuda" , generator = gen ) * 2 - 1 ) * 0.01
32+ bias = (torch .rand ((n ,), dtype = torch .bfloat16 , device = device , generator = gen ) * 2 - 1 ) * 0.01
3233
3334 return (input , weight , bias )
3435
@@ -60,4 +61,12 @@ def ref_kernel(data: input_t) -> output_t:
6061 return rs_output
6162
6263
63- check_implementation = make_match_reference (ref_kernel , rtol = 1e-2 , atol = 1e-2 )
64+ def check_implementation (data : input_t , output : output_t ):
65+ expected = ref_kernel (data )
66+ if output .device != expected .device :
67+ return False , f"Output device mismatch: { output .device } != { expected .device } "
68+ res = torch .allclose (output , expected , rtol = 1e-2 , atol = 1e-2 )
69+ if not res :
70+ return False , f"Output values mismatch, { output } != { expected } "
71+
72+ return True , ""
0 commit comments