@@ -1643,6 +1643,9 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = ""):
16431643 )
16441644 for _ in range (self .deepstack_num_level )
16451645 ]
1646+ # Tracks the valid token span currently stored in the buffer.
1647+ # Zero means there is no active deepstack payload to consume.
1648+ self .deepstack_input_embeds_num_tokens = 0
16461649
16471650 with self ._mark_language_model (aphrodite_config ):
16481651 self .language_model = Qwen3MoeLLMForCausalLM (
@@ -1661,6 +1664,13 @@ def _get_deepstack_input_embeds(
16611664 ) -> IntermediateTensors | None :
16621665 if not getattr (self , "deepstack_input_embeds" , None ):
16631666 return None # If vision tower is skipped
1667+ if getattr (self , "deepstack_input_embeds_num_tokens" , 0 ) == 0 :
1668+ return None
1669+ if num_tokens > self .deepstack_input_embeds_num_tokens :
1670+ raise ValueError (
1671+ "Requested more deepstack tokens than available in buffer: "
1672+ f"{ num_tokens = } > { self .deepstack_input_embeds_num_tokens = } "
1673+ )
16641674
16651675 # get deepstack_input_embeds from buffer, and clear the buffer
16661676 return IntermediateTensors (
@@ -1689,14 +1699,25 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
16891699 for idx in range (self .deepstack_num_level ):
16901700 self .deepstack_input_embeds [idx ][:num_tokens ].copy_ (deepstack_input_embeds [idx ])
16911701
1702+ self .deepstack_input_embeds_num_tokens = num_tokens
1703+
16921704 def _clear_deepstack_input_embeds (self , num_tokens : int ) -> None :
16931705 if not getattr (self , "deepstack_input_embeds" , None ):
16941706 return
1707+ if getattr (self , "deepstack_input_embeds_num_tokens" , 0 ) == 0 :
1708+ return
16951709
16961710 # clear deepstack_input_embeds in buffer
16971711 if num_tokens > 0 :
1712+ if num_tokens > self .deepstack_input_embeds_num_tokens :
1713+ raise ValueError (
1714+ "Requested to clear more deepstack tokens than available in "
1715+ "buffer: "
1716+ f"{ num_tokens = } > { self .deepstack_input_embeds_num_tokens = } "
1717+ )
16981718 for idx in range (self .deepstack_num_level ):
16991719 self .deepstack_input_embeds [idx ][:num_tokens ].zero_ ()
1720+ self .deepstack_input_embeds_num_tokens = 0
17001721
17011722 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
17021723 mm_input_by_modality = {}
0 commit comments