@@ -124,6 +124,14 @@ def error(self, msg):
124124 self .messages .append (("error" , msg ))
125125
126126
127+ class _QueueRecorder :
128+ def __init__ (self ):
129+ self .items = []
130+
131+ def put (self , item ):
132+ self .items .append (item )
133+
134+
127135class _DummySignalValue :
128136 def __init__ (self , sequence ):
129137 self .sequence = list (sequence )
@@ -390,6 +398,111 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch):
390398 assert messager .cache_info ["req-2" ]["status" ] == "init"
391399
392400
401+ def test_cache_messager_v1_recovers_pending_layer0_signal (monkeypatch ):
402+ dummy_queue = _DummyEngineWorkerQueue (
403+ cache_info_sequence = [
404+ [
405+ {
406+ "request_id" : "req-pending" ,
407+ "src_block_ids" : [0 , 1 ],
408+ "dest_block_ids" : [2 ],
409+ "current_id" : 3 ,
410+ "need_prefill_tokens" : 128 ,
411+ "transfer_protocol" : "rdma" ,
412+ }
413+ ]
414+ ]
415+ )
416+ monkeypatch .setattr (cache_messager , "EngineWorkerQueue" , lambda * args , ** kwargs : dummy_queue )
417+ monkeypatch .setattr (cache_messager , "RDMACommManager" , _DummyRDMACommManager )
418+ monkeypatch .setattr (cache_messager , "logger" , _DummyLogger (), raising = False )
419+
420+ gpu_cache_kvs = _build_cache_kvs (dtype = "float16" , include_value_cache = True , num_layers = 1 )
421+ messager = cache_messager .CacheMessagerV1 (
422+ splitwise_role = "mixed" ,
423+ transfer_protocol = "rdma" ,
424+ pod_ip = "0.0.0.0" ,
425+ engine_worker_queue_port = 9000 ,
426+ local_data_parallel_id = 0 ,
427+ gpu_cache_kvs = gpu_cache_kvs ,
428+ rank = 0 ,
429+ nranks = 1 ,
430+ num_layers = 1 ,
431+ gpu_id = 0 ,
432+ block_size = 64 ,
433+ rdma_port = "2222" ,
434+ )
435+ messager .cache_prefilled_engine_ids_queue = _QueueRecorder ()
436+ messager .cache_info ["req-pending" ] = {
437+ "request_id" : "req-pending" ,
438+ "src_block_ids" : [0 , 1 ],
439+ "dest_block_ids" : [2 ],
440+ "current_id" : 3 ,
441+ "need_prefill_tokens" : 128 ,
442+ "transfer_protocol" : "rdma" ,
443+ }
444+ messager .pending_layer0_signals [3 ] = (3 , 64 )
445+ messager .pending_layer0_signals [4 ] = (4 , 64 )
446+
447+ with pytest .raises (SystemExit ):
448+ messager ._add_cache_task_thread ()
449+
450+ assert messager .pending_layer0_signals == {4 : (4 , 64 )}
451+ assert messager .cache_prefilled_engine_ids_queue .items == [[(3 , 64 )]]
452+
453+
454+ def test_cache_messager_v1_drops_invalid_pending_layer0_signal (monkeypatch ):
455+ dummy_queue = _DummyEngineWorkerQueue (
456+ cache_info_sequence = [
457+ [
458+ {
459+ "request_id" : "req-pending" ,
460+ "src_block_ids" : [0 , 1 ],
461+ "dest_block_ids" : [2 ],
462+ "current_id" : 3 ,
463+ "need_prefill_tokens" : 128 ,
464+ "transfer_protocol" : "rdma" ,
465+ }
466+ ]
467+ ]
468+ )
469+ monkeypatch .setattr (cache_messager , "EngineWorkerQueue" , lambda * args , ** kwargs : dummy_queue )
470+ monkeypatch .setattr (cache_messager , "RDMACommManager" , _DummyRDMACommManager )
471+ monkeypatch .setattr (cache_messager , "logger" , _DummyLogger (), raising = False )
472+
473+ gpu_cache_kvs = _build_cache_kvs (dtype = "float16" , include_value_cache = True , num_layers = 1 )
474+ messager = cache_messager .CacheMessagerV1 (
475+ splitwise_role = "mixed" ,
476+ transfer_protocol = "rdma" ,
477+ pod_ip = "0.0.0.0" ,
478+ engine_worker_queue_port = 9000 ,
479+ local_data_parallel_id = 0 ,
480+ gpu_cache_kvs = gpu_cache_kvs ,
481+ rank = 0 ,
482+ nranks = 1 ,
483+ num_layers = 1 ,
484+ gpu_id = 0 ,
485+ block_size = 64 ,
486+ rdma_port = "2222" ,
487+ )
488+ messager .cache_prefilled_engine_ids_queue = _QueueRecorder ()
489+ messager .cache_info ["req-pending" ] = {
490+ "request_id" : "req-pending" ,
491+ "src_block_ids" : [0 , 1 ],
492+ "dest_block_ids" : [2 ],
493+ "current_id" : 3 ,
494+ "need_prefill_tokens" : 128 ,
495+ "transfer_protocol" : "rdma" ,
496+ }
497+ messager .pending_layer0_signals [3 ] = (3 , 256 )
498+
499+ with pytest .raises (SystemExit ):
500+ messager ._add_cache_task_thread ()
501+
502+ assert messager .pending_layer0_signals == {}
503+ assert messager .cache_prefilled_engine_ids_queue .items == []
504+
505+
393506def test_cache_messager_v1_prefill_layerwise_send_cache_thread (monkeypatch ):
394507 class _OneShotQueue :
395508 def __init__ (self ):
@@ -435,10 +548,12 @@ def get(self):
435548 }
436549 messager .engine_cache_tasks [0 ] = {"prefilled_layer_idx" : 1 , "prefilled_token_num" : 64 }
437550 messager .cache_info ["req-3" ] = messager .idx_cache_task_dict [0 ]
551+ messager .pending_layer0_signals = {0 : (0 , 64 ), 1 : (1 , 64 )}
438552 with pytest .raises (SystemExit ):
439553 messager .prefill_layerwise_send_cache_thread ()
440554 assert dummy_queue .finished_req_payloads
441555 assert dummy_queue .finished_req_payloads [0 ][0 ][0 ] == "req-3"
556+ assert messager .pending_layer0_signals == {1 : (1 , 64 )}
442557
443558
444559def test_cache_messager_v1_handle_connect_task (monkeypatch ):
@@ -562,13 +677,6 @@ def test_cache_messager_v1_consume_signals(monkeypatch):
562677 monkeypatch .setattr (cache_messager , "RDMACommManager" , _DummyRDMACommManager )
563678 monkeypatch .setattr (cache_messager , "logger" , _DummyLogger (), raising = False )
564679
565- class _QueueRecorder :
566- def __init__ (self ):
567- self .items = []
568-
569- def put (self , item ):
570- self .items .append (item )
571-
572680 counter = {"calls" : 0 }
573681
574682 def _fake_get_output_kv_signal (kv_signal_data , rank_id , wait_flag ):
@@ -600,12 +708,57 @@ def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
600708 rdma_port = "2222" ,
601709 )
602710 messager .cache_info ["req-4" ] = {"request_id" : "req-4" }
711+ messager .idx_cache_task_dict [2 ] = {"request_id" : "req-4" , "current_id" : 2 }
603712 messager .cache_prefilled_engine_ids_queue = _QueueRecorder ()
604713 with pytest .raises (SystemExit ):
605714 messager .consume_signals ()
606715 assert messager .cache_prefilled_engine_ids_queue .items == [[(2 , 9 )]]
607716
608717
718+ def test_cache_messager_v1_consume_signals_buffers_early_layer0 (monkeypatch ):
719+ monkeypatch .setattr (cache_messager , "EngineWorkerQueue" , _DummyEngineWorkerQueue )
720+ monkeypatch .setattr (cache_messager , "RDMACommManager" , _DummyRDMACommManager )
721+ monkeypatch .setattr (cache_messager , "logger" , _DummyLogger (), raising = False )
722+
723+ signals = [(5 , 7 , 9 ), (5 , 17 , 19 )]
724+
725+ def _fake_get_output_kv_signal (kv_signal_data , rank_id , wait_flag ):
726+ if not signals :
727+ raise SystemExit
728+ engine_idx , chuck_token_offset , current_seq_len = signals .pop (0 )
729+ data = np .full (kv_signal_data .shape , - 1 , dtype = "int32" )
730+ data [0 ] = 1
731+ data [1 ] = 0
732+ data [2 ] = engine_idx
733+ data [3 ] = chuck_token_offset
734+ data [4 ] = current_seq_len
735+ kv_signal_data .set_value (data )
736+
737+ monkeypatch .setattr (cache_messager , "get_output_kv_signal" , _fake_get_output_kv_signal )
738+ gpu_cache_kvs = _build_cache_kvs (dtype = "float16" , include_value_cache = False , num_layers = 1 )
739+ messager = cache_messager .CacheMessagerV1 (
740+ splitwise_role = "mixed" ,
741+ transfer_protocol = "rdma" ,
742+ pod_ip = "0.0.0.0" ,
743+ engine_worker_queue_port = 9000 ,
744+ local_data_parallel_id = 0 ,
745+ gpu_cache_kvs = gpu_cache_kvs ,
746+ rank = 0 ,
747+ nranks = 1 ,
748+ num_layers = 1 ,
749+ gpu_id = 0 ,
750+ block_size = 64 ,
751+ rdma_port = "2222" ,
752+ )
753+ messager .cache_prefilled_engine_ids_queue = _QueueRecorder ()
754+
755+ with pytest .raises (SystemExit ):
756+ messager .consume_signals ()
757+
758+ assert messager .pending_layer0_signals == {5 : (5 , 36 )}
759+ assert messager .cache_prefilled_engine_ids_queue .items == []
760+
761+
609762def test_main_initializes_cache_and_exits (monkeypatch ):
610763 monkeypatch .setattr (cache_messager , "set_device" , lambda device : None )
611764 monkeypatch .setattr (cache_messager , "set_data_ipc" , lambda tensor , name : None )
0 commit comments