Skip to content

Commit e383ad1

Browse files
committed
move mask computation outside graph + infra improvements
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent e906b02 commit e383ad1

6 files changed

Lines changed: 287 additions & 107 deletions

File tree

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py

Lines changed: 167 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,24 @@ def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[floa
7070
return attn_scores
7171

7272

73+
def _write_generate_kv_cache(
74+
k: torch.Tensor,
75+
v: torch.Tensor,
76+
k_cache: torch.Tensor,
77+
v_cache: torch.Tensor,
78+
slot_idx: torch.Tensor,
79+
input_pos: torch.Tensor,
80+
):
81+
"""Write single-token decode K/V into the cache."""
82+
b, s = k.shape[:2]
83+
assert s == 1, f"Expected sequence length 1 for generate phase, got {s}"
84+
for i in range(b):
85+
cache_idx = slot_idx[i].item()
86+
pos = input_pos[i].item()
87+
k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim
88+
v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim
89+
90+
7391
def _torch_generate_mha(
7492
q: torch.Tensor,
7593
k: torch.Tensor,
@@ -89,12 +107,7 @@ def _torch_generate_mha(
89107
assert s == 1, f"Expected sequence length 1 for generate phase, got {s}"
90108
n_kv_heads = k.shape[2] # k has shape (b, 1, n_kv_heads, head_dim)
91109

92-
# Update KV cache for single token
93-
for i in range(b):
94-
cache_idx = slot_idx[i].item()
95-
pos = input_pos[i].item()
96-
k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim
97-
v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim
110+
_write_generate_kv_cache(k, v, k_cache, v_cache, slot_idx, input_pos)
98111

99112
# Compute attention for each sequence using manual computation
100113
for i in range(b):
@@ -156,6 +169,60 @@ def _torch_generate_mha(
156169
out[i] = attn_out.squeeze(1) # [n_heads, v_head_dim]
157170

158171

172+
def _torch_generate_mha_readonly(
173+
q: torch.Tensor,
174+
k_cache: torch.Tensor,
175+
v_cache: torch.Tensor,
176+
slot_idx: torch.Tensor,
177+
input_pos: torch.Tensor,
178+
scale: float,
179+
out: torch.Tensor,
180+
logit_cap: Optional[float] = None,
181+
sliding_window_size: Optional[int] = None,
182+
sinks: Optional[torch.Tensor] = None,
183+
):
184+
"""Generate-only attention using an existing KV cache without writing current-layer K/V."""
185+
b, s, n_heads, head_dim = q.shape
186+
assert s == 1, f"Expected sequence length 1 for generate phase, got {s}"
187+
n_kv_heads = k_cache.shape[2]
188+
189+
for i in range(b):
190+
cache_idx = slot_idx[i].item()
191+
pos = input_pos[i].item()
192+
q_i = q[i, 0]
193+
194+
if sliding_window_size is not None and sliding_window_size > 0:
195+
start_pos = max(0, pos - sliding_window_size + 1)
196+
k_i = k_cache[cache_idx, start_pos : pos + 1]
197+
v_i = v_cache[cache_idx, start_pos : pos + 1]
198+
else:
199+
k_i = k_cache[cache_idx, : pos + 1]
200+
v_i = v_cache[cache_idx, : pos + 1]
201+
202+
q_i = q_i.unsqueeze(1)
203+
k_i = k_i.transpose(0, 1)
204+
v_i = v_i.transpose(0, 1)
205+
206+
if n_heads != n_kv_heads:
207+
n_rep = n_heads // n_kv_heads
208+
k_i = repeat_kv(k_i.unsqueeze(0), n_rep)[0]
209+
v_i = repeat_kv(v_i.unsqueeze(0), n_rep)[0]
210+
211+
attn_scores = torch.matmul(q_i, k_i.transpose(-2, -1)) * scale
212+
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
213+
214+
if sinks is not None:
215+
sinks = sinks.reshape(-1, 1, 1)
216+
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
217+
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
218+
attn_out = torch.matmul(attn_weights[..., : -sinks.size(-1)], v_i)
219+
else:
220+
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
221+
attn_out = torch.matmul(attn_weights, v_i)
222+
223+
out[i] = attn_out.squeeze(1)
224+
225+
159226
def _torch_context_mha(
160227
q: torch.Tensor,
161228
k: torch.Tensor,
@@ -174,7 +241,6 @@ def _torch_context_mha(
174241
sinks: Optional[torch.Tensor] = None,
175242
) -> None:
176243
"""Context attention (multiple tokens, potentially multiple sequences) using existing torch functions."""
177-
# Update KV cache first using existing function
178244
_update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
179245

180246
# Compute attention for each sequence
@@ -293,9 +359,85 @@ def _torch_context_mha(
293359
out.copy_(torch.cat(attn_outputs, dim=0))
294360

295361

296-
@torch.library.custom_op(
297-
"auto_deploy::torch_cached_attention_with_cache", mutates_args=("k_cache", "v_cache")
298-
)
362+
def _torch_context_mha_readonly(
363+
q: torch.Tensor,
364+
input_pos: torch.Tensor,
365+
slot_idx: torch.Tensor,
366+
k_cache: torch.Tensor,
367+
v_cache: torch.Tensor,
368+
seq_len: torch.Tensor,
369+
seq_start: torch.Tensor,
370+
scale: float,
371+
out: torch.Tensor,
372+
logit_cap: Optional[float] = None,
373+
sliding_window_size: Optional[int] = None,
374+
sinks: Optional[torch.Tensor] = None,
375+
) -> None:
376+
"""Context attention using an existing KV cache without writing current-layer K/V."""
377+
attn_outputs = []
378+
for idx in range(seq_len.shape[0]):
379+
seq_len_i = seq_len[idx].item()
380+
input_pos_i = input_pos[idx].item()
381+
slot_idx_i = slot_idx[idx].item()
382+
seq_start_i = seq_start[idx].item()
383+
384+
if seq_len_i == 0:
385+
continue
386+
387+
q_seq = q[seq_start_i : seq_start_i + seq_len_i]
388+
kv_seq_len = input_pos_i + seq_len_i
389+
k_seq = k_cache[slot_idx_i, :kv_seq_len]
390+
v_seq = v_cache[slot_idx_i, :kv_seq_len]
391+
392+
n_heads = q_seq.shape[1]
393+
n_kv_heads = k_seq.shape[1]
394+
395+
q_seq_t = q_seq.transpose(0, 1).unsqueeze(0)
396+
k_seq_t = k_seq.transpose(0, 1).unsqueeze(0)
397+
v_seq_t = v_seq.transpose(0, 1).unsqueeze(0)
398+
399+
if n_heads != n_kv_heads:
400+
n_rep = n_heads // n_kv_heads
401+
k_seq_t = repeat_kv(k_seq_t, n_rep)
402+
v_seq_t = repeat_kv(v_seq_t, n_rep)
403+
404+
attn_scores = torch.matmul(q_seq_t, k_seq_t.transpose(-2, -1)) * scale
405+
406+
causal_mask = torch.triu(
407+
torch.ones(seq_len_i, kv_seq_len, device=q.device, dtype=torch.bool),
408+
diagonal=1 + input_pos_i,
409+
)
410+
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
411+
412+
if sliding_window_size is not None and sliding_window_size > 0:
413+
query_positions = torch.arange(input_pos_i, input_pos_i + seq_len_i, device=q.device)
414+
key_positions = torch.arange(kv_seq_len, device=q.device)
415+
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0)
416+
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window_size)
417+
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
418+
419+
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
420+
421+
if sinks is not None:
422+
new_sinks = sinks.reshape(1, -1, 1, 1).expand(1, n_heads, seq_len_i, 1)
423+
attn_weights = torch.cat([attn_scores, new_sinks], dim=-1)
424+
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
425+
attn_out = torch.matmul(attn_weights[..., : -new_sinks.size(-1)], v_seq_t)
426+
else:
427+
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
428+
attn_out = torch.matmul(attn_weights, v_seq_t)
429+
430+
attn_outputs.append(attn_out[0].transpose(0, 1))
431+
432+
if len(attn_outputs) == 0:
433+
out.zero_()
434+
elif len(attn_outputs) == 1:
435+
out.copy_(attn_outputs[0])
436+
else:
437+
out.copy_(torch.cat(attn_outputs, dim=0))
438+
439+
440+
@torch.library.custom_op("auto_deploy::torch_cached_attention_with_cache", mutates_args=())
299441
def torch_backend_mha_with_cache(
300442
# Q, K, V
301443
q: torch.Tensor,
@@ -320,6 +462,7 @@ def torch_backend_mha_with_cache(
320462
sinks: Optional[torch.Tensor] = None,
321463
sliding_window_size: Optional[int] = None,
322464
logit_cap: Optional[float] = None,
465+
read_cache_only: bool = False,
323466
out: Optional[torch.Tensor] = None,
324467
) -> torch.Tensor:
325468
"""Torch backend MHA with cache that takes q, k, v in BSND layout."""
@@ -359,12 +502,15 @@ def torch_backend_mha_with_cache(
359502
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
360503

361504
# Compute attention
505+
if not read_cache_only:
506+
if s == 1:
507+
_write_generate_kv_cache(k, v, k_cache, v_cache, slot_idx, input_pos)
508+
else:
509+
_update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
510+
362511
if s == 1:
363-
# Generate-only phase
364-
_torch_generate_mha(
512+
_torch_generate_mha_readonly(
365513
q,
366-
k,
367-
v,
368514
k_cache,
369515
v_cache,
370516
slot_idx,
@@ -376,11 +522,8 @@ def torch_backend_mha_with_cache(
376522
sinks,
377523
)
378524
else:
379-
# Context phase
380-
_torch_context_mha(
525+
_torch_context_mha_readonly(
381526
q,
382-
k,
383-
v,
384527
input_pos,
385528
slot_idx,
386529
k_cache,
@@ -437,6 +580,7 @@ def torch_backend_mha_with_cache_fake(
437580
sinks: Optional[torch.Tensor] = None,
438581
sliding_window_size: Optional[int] = None,
439582
logit_cap: Optional[float] = None,
583+
read_cache_only: bool = False,
440584
out: Optional[torch.Tensor] = None,
441585
) -> torch.Tensor:
442586
if out is not None:
@@ -464,6 +608,10 @@ def get_source_attention_op(cls) -> OpOverloadPacket:
464608
def get_cached_attention_op(cls) -> MHACallable:
465609
return torch.ops.auto_deploy.torch_cached_attention_with_cache.default
466610

611+
@classmethod
612+
def supports_shared_kv(cls) -> bool:
613+
return True
614+
467615
@classmethod
468616
def get_standard_metadata_args(cls) -> List[str]:
469617
return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
@@ -537,4 +685,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
537685
sinks, # sinks parameter
538686
sliding_window, # sliding window parameter
539687
logit_cap, # logit cap parameter
688+
cls.get_shared_kv_source_layer_idx(source_attn_node) is not None, # read_cache_only
540689
]

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,27 +1359,22 @@ def get_dynamic_inputs(cls, source_attn_node: Node) -> List[Optional[Node]]:
13591359

13601360
@classmethod
13611361
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
1362-
layout = extract_op_args(source_attn_node, "layout")[0]
1362+
layout, scale, attn_mask, dropout_p, is_causal = extract_op_args(
1363+
source_attn_node, "layout", "scale", "attn_mask", "dropout_p", "is_causal"
1364+
)
1365+
13631366
if layout != "bsnd":
13641367
raise RuntimeError(
13651368
f"Expected torch_attention layout='bsnd' but got {layout!r} "
13661369
f"for node: {source_attn_node.format_node()}"
13671370
)
13681371

1369-
attn_mask, dropout_p, is_causal = extract_op_args(
1370-
source_attn_node, "attn_mask", "dropout_p", "is_causal"
1371-
)
13721372
if dropout_p != 0.0 or not is_causal:
13731373
ad_logger.debug(
13741374
"Unsupported attention arguments for "
13751375
f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}"
13761376
)
13771377

1378-
if len(source_attn_node.args) > 6:
1379-
scale = source_attn_node.args[6]
1380-
else:
1381-
scale = source_attn_node.kwargs.get("scale", None)
1382-
13831378
if not (isinstance(scale, float) or scale is None):
13841379
ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.")
13851380
scale = None

0 commit comments

Comments
 (0)