@@ -35,14 +35,15 @@ class RaggedAttentionTest(unittest.TestCase):
3535 head_dim = 128
3636
3737 dtype = jnp .float32
38- key = jax .random .key (0 )
39- k1 , k2 , k3 = jax .random .split (key , 3 )
4038
4139 @pytest .mark .tpu_only
4240 def test_ragged_mqa (self ):
43- q = jax .random .normal (self .k1 , (self .batch_size , 1 , self .head_dim ), dtype = self .dtype )
44- k = jax .random .normal (self .k2 , (self .batch_size , self .max_target_length , self .head_dim ), dtype = self .dtype )
45- v = jax .random .normal (self .k3 , (self .batch_size , self .max_target_length , self .head_dim ), dtype = self .dtype )
41+ key = jax .random .key (0 )
42+ k1 , k2 , k3 = jax .random .split (key , 3 )
43+
44+ q = jax .random .normal (k1 , (self .batch_size , 1 , self .head_dim ), dtype = self .dtype )
45+ k = jax .random .normal (k2 , (self .batch_size , self .max_target_length , self .head_dim ), dtype = self .dtype )
46+ v = jax .random .normal (k3 , (self .batch_size , self .max_target_length , self .head_dim ), dtype = self .dtype )
4647 lengths = jnp .array (np .random .randint (1 , self .max_target_length , self .batch_size ), dtype = jnp .int32 )
4748
4849 ragged_out , ragged_max , ragged_denom = ragged_mqa (q , k , v , lengths )
@@ -58,12 +59,15 @@ def test_ragged_mqa(self):
5859
5960 @pytest .mark .tpu_only
6061 def test_ragged_mha (self ):
61- q = jax .random .normal (self .k1 , (self .batch_size , 1 , self .num_query_heads , self .head_dim ), dtype = self .dtype )
62+ key = jax .random .key (0 )
63+ k1 , k2 , k3 = jax .random .split (key , 3 )
64+
65+ q = jax .random .normal (k1 , (self .batch_size , 1 , self .num_query_heads , self .head_dim ), dtype = self .dtype )
6266 k = jax .random .normal (
63- self . k2 , (self .batch_size , self .max_target_length , self .num_query_heads , self .head_dim ), dtype = self .dtype
67+ k2 , (self .batch_size , self .max_target_length , self .num_query_heads , self .head_dim ), dtype = self .dtype
6468 )
6569 v = jax .random .normal (
66- self . k3 , (self .batch_size , self .max_target_length , self .num_query_heads , self .head_dim ), dtype = self .dtype
70+ k3 , (self .batch_size , self .max_target_length , self .num_query_heads , self .head_dim ), dtype = self .dtype
6771 )
6872 lengths = jnp .array (np .random .randint (1 , self .max_target_length , self .batch_size ), dtype = jnp .int32 )
6973
@@ -81,12 +85,15 @@ def test_ragged_mha(self):
8185
8286 @pytest .mark .tpu_only
8387 def test_ragged_gqa (self ):
84- q = jax .random .normal (self .k1 , (self .batch_size , 1 , self .num_query_heads , self .head_dim ), dtype = self .dtype )
88+ key = jax .random .key (0 )
89+ k1 , k2 , k3 = jax .random .split (key , 3 )
90+
91+ q = jax .random .normal (k1 , (self .batch_size , 1 , self .num_query_heads , self .head_dim ), dtype = self .dtype )
8592 k = jax .random .normal (
86- self . k2 , (self .batch_size , self .max_target_length , self .num_kv_heads , self .head_dim ), dtype = self .dtype
93+ k2 , (self .batch_size , self .max_target_length , self .num_kv_heads , self .head_dim ), dtype = self .dtype
8794 )
8895 v = jax .random .normal (
89- self . k3 , (self .batch_size , self .max_target_length , self .num_kv_heads , self .head_dim ), dtype = self .dtype
96+ k3 , (self .batch_size , self .max_target_length , self .num_kv_heads , self .head_dim ), dtype = self .dtype
9097 )
9198 lengths = jnp .array (np .random .randint (1 , self .max_target_length , self .batch_size ), dtype = jnp .int32 )
9299
0 commit comments