@@ -17,6 +17,7 @@ def arrangement(
1717 present_key_slot ,
1818 present_value_slot ,
1919 attn_mask ,
20+ is_causal ,
2021 scale ,
2122 output ,
2223 with_attn_mask ,
@@ -67,6 +68,7 @@ def arrange_attn_mask(input):
6768 present_value_slot
6869 )
6970 attn_mask_arranged = arrange_attn_mask (attn_mask )
71+ is_causal_arranged = is_causal
7072 scale_arranged = scale
7173 output_arranged = arrange_query_or_output (output )
7274 with_attn_mask_arranged = with_attn_mask
@@ -81,6 +83,7 @@ def arrange_attn_mask(input):
8183 present_key_slot_arranged ,
8284 present_value_slot_arranged ,
8385 attn_mask_arranged ,
86+ is_causal_arranged ,
8487 scale_arranged ,
8588 output_arranged ,
8689 with_attn_mask_arranged ,
@@ -91,6 +94,7 @@ def arrange_attn_mask(input):
9194 key_arranged ,
9295 value_arranged ,
9396 attn_mask_arranged ,
97+ is_causal_arranged ,
9498 scale_arranged ,
9599 output_arranged ,
96100 with_attn_mask_arranged ,
@@ -106,6 +110,7 @@ def application_with_kv_cache(
106110 present_key_slot ,
107111 present_value_slot ,
108112 attn_mask ,
113+ is_causal ,
109114 scale ,
110115 output ,
111116 with_attn_mask ,
@@ -114,12 +119,12 @@ def application_with_kv_cache(
114119 present_value_slot = present_value # noqa: F841
115120
116121 application_without_kv_cache (
117- query , key , value , attn_mask , scale , output , with_attn_mask
122+ query , key , value , attn_mask , is_causal , scale , output , with_attn_mask
118123 )
119124
120125
121126def application_without_kv_cache (
122- query , key , value , attn_mask , scale , output , with_attn_mask
127+ query , key , value , attn_mask , is_causal , scale , output , with_attn_mask
123128):
124129 for i in range (query .shape [0 ]):
125130 query_i = (1.4426950408889634 * scale * query [i ]).to (query [i ].dtype )
@@ -135,6 +140,10 @@ def application_without_kv_cache(
135140 if with_attn_mask :
136141 qk += attn_mask [j ]
137142
143+ if is_causal :
144+ mask = query [i ].offsets (- 2 )[:, None ] >= key [j ].offsets (- 2 )[None , :]
145+ qk = ntl .where (mask , qk , float ("-inf" ))
146+
138147 next_max = ntl .maximum (max , ntl .max (qk , 1 ))
139148 stable_qk = ntl .exp2 (qk - next_max [:, None ])
140149
@@ -168,7 +177,7 @@ def make(with_kv_cache):
168177 for _ in range (4 )
169178 )
170179 scale = Tensor (0 )
171- with_attn_mask = Tensor (0 , constexpr = True )
180+ is_causal , with_attn_mask = ( Tensor (0 , constexpr = True ) for _ in range ( 2 ) )
172181
173182 if with_kv_cache :
174183 application = application_with_kv_cache
@@ -184,6 +193,7 @@ def make(with_kv_cache):
184193 present_key_slot ,
185194 present_value_slot ,
186195 attn_mask ,
196+ is_causal ,
187197 scale ,
188198 output ,
189199 with_attn_mask ,
0 commit comments