Skip to content

Commit df07f95

Browse files
author
zhangyue
committed
feat(flash_attention): add vLLM-style sliding_window entry (additive)
Keeps the native `window_left` / `window_right` pair as-is and adds an optional `std::optional<int64_t> sliding_window` parameter. When set, the base class normalizes it to the causal-sliding pair `(sliding_window - 1, 0)`; when both forms are supplied the normalized values must agree. Callers can now use either entry point: // Pair form (existing, unchanged): flash_attention(..., window_left=255, window_right=0, ...) // vLLM form: flash_attention(..., sliding_window=256, ...) Ascend impl reads the resolved pair from the base-class members (`window_left_` / `window_right_`) so `sliding_window` is honored at both construction and call time. Also extends `generate_wrappers.py` to set `py::arg(...) = py::none()` defaults for all `std::optional<...>` parameters (previously only `std::optional<Tensor>`), so `sliding_window` is properly optional on the Python side. Adds `test_flash_attention_sliding_window_equivalence` asserting bit-exact equality between the two entry points.
1 parent 592b493 commit df07f95

4 files changed

Lines changed: 147 additions & 22 deletions

File tree

scripts/generate_wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def _is_optional_tensor(arg):
121121
return True
122122
return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling
123123

124+
def _is_optional(arg):
125+
return "std::optional" in arg.type.spelling
126+
124127
def _is_vector_tensor(arg):
125128
if arg.spelling in vector_tensor_params:
126129
return True
@@ -177,7 +180,7 @@ def _generate_py_args(node):
177180
if arg.spelling == "stream":
178181
continue
179182

180-
if _is_optional_tensor(arg):
183+
if _is_optional(arg):
181184
parts.append(f'py::arg("{arg.spelling}") = py::none()')
182185
else:
183186
parts.append(f'py::arg("{arg.spelling}")')

src/ascend/flash_attention/kernel.h

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,12 @@ class Operator<FlashAttention, Device::Type::kAscend> : public FlashAttention {
114114
std::optional<Tensor> block_table, int64_t num_heads,
115115
int64_t num_kv_heads, int64_t head_size, double scale, bool causal,
116116
int64_t window_left, int64_t window_right, int64_t block_size,
117-
Tensor output)
117+
Tensor output,
118+
std::optional<int64_t> sliding_window = std::nullopt)
118119
: FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv,
119120
block_table, num_heads, num_kv_heads, head_size, scale,
120-
causal, window_left, window_right, block_size, output) {
121+
causal, window_left, window_right, block_size, output,
122+
sliding_window) {
121123
paged_ = block_table.has_value() && block_size > 0;
122124
aclDataType acl_dt = ascend::toAclDtype(query.dtype());
123125

@@ -126,9 +128,11 @@ class Operator<FlashAttention, Device::Type::kAscend> : public FlashAttention {
126128
prefill_q_cache_ = ascend::AclTensorCache(query);
127129
prefill_out_cache_ = ascend::AclTensorCache(output);
128130

129-
// Pre-compute causal mask once (sparse_mode >= 2).
131+
// Pre-compute causal mask once (sparse_mode >= 2). Read the
132+
// resolved pair from base-class members so `sliding_window`
133+
// normalization is honored at cache-key construction.
130134
if (causal) {
131-
int64_t sm = (window_left >= 0) ? 4 : 3;
135+
int64_t sm = (window_left_ >= 0) ? 4 : 3;
132136
if (sm >= 2) {
133137
causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr);
134138
}
@@ -169,26 +173,36 @@ class Operator<FlashAttention, Device::Type::kAscend> : public FlashAttention {
169173
std::optional<Tensor> block_table, int64_t num_heads,
170174
int64_t num_kv_heads, int64_t head_size, double scale,
171175
bool causal, int64_t window_left, int64_t window_right,
172-
int64_t block_size, Tensor output) const override {
176+
int64_t block_size, Tensor output,
177+
std::optional<int64_t> sliding_window) const override {
173178
auto stream = static_cast<aclrtStream>(stream_);
174179
const bool paged = paged_;
175180

181+
// The base class stored the resolved window pair in `window_left_` /
182+
// `window_right_` at construction; prefer those over the call-site
183+
// args so that `sliding_window` is honored here as well.
184+
int64_t wl = window_left_;
185+
int64_t wr = window_right_;
186+
(void)window_left;
187+
(void)window_right;
188+
(void)sliding_window;
189+
176190
int64_t sparse_mode;
177191
int64_t pre_tokens = 2147483647;
178192
int64_t next_tokens = 2147483647;
179193
if (causal) {
180-
if (window_left >= 0) {
194+
if (wl >= 0) {
181195
sparse_mode = 4;
182-
pre_tokens = window_left;
196+
pre_tokens = wl;
183197
next_tokens = 0;
184198
} else {
185199
sparse_mode = 3;
186200
next_tokens = 0;
187201
}
188202
} else {
189203
sparse_mode = 0;
190-
if (window_left >= 0) pre_tokens = window_left;
191-
if (window_right >= 0) next_tokens = window_right;
204+
if (wl >= 0) pre_tokens = wl;
205+
if (wr >= 0) next_tokens = wr;
192206
}
193207

194208
if (!paged) {

src/base/flash_attention.h

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,30 @@ namespace infini::ops {
1111

1212
class FlashAttention : public Operator<FlashAttention> {
1313
public:
14+
// `window_left` / `window_right` is the native InfiniOps pair-form
15+
// window (left-context / right-context tokens, `-1` = disabled).
16+
// `sliding_window` is a vLLM-style single-parameter shortcut: when
17+
// set, it is normalized to `(sliding_window - 1, 0)` — i.e. causal
18+
// sliding over the most recent `sliding_window` tokens. When both
19+
// forms are supplied the normalized values must agree. Callers may
20+
// use whichever form is more natural; the kernel only sees the
21+
// resolved pair.
1422
FlashAttention(const Tensor query, const Tensor key, const Tensor value,
1523
std::optional<Tensor> cu_seqlens_q,
1624
std::optional<Tensor> cu_seqlens_kv,
1725
std::optional<Tensor> block_table, int64_t num_heads,
1826
int64_t num_kv_heads, int64_t head_size, double scale,
1927
bool causal, int64_t window_left, int64_t window_right,
20-
int64_t block_size, Tensor output)
28+
int64_t block_size, Tensor output,
29+
std::optional<int64_t> sliding_window = std::nullopt)
2130
: num_tokens_{query.size(0)},
2231
num_heads_{num_heads},
2332
num_kv_heads_{num_kv_heads},
2433
head_size_{head_size},
2534
scale_{scale},
2635
causal_{causal},
27-
window_left_{window_left},
28-
window_right_{window_right},
36+
window_left_{resolveWindowLeft(window_left, sliding_window)},
37+
window_right_{resolveWindowRight(window_right, sliding_window)},
2938
block_size_{block_size},
3039
dtype_{query.dtype()},
3140
query_shape_{query.shape()},
@@ -45,15 +54,37 @@ class FlashAttention : public Operator<FlashAttention> {
4554
"`FlashAttention` requires query to be 3D [T, N, D]");
4655
}
4756

48-
virtual void operator()(const Tensor query, const Tensor key,
49-
const Tensor value,
50-
std::optional<Tensor> cu_seqlens_q,
51-
std::optional<Tensor> cu_seqlens_kv,
52-
std::optional<Tensor> block_table, int64_t num_heads,
53-
int64_t num_kv_heads, int64_t head_size, double scale,
54-
bool causal, int64_t window_left,
55-
int64_t window_right, int64_t block_size,
56-
Tensor output) const = 0;
57+
virtual void operator()(
58+
const Tensor query, const Tensor key, const Tensor value,
59+
std::optional<Tensor> cu_seqlens_q, std::optional<Tensor> cu_seqlens_kv,
60+
std::optional<Tensor> block_table, int64_t num_heads, int64_t num_kv_heads,
61+
int64_t head_size, double scale, bool causal, int64_t window_left,
62+
int64_t window_right, int64_t block_size, Tensor output,
63+
std::optional<int64_t> sliding_window = std::nullopt) const = 0;
64+
65+
private:
66+
// Normalize the window representation. If both the explicit pair and
67+
// `sliding_window` are supplied, assert the pair matches the derived
68+
// `(sliding_window - 1, 0)` causal-sliding window.
69+
static int64_t resolveWindowLeft(int64_t window_left,
70+
std::optional<int64_t> sliding_window) {
71+
if (!sliding_window.has_value()) return window_left;
72+
int64_t derived = sliding_window.value() - 1;
73+
assert((window_left == -1 || window_left == derived) &&
74+
"`FlashAttention`: `window_left` inconsistent with `sliding_window`");
75+
return derived;
76+
}
77+
78+
static int64_t resolveWindowRight(int64_t window_right,
79+
std::optional<int64_t> sliding_window) {
80+
if (!sliding_window.has_value()) return window_right;
81+
assert((window_right == -1 || window_right == 0) &&
82+
"`FlashAttention`: `window_right` inconsistent with `sliding_window` "
83+
"(vLLM sliding_window implies right=0)");
84+
return 0;
85+
}
86+
87+
public:
5788

5889
protected:
5990
Tensor::Size num_tokens_{0};

tests/test_flash_attention.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,80 @@ def _ref_flash_attention_paged(
537537
outputs.append(out)
538538

539539
return torch.cat(outputs, dim=0).to(query.device)
540+
541+
542+
@pytest.mark.parametrize("sliding_window", (4, 16))
543+
@pytest.mark.parametrize("device", ("npu",))
544+
def test_flash_attention_sliding_window_equivalence(sliding_window, device):
545+
"""The vLLM-style `sliding_window=N` entry must produce the same output
546+
as the native `window_left=N-1, window_right=0` pair.
547+
"""
548+
if not (hasattr(torch, "npu") and torch.npu.is_available()):
549+
pytest.skip("NPU not available")
550+
551+
num_tokens = 32
552+
num_heads = 8
553+
num_kv_heads = 8
554+
head_size = 64
555+
scale = 1.0 / head_size**0.5
556+
dtype = torch.float16
557+
558+
query = randn_strided(
559+
(num_tokens, num_heads, head_size), None, dtype=dtype, device=device
560+
)
561+
key = randn_strided(
562+
(num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device
563+
)
564+
value = randn_strided(
565+
(num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device
566+
)
567+
568+
cu_seqlens_q = torch.tensor([0, num_tokens], dtype=torch.int64, device=device)
569+
cu_seqlens_kv = torch.tensor([0, num_tokens], dtype=torch.int64, device=device)
570+
571+
# Pair-form call.
572+
out_pair = torch.empty_like(query)
573+
infini.ops.flash_attention(
574+
query,
575+
key,
576+
value,
577+
cu_seqlens_q,
578+
cu_seqlens_kv,
579+
None,
580+
num_heads,
581+
num_kv_heads,
582+
head_size,
583+
scale,
584+
True,
585+
sliding_window - 1,
586+
0,
587+
0,
588+
out_pair,
589+
stream=get_npu_stream(query),
590+
)
591+
592+
# vLLM-style single-parameter call.
593+
out_sw = torch.empty_like(query)
594+
infini.ops.flash_attention(
595+
query,
596+
key,
597+
value,
598+
cu_seqlens_q,
599+
cu_seqlens_kv,
600+
None,
601+
num_heads,
602+
num_kv_heads,
603+
head_size,
604+
scale,
605+
True,
606+
-1,
607+
-1,
608+
0,
609+
out_sw,
610+
sliding_window=sliding_window,
611+
stream=get_npu_stream(query),
612+
)
613+
614+
assert torch.equal(out_pair, out_sw), (
615+
f"Max diff: {(out_pair.float() - out_sw.float()).abs().max().item()}"
616+
)

0 commit comments

Comments
 (0)