@@ -55,7 +55,29 @@ def test_current_device(model, module):
5555 self_attn_mask_type = "padding" ,
5656 device = f"cuda:{ tensor_device } " ,
5757 )
58- num_tokens = torch .randint (0 , config .max_seqlen_q , (1 ,)).item ()
58+ seqlens_q = torch .randint (
59+ 1 ,
60+ config .max_seqlen_q ,
61+ [config .batch_size ],
62+ dtype = torch .int32 ,
63+ device = f"cuda:{ tensor_device } " ,
64+ )
65+ cu_seqlens_q = torch .zeros (
66+ config .batch_size + 1 , dtype = torch .int32 , device = f"cuda:{ tensor_device } "
67+ )
68+ cu_seqlens_q [1 :] = torch .cumsum (seqlens_q , dim = 0 )
69+ seqlens_kv = torch .randint (
70+ 1 ,
71+ config .max_seqlen_kv ,
72+ [config .batch_size ],
73+ dtype = torch .int32 ,
74+ device = f"cuda:{ tensor_device } " ,
75+ )
76+ cu_seqlens_kv = torch .zeros (
77+ config .batch_size + 1 , dtype = torch .int32 , device = f"cuda:{ tensor_device } "
78+ )
79+ cu_seqlens_kv [1 :] = torch .cumsum (seqlens_kv , dim = 0 )
80+ num_tokens = cu_seqlens_q [- 1 ]
5981 args = [
6082 torch .randn (
6183 (num_tokens , config .hidden_size ),
@@ -64,9 +86,6 @@ def test_current_device(model, module):
6486 requires_grad = True ,
6587 )
6688 ]
67- cu_seqlens_q , cu_seqlens_kv = [
68- torch .Tensor ([0 , 2 , 3 ]).to (dtype = torch .int32 , device = tensor_device ) for _ in range (2 )
69- ]
7089 kwargs ["cu_seqlens_q" ] = cu_seqlens_q
7190 kwargs ["cu_seqlens_kv" ] = cu_seqlens_kv
7291 kwargs ["max_seqlen_q" ] = config .max_seqlen_q
@@ -75,26 +94,47 @@ def test_current_device(model, module):
7594 model = DotProductAttention (
7695 config .num_heads , config .head_dim_qk , qkv_format = "thd" , attn_mask_type = "padding"
7796 )
78- num_tokens = torch .randint (0 , config .max_seqlen_q , (1 ,)).item ()
97+ seqlens_q = torch .randint (
98+ 1 ,
99+ config .max_seqlen_q ,
100+ [config .batch_size ],
101+ dtype = torch .int32 ,
102+ device = f"cuda:{ tensor_device } " ,
103+ )
104+ cu_seqlens_q = torch .zeros (
105+ config .batch_size + 1 , dtype = torch .int32 , device = f"cuda:{ tensor_device } "
106+ )
107+ cu_seqlens_q [1 :] = torch .cumsum (seqlens_q , dim = 0 )
108+ seqlens_kv = torch .randint (
109+ 1 ,
110+ config .max_seqlen_kv ,
111+ [config .batch_size ],
112+ dtype = torch .int32 ,
113+ device = f"cuda:{ tensor_device } " ,
114+ )
115+ cu_seqlens_kv = torch .zeros (
116+ config .batch_size + 1 , dtype = torch .int32 , device = f"cuda:{ tensor_device } "
117+ )
118+ cu_seqlens_kv [1 :] = torch .cumsum (seqlens_kv , dim = 0 )
119+ num_tokens = cu_seqlens_q [- 1 ]
79120 args = [
80121 torch .randn (
81122 num_tokens ,
82123 config .num_heads ,
83124 config .head_dim_qk ,
84125 dtype = dtype ,
85- device = tensor_device ,
126+ device = f"cuda: { tensor_device } " ,
86127 requires_grad = True ,
87128 )
88129 for _ in range (3 )
89130 ]
90- cu_seqlens_q , cu_seqlens_kv = [
91- torch .Tensor ([0 , 2 , 3 ]).to (dtype = torch .int32 , device = tensor_device ) for _ in range (2 )
92- ]
93131 kwargs ["cu_seqlens_q" ] = cu_seqlens_q
94132 kwargs ["cu_seqlens_kv" ] = cu_seqlens_kv
95133 kwargs ["max_seqlen_q" ] = config .max_seqlen_q
96134 kwargs ["max_seqlen_kv" ] = config .max_seqlen_kv
97- bwd_args = [torch .randn (num_tokens , config .hidden_size , dtype = dtype , device = tensor_device )]
135+ bwd_args = [
136+ torch .randn (num_tokens , config .hidden_size , dtype = dtype , device = f"cuda:{ tensor_device } " )
137+ ]
98138 elif module == "Linear" :
99139 model = Linear (
100140 config .hidden_size ,
0 commit comments