Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions python/minisgl/attention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ def get_last_indices(self, bs: int) -> torch.Tensor: ...
class BaseAttnBackend(ABC):
@abstractmethod
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
batch: Batch,
*,
window_size: tuple[int, int] = (-1, -1),
softmax_scale: float | None = None,
) -> torch.Tensor: ...

@abstractmethod
Expand All @@ -44,10 +52,26 @@ def __init__(
self.decode_backend = decode_backend

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
batch: Batch,
*,
window_size: tuple[int, int] = (-1, -1),
softmax_scale: float | None = None,
) -> torch.Tensor:
backend = self.prefill_backend if batch.is_prefill else self.decode_backend
return backend.forward(q, k, v, layer_id, batch)
return backend.forward(
q,
k,
v,
layer_id,
batch,
window_size=window_size,
softmax_scale=softmax_scale,
)

def prepare_metadata(self, batch: Batch) -> None:
backend = self.prefill_backend if batch.is_prefill else self.decode_backend
Expand Down
13 changes: 11 additions & 2 deletions python/minisgl/attention/fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@ def __init__(self, config: ModelConfig):
self.version = 4 if is_sm100_supported() else 3

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
batch: Batch,
*,
window_size: tuple[int, int] = (-1, -1),
softmax_scale: float | None = None,
) -> torch.Tensor:
metadata = batch.attn_metadata
assert isinstance(metadata, FAMetadata)
Expand All @@ -60,8 +68,9 @@ def forward(
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seqlen_q,
softmax_scale=self.scale,
softmax_scale=self.scale if softmax_scale is None else softmax_scale,
version=self.version,
window_size=window_size,
)

def prepare_metadata(self, batch: Batch) -> None:
Expand Down
180 changes: 149 additions & 31 deletions python/minisgl/attention/fi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def indices(self) -> torch.Tensor:
return self.page_table


@dataclass
class _FIGraphWrappers:
wrapper: CUDAGraphBatchDecodeWithPagedKVCacheWrapper
sliding_wrapper: CUDAGraphBatchDecodeWithPagedKVCacheWrapper | None = None


@dataclass
class FIMetadata(BaseAttnMetadata):
# fmt: off
Expand All @@ -58,7 +64,17 @@ class FIMetadata(BaseAttnMetadata):
pos_encoding_mode: str
seq_lens_cpu: torch.Tensor # on cpu
dtype: torch.dtype
wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDecodeWithPagedKVCacheWrapper
wrapper: (
BatchPrefillWithPagedKVCacheWrapper
| BatchDecodeWithPagedKVCacheWrapper
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper
)
sliding_wrapper: (
BatchPrefillWithPagedKVCacheWrapper
| BatchDecodeWithPagedKVCacheWrapper
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper
| None
) = None
initialized: bool = False
# fmt: on

Expand Down Expand Up @@ -101,10 +117,29 @@ def __init__(self, config: ModelConfig) -> None:
kv_layout="NHD",
backend="fa2", # flashinfer fa3 is slow, use fa2 instead
)
self.sliding_prefill_wrapper = None
self.sliding_decode_wrappers = None
if self.config.has_sliding_attention:
self.sliding_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self.float_workspace_buffer,
kv_layout="NHD",
backend="fa2",
)
self.sliding_decode_wrappers = BatchDecodeWithPagedKVCacheWrapper(
self.float_workspace_buffer,
use_tensor_cores=self.use_tensor_cores,
kv_layout="NHD",
backend="fa2",
)

# NOTE: some hack to reuse the int_workspace_buffer
self.int_workspace_buffer = self.prefill_wrapper._int_workspace_buffer
self.decode_wrappers._int_workspace_buffer = self.int_workspace_buffer
if self.sliding_prefill_wrapper is not None:
assert self.sliding_decode_wrappers is not None
self.sliding_decode_wrappers._int_workspace_buffer = (
self.sliding_prefill_wrapper._int_workspace_buffer
)

# initialize some data members
tp_size = get_tp_info().size
Expand All @@ -115,23 +150,37 @@ def __init__(self, config: ModelConfig) -> None:
# for cuda graph
self.capture_bs: List[int] = []
self.max_graph_bs = 0
self.graph_wrappers: Dict[int, CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = {}
self.graph_wrappers: Dict[int, _FIGraphWrappers] = {}
self.graph_softmax_scale = (
self.config.query_pre_attn_scalar**-0.5
if self.config.query_pre_attn_scalar is not None
else None
)
self.capture: FICaptureData | None = None
self.last_event = torch.cuda.Event()
self.last_event.record()

def _initialize_metadata_once(self, metadata: FIMetadata) -> None:
if metadata.initialized:
return

from flashinfer import BatchDecodeWithPagedKVCacheWrapper
def _plan_wrapper(
self,
metadata: FIMetadata,
wrapper: (
BatchPrefillWithPagedKVCacheWrapper
| BatchDecodeWithPagedKVCacheWrapper
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper
),
softmax_scale: float | None,
window_left: int,
) -> None:
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)

metadata.initialized = True
# FlashInfer planning reuses a pinned host staging buffer and launches an
# async H2D copy. Wait here before the next plan mutates that host buffer.
self.last_event.synchronize()
if isinstance(metadata.wrapper, BatchDecodeWithPagedKVCacheWrapper):
metadata.wrapper.plan(
if isinstance(
wrapper,
(BatchDecodeWithPagedKVCacheWrapper, CUDAGraphBatchDecodeWithPagedKVCacheWrapper),
):
wrapper.plan(
indptr=metadata.cu_seqlens_k_cpu,
indices=metadata.indices,
last_page_len=metadata.last_page_len_cpu,
Expand All @@ -144,10 +193,12 @@ def _initialize_metadata_once(self, metadata: FIMetadata) -> None:
data_type=metadata.dtype,
q_data_type=metadata.dtype,
kv_data_type=metadata.dtype,
sm_scale=softmax_scale,
window_left=window_left,
non_blocking=True,
)
else:
metadata.wrapper.plan(
wrapper.plan(
qo_indptr=metadata.cu_seqlens_q_cpu,
paged_kv_indptr=metadata.cu_seqlens_k_cpu,
paged_kv_indices=metadata.indices,
Expand All @@ -160,9 +211,31 @@ def _initialize_metadata_once(self, metadata: FIMetadata) -> None:
seq_lens=metadata.seq_lens_cpu,
q_data_type=metadata.dtype,
kv_data_type=metadata.dtype,
sm_scale=softmax_scale,
window_left=window_left,
non_blocking=True,
causal=True,
)

def _initialize_metadata_once(
self, metadata: FIMetadata, softmax_scale: float | None = None
) -> None:
if metadata.initialized:
return

# FlashInfer planning launches async H2D copies from host metadata. Wait
# before the next batch can reuse those host-side staging tensors.
self.last_event.synchronize()
self._plan_wrapper(metadata, metadata.wrapper, softmax_scale, window_left=-1)
if metadata.sliding_wrapper is not None:
assert self.config.sliding_window is not None
self._plan_wrapper(
metadata,
metadata.sliding_wrapper,
softmax_scale,
window_left=self.config.sliding_window - 1,
)
metadata.initialized = True
self.last_event.record()

def _get_ones_cpu(self, bs: int) -> torch.Tensor:
Expand All @@ -174,18 +247,34 @@ def _get_ones_cpu(self, bs: int) -> torch.Tensor:
return self.cached_ones_cpu[:bs]

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
batch: Batch,
*,
window_size: tuple[int, int] = (-1, -1),
softmax_scale: float | None = None,
) -> torch.Tensor:
def _flatten_cache(cache: torch.Tensor) -> torch.Tensor: # treat page = 1
return cache.view(-1, 1, cache.shape[2], cache.shape[3])

metadata = batch.attn_metadata
assert isinstance(metadata, FIMetadata)
self._initialize_metadata_once(metadata)
if window_size == (-1, -1):
wrapper = metadata.wrapper
else:
assert self.config.sliding_window is not None
assert window_size == (self.config.sliding_window - 1, 0)
assert metadata.sliding_wrapper is not None
wrapper = metadata.sliding_wrapper

self._initialize_metadata_once(metadata, softmax_scale)
self.kvcache.store_kv(k, v, batch.out_loc, layer_id)
kv_cache = (self.kvcache.k_cache(layer_id), self.kvcache.v_cache(layer_id))
kv_cache = (_flatten_cache(kv_cache[0]), _flatten_cache(kv_cache[1]))
return metadata.wrapper.run(q=q, paged_kv_cache=kv_cache)
return wrapper.run(q=q, paged_kv_cache=kv_cache)

def prepare_metadata(self, batch: Batch) -> None:
reqs = batch.padded_reqs
Expand All @@ -208,6 +297,14 @@ def prepare_metadata(self, batch: Batch) -> None:
cu_seqlens_q_cpu = torch.tensor([0] + seqlens_q, **CPU_KWARGS).cumsum_(dim=0)

page_table = get_global_ctx().page_table
wrapper = self.decode_wrappers if batch.is_decode else self.prefill_wrapper
sliding_wrapper = None
if self.config.has_sliding_attention:
sliding_wrapper = (
self.sliding_decode_wrappers if batch.is_decode else self.sliding_prefill_wrapper
)
assert sliding_wrapper is not None

batch.attn_metadata = FIMetadata(
cu_seqlens_q_cpu=cu_seqlens_q_cpu,
cu_seqlens_k_cpu=cu_seqlens_k_cpu,
Expand All @@ -221,7 +318,8 @@ def prepare_metadata(self, batch: Batch) -> None:
pos_encoding_mode="NONE",
seq_lens_cpu=seq_len_cpu,
dtype=self.kvcache.dtype,
wrapper=self.decode_wrappers if batch.is_decode else self.prefill_wrapper,
wrapper=wrapper,
sliding_wrapper=sliding_wrapper,
)

def init_capture_graph(self, max_seq_len: int, bs_list: List[int]) -> None:
Expand All @@ -247,25 +345,45 @@ def prepare_for_capture(self, batch: Batch) -> None:
bs = batch.size
assert bs in self.capture_bs and bs not in self.graph_wrappers and self.capture
capture = self.capture
self.graph_wrappers[bs] = CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self.float_workspace_buffer,
kv_layout="NHD",
use_tensor_cores=self.use_tensor_cores,
indptr_buffer=capture.cu_seqlens_k[: bs + 1],
indices_buffer=capture.indices,
last_page_len_buffer=capture.one_tensor[:bs],
graph_wrappers = _FIGraphWrappers(
wrapper=CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self.float_workspace_buffer,
kv_layout="NHD",
use_tensor_cores=self.use_tensor_cores,
indptr_buffer=capture.cu_seqlens_k[: bs + 1],
indices_buffer=capture.indices,
last_page_len_buffer=capture.one_tensor[:bs],
)
)
self.graph_wrappers[bs]._backend = "fa2"
self.graph_wrappers[bs]._int_workspace_buffer = self.int_workspace_buffer
graph_wrappers.wrapper._backend = "fa2"
graph_wrappers.wrapper._int_workspace_buffer = self.int_workspace_buffer
if self.config.has_sliding_attention:
assert self.sliding_prefill_wrapper is not None
graph_wrappers.sliding_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self.float_workspace_buffer,
kv_layout="NHD",
use_tensor_cores=self.use_tensor_cores,
indptr_buffer=capture.cu_seqlens_k[: bs + 1],
indices_buffer=capture.indices,
last_page_len_buffer=capture.one_tensor[:bs],
)
graph_wrappers.sliding_wrapper._backend = "fa2"
graph_wrappers.sliding_wrapper._int_workspace_buffer = (
self.sliding_prefill_wrapper._int_workspace_buffer
)
self.graph_wrappers[bs] = graph_wrappers
self.prepare_metadata(batch)
metadata = batch.attn_metadata
assert isinstance(metadata, FIMetadata)
metadata.wrapper = self.graph_wrappers[bs]
self._initialize_metadata_once(metadata)
metadata.wrapper = graph_wrappers.wrapper
metadata.sliding_wrapper = graph_wrappers.sliding_wrapper
self._initialize_metadata_once(metadata, self.graph_softmax_scale)

def prepare_for_replay(self, batch: Batch) -> None:
metadata, bs = batch.attn_metadata, batch.padded_size
assert isinstance(metadata, FIMetadata) and not metadata.initialized
assert self.capture is not None and bs in self.capture_bs
metadata.wrapper = self.graph_wrappers[bs]
self._initialize_metadata_once(metadata)
graph_wrappers = self.graph_wrappers[bs]
metadata.wrapper = graph_wrappers.wrapper
metadata.sliding_wrapper = graph_wrappers.sliding_wrapper
self._initialize_metadata_once(metadata, self.graph_softmax_scale)
13 changes: 12 additions & 1 deletion python/minisgl/attention/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,22 @@ def __init__(self, config: ModelConfig):
)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
batch: Batch,
*,
window_size: tuple[int, int] = (-1, -1),
softmax_scale: float | None = None,
) -> torch.Tensor:
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache

if window_size != (-1, -1) or softmax_scale is not None:
raise NotImplementedError

metadata = batch.attn_metadata
assert isinstance(metadata, TRTLLMMetadata)
self.kvcache.store_kv(k, v, batch.out_loc, layer_id)
Expand Down
Loading