Skip to content

Commit 3b0472d

Browse files
authored
chore: optimize deepstack buffer handling for MM Qwen3 models (#1643)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 6b19cf8 commit 3b0472d

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

aphrodite/model_executor/models/qwen3_omni_moe_thinker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,9 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = ""):
16431643
)
16441644
for _ in range(self.deepstack_num_level)
16451645
]
1646+
# Tracks the valid token span currently stored in the buffer.
1647+
# Zero means there is no active deepstack payload to consume.
1648+
self.deepstack_input_embeds_num_tokens = 0
16461649

16471650
with self._mark_language_model(aphrodite_config):
16481651
self.language_model = Qwen3MoeLLMForCausalLM(
@@ -1661,6 +1664,13 @@ def _get_deepstack_input_embeds(
16611664
) -> IntermediateTensors | None:
16621665
if not getattr(self, "deepstack_input_embeds", None):
16631666
return None # If vision tower is skipped
1667+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1668+
return None
1669+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1670+
raise ValueError(
1671+
"Requested more deepstack tokens than available in buffer: "
1672+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1673+
)
16641674

16651675
# get deepstack_input_embeds from buffer, and clear the buffer
16661676
return IntermediateTensors(
@@ -1689,14 +1699,25 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
16891699
for idx in range(self.deepstack_num_level):
16901700
self.deepstack_input_embeds[idx][:num_tokens].copy_(deepstack_input_embeds[idx])
16911701

1702+
self.deepstack_input_embeds_num_tokens = num_tokens
1703+
16921704
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
16931705
if not getattr(self, "deepstack_input_embeds", None):
16941706
return
1707+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1708+
return
16951709

16961710
# clear deepstack_input_embeds in buffer
16971711
if num_tokens > 0:
1712+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1713+
raise ValueError(
1714+
"Requested to clear more deepstack tokens than available in "
1715+
"buffer: "
1716+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1717+
)
16981718
for idx in range(self.deepstack_num_level):
16991719
self.deepstack_input_embeds[idx][:num_tokens].zero_()
1720+
self.deepstack_input_embeds_num_tokens = 0
17001721

17011722
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
17021723
mm_input_by_modality = {}

aphrodite/model_executor/models/qwen3_vl.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,9 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = "model"):
15921592
)
15931593
for _ in range(self.deepstack_num_level)
15941594
]
1595+
# Tracks the valid token span currently stored in the buffer.
1596+
# Zero means there is no active deepstack payload to consume.
1597+
self.deepstack_input_embeds_num_tokens = 0
15951598

15961599
with self._mark_language_model(aphrodite_config):
15971600
self.language_model = Qwen3LLMForCausalLM(
@@ -1612,6 +1615,13 @@ def _get_deepstack_input_embeds(
16121615
) -> IntermediateTensors | None:
16131616
if not getattr(self, "deepstack_input_embeds", None):
16141617
return None # If vision tower is skipped
1618+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1619+
return None
1620+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1621+
raise ValueError(
1622+
"Requested more deepstack tokens than available in buffer: "
1623+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1624+
)
16151625

16161626
# get deepstack_input_embeds from buffer, and clear the buffer
16171627
return IntermediateTensors(
@@ -1639,15 +1649,25 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
16391649
]
16401650
for idx in range(self.deepstack_num_level):
16411651
self.deepstack_input_embeds[idx][:num_tokens].copy_(deepstack_input_embeds[idx])
1652+
self.deepstack_input_embeds_num_tokens = num_tokens
16421653

16431654
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
16441655
if not getattr(self, "deepstack_input_embeds", None):
16451656
return
1657+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1658+
return
16461659

16471660
# clear deepstack_input_embeds in buffer
16481661
if num_tokens > 0:
1662+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1663+
raise ValueError(
1664+
"Requested to clear more deepstack tokens than available in "
1665+
"buffer: "
1666+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1667+
)
16491668
for idx in range(self.deepstack_num_level):
16501669
self.deepstack_input_embeds[idx][:num_tokens].zero_()
1670+
self.deepstack_input_embeds_num_tokens = 0
16511671

16521672
# -- SupportsEncoderCudaGraph protocol methods --
16531673

0 commit comments

Comments
 (0)