1+ import enum
12import functools
23
34import ninetoothed
89BLOCK_SIZE_N = ninetoothed .block_size ()
910
1011
12+ class CausalVariant (enum .IntEnum ):
13+ """Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
14+
15+ UPPER_LEFT = enum .auto ()
16+
17+ LOWER_RIGHT = enum .auto ()
18+
19+
1120def arrangement (
1221 query ,
1322 key ,
@@ -21,6 +30,7 @@ def arrangement(
2130 scale ,
2231 output ,
2332 with_attn_mask ,
33+ causal_variant ,
2434 with_kv_cache ,
2535 block_size_m = None ,
2636 block_size_n = None ,
@@ -78,6 +88,7 @@ def arrange_attn_mask(input):
7888 scale_arranged = scale
7989 output_arranged = arrange_query_or_output (output )
8090 with_attn_mask_arranged = with_attn_mask
91+ causal_variant_arranged = causal_variant
8192
8293 if with_kv_cache :
8394 return (
@@ -93,6 +104,7 @@ def arrange_attn_mask(input):
93104 scale_arranged ,
94105 output_arranged ,
95106 with_attn_mask_arranged ,
107+ causal_variant_arranged ,
96108 )
97109
98110 return (
@@ -104,6 +116,7 @@ def arrange_attn_mask(input):
104116 scale_arranged ,
105117 output_arranged ,
106118 with_attn_mask_arranged ,
119+ causal_variant_arranged ,
107120 )
108121
109122
@@ -120,17 +133,34 @@ def application_with_kv_cache(
120133 scale ,
121134 output ,
122135 with_attn_mask ,
136+ causal_variant ,
123137):
124138 present_key_slot = present_key # noqa: F841
125139 present_value_slot = present_value # noqa: F841
126140
127141 application_without_kv_cache (
128- query , key , value , attn_mask , is_causal , scale , output , with_attn_mask
142+ query ,
143+ key ,
144+ value ,
145+ attn_mask ,
146+ is_causal ,
147+ scale ,
148+ output ,
149+ with_attn_mask ,
150+ causal_variant ,
129151 )
130152
131153
132154def application_without_kv_cache (
133- query , key , value , attn_mask , is_causal , scale , output , with_attn_mask
155+ query ,
156+ key ,
157+ value ,
158+ attn_mask ,
159+ is_causal ,
160+ scale ,
161+ output ,
162+ with_attn_mask ,
163+ causal_variant ,
134164):
135165 for i in range (query .shape [0 ]):
136166 query_i = (1.4426950408889634 * scale * query [i ]).to (query [i ].dtype )
@@ -147,7 +177,16 @@ def application_without_kv_cache(
147177 qk += attn_mask [j ]
148178
149179 if is_causal :
150- mask = query [i ].offsets (- 2 )[:, None ] >= key [j ].offsets (- 2 )[None , :]
180+ if causal_variant == 2 : # CausalVariant.LOWER_RIGHT:
181+ mask = (
182+ query [i ].offsets (- 2 )[:, None ]
183+ + key .source .shape [- 2 ]
184+ - query .source .shape [- 2 ]
185+ >= key [j ].offsets (- 2 )[None , :]
186+ )
187+ else :
188+ mask = query [i ].offsets (- 2 )[:, None ] >= key [j ].offsets (- 2 )[None , :]
189+
151190 qk = ntl .where (mask , qk , float ("-inf" ))
152191
153192 next_max = ntl .maximum (max , ntl .max (qk , 1 ))
@@ -167,6 +206,7 @@ def premake(
167206 emb_dim = None ,
168207 is_causal = None ,
169208 with_attn_mask = None ,
209+ causal_variant = None ,
170210 dtype = None ,
171211 block_size_m = None ,
172212 block_size_n = None ,
@@ -192,6 +232,7 @@ def premake(
192232 scale = Tensor (0 , dtype = dtype )
193233 is_causal = Tensor (0 , dtype = dtype , constexpr = True , value = is_causal )
194234 with_attn_mask = Tensor (0 , dtype = dtype , constexpr = True , value = with_attn_mask )
235+ causal_variant = Tensor (0 , dtype = dtype , constexpr = True , value = causal_variant )
195236
196237 if emb_dim is not None :
197238 for tensor in (query , key , value , attn_mask , output ):
@@ -215,6 +256,7 @@ def premake(
215256 scale ,
216257 output ,
217258 with_attn_mask ,
259+ causal_variant ,
218260 )
219261
220262 return arrangement_ , application , tensors
0 commit comments