Skip to content

Commit 20fcd7f

Browse files
[Fix] Fix race condition in flashinfer backend (#103)
1 parent db31896 commit 20fcd7f

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

python/minisgl/attention/fi.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,19 @@ def __init__(self, config: ModelConfig) -> None:
117117
self.max_graph_bs = 0
118118
self.graph_wrappers: Dict[int, CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = {}
119119
self.capture: FICaptureData | None = None
120+
self.last_event = torch.cuda.Event()
121+
self.last_event.record()
120122

121-
@staticmethod
122-
def _initialize_metadata_once(metadata: FIMetadata) -> None:
123+
def _initialize_metadata_once(self, metadata: FIMetadata) -> None:
123124
if metadata.initialized:
124125
return
125126

126127
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
127128

128129
metadata.initialized = True
130+
# FlashInfer planning reuses a pinned host staging buffer and launches an
131+
# async H2D copy. Wait here before the next plan mutates that host buffer.
132+
self.last_event.synchronize()
129133
if isinstance(metadata.wrapper, BatchDecodeWithPagedKVCacheWrapper):
130134
metadata.wrapper.plan(
131135
indptr=metadata.cu_seqlens_k_cpu,
@@ -159,6 +163,7 @@ def _initialize_metadata_once(metadata: FIMetadata) -> None:
159163
non_blocking=True,
160164
causal=True,
161165
)
166+
self.last_event.record()
162167

163168
def _get_ones_cpu(self, bs: int) -> torch.Tensor:
164169
if bs <= len(self.cached_ones_cpu):

python/minisgl/env.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class EnvClassSingleton:
6767
# backend runtime
6868
FLASHINFER_USE_TENSOR_CORES = EnvOption()
6969
DISABLE_OVERLAP_SCHEDULING = EnvBool(False)
70-
OVERLAP_EXTRA_SYNC = EnvBool(False)
7170
PYNCCL_MAX_BUFFER_SIZE = EnvMem(1024**3)
7271

7372
def __new__(cls):

python/minisgl/scheduler/scheduler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,6 @@ def _schedule_next_batch(self) -> ForwardInput | None:
227227
def _forward(self, forward_input: ForwardInput) -> ForwardOutput:
228228
batch, sample_args, input_mapping, output_mapping = forward_input
229229
batch.input_ids = self.token_pool[input_mapping]
230-
if ENV.OVERLAP_EXTRA_SYNC: # NOTE: https://github.com/sgl-project/mini-sglang/issues/58
231-
self.stream.synchronize()
232230
forward_output = self.engine.forward_batch(batch, sample_args)
233231
self.token_pool[output_mapping] = forward_output.next_tokens_gpu
234232
self.decode_manager.filter_reqs(forward_input.batch.reqs)

0 commit comments

Comments
 (0)