File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -117,15 +117,19 @@ def __init__(self, config: ModelConfig) -> None:
117117 self .max_graph_bs = 0
118118 self .graph_wrappers : Dict [int , CUDAGraphBatchDecodeWithPagedKVCacheWrapper ] = {}
119119 self .capture : FICaptureData | None = None
120+ self .last_event = torch .cuda .Event ()
121+ self .last_event .record ()
120122
121- @staticmethod
122- def _initialize_metadata_once (metadata : FIMetadata ) -> None :
123+ def _initialize_metadata_once (self , metadata : FIMetadata ) -> None :
123124 if metadata .initialized :
124125 return
125126
126127 from flashinfer import BatchDecodeWithPagedKVCacheWrapper
127128
128129 metadata .initialized = True
130+ # FlashInfer planning reuses a pinned host staging buffer and launches an
131+ # async H2D copy. Wait here before the next plan mutates that host buffer.
132+ self .last_event .synchronize ()
129133 if isinstance (metadata .wrapper , BatchDecodeWithPagedKVCacheWrapper ):
130134 metadata .wrapper .plan (
131135 indptr = metadata .cu_seqlens_k_cpu ,
@@ -159,6 +163,7 @@ def _initialize_metadata_once(metadata: FIMetadata) -> None:
159163 non_blocking = True ,
160164 causal = True ,
161165 )
166+ self .last_event .record ()
162167
163168 def _get_ones_cpu (self , bs : int ) -> torch .Tensor :
164169 if bs <= len (self .cached_ones_cpu ):
Original file line number Diff line number Diff line change @@ -67,7 +67,6 @@ class EnvClassSingleton:
6767 # backend runtime
6868 FLASHINFER_USE_TENSOR_CORES = EnvOption ()
6969 DISABLE_OVERLAP_SCHEDULING = EnvBool (False )
70- OVERLAP_EXTRA_SYNC = EnvBool (False )
7170 PYNCCL_MAX_BUFFER_SIZE = EnvMem (1024 ** 3 )
7271
7372 def __new__ (cls ):
Original file line number Diff line number Diff line change @@ -227,8 +227,6 @@ def _schedule_next_batch(self) -> ForwardInput | None:
227227 def _forward (self , forward_input : ForwardInput ) -> ForwardOutput :
228228 batch , sample_args , input_mapping , output_mapping = forward_input
229229 batch .input_ids = self .token_pool [input_mapping ]
230- if ENV .OVERLAP_EXTRA_SYNC : # NOTE: https://github.com/sgl-project/mini-sglang/issues/58
231- self .stream .synchronize ()
232230 forward_output = self .engine .forward_batch (batch , sample_args )
233231 self .token_pool [output_mapping ] = forward_output .next_tokens_gpu
234232 self .decode_manager .filter_reqs (forward_input .batch .reqs )
You can’t perform that action at this time.
0 commit comments