Skip to content

Commit 0dfa6b6

Browse files
authored
core: support logprobs with multi-step scheduling (#963)
* deferred sampler results * fix imports and implement within multistep worker * update tests * fix test * fix sequence test * fix unrelated gguf ruff issue
1 parent 34e8606 commit 0dfa6b6

108 files changed

Lines changed: 917 additions & 424 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

aphrodite/common/sequence.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,72 +1046,6 @@ def __repr__(self) -> str:
10461046
return f"IntermediateTensors(tensors={self.tensors})"
10471047

10481048

1049-
class SamplerOutput(
1050-
msgspec.Struct,
1051-
omit_defaults=True, # type: ignore[call-arg]
1052-
array_like=True): # type: ignore[call-arg]
1053-
"""For each sequence group, we generate a list of SequenceOutput object,
1054-
each of which contains one possible candidate for the next token.
1055-
1056-
This data structure implements methods, so it can be used like a list, but
1057-
also has optional fields for device tensors.
1058-
"""
1059-
1060-
outputs: List[CompletionSequenceGroupOutput]
1061-
1062-
# On-device tensor containing probabilities of each token.
1063-
sampled_token_probs: Optional[torch.Tensor] = None
1064-
1065-
# On-device tensor containing the logprobs of each token.
1066-
logprobs: Optional["torch.Tensor"] = None
1067-
1068-
# On-device tensor containing the sampled token ids.
1069-
sampled_token_ids: Optional[torch.Tensor] = None
1070-
# CPU tensor containing the sampled token ids. Used during multi-step to
1071-
# return the sampled token ids from last rank to AsyncLLMEngine to be
1072-
# 'broadcasted' to all other PP ranks for next step.
1073-
sampled_token_ids_cpu: Optional[torch.Tensor] = None
1074-
1075-
# Spec decode metrics populated by workers.
1076-
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
1077-
1078-
# Optional last hidden states from the model.
1079-
hidden_states: Optional[torch.Tensor] = None
1080-
1081-
# Optional prefill hidden states from the model
1082-
# (used for models like EAGLE).
1083-
prefill_hidden_states: Optional[torch.Tensor] = None
1084-
1085-
# Time taken in the forward pass for this across all workers
1086-
model_forward_time: Optional[float] = None
1087-
1088-
def __getitem__(self, idx: int):
1089-
return self.outputs[idx]
1090-
1091-
def __setitem__(self, idx: int, value):
1092-
self.outputs[idx] = value
1093-
1094-
def __len__(self):
1095-
return len(self.outputs)
1096-
1097-
def __eq__(self, other: object):
1098-
return isinstance(other,
1099-
self.__class__) and self.outputs == other.outputs
1100-
1101-
def __repr__(self) -> str:
1102-
"""Show the shape of a tensor instead of its values to reduce noise.
1103-
"""
1104-
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
1105-
else self.sampled_token_probs.shape)
1106-
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
1107-
self.sampled_token_ids.shape)
1108-
return (
1109-
f"SamplerOutput(outputs={self.outputs}, "
1110-
f"sampled_token_probs={sampled_token_probs_repr}, "
1111-
f"sampled_token_ids={sampled_token_ids_repr}, "
1112-
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
1113-
1114-
11151049
class PoolerOutput(
11161050
msgspec.Struct,
11171051
omit_defaults=True, # type: ignore[call-arg]

aphrodite/engine/aphrodite_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from aphrodite.common.pooling_params import PoolingParams
2525
from aphrodite.common.sampling_params import SamplingParams
2626
from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
27-
ExecuteModelRequest, SamplerOutput,
28-
Sequence, SequenceGroup,
29-
SequenceGroupMetadata, SequenceStatus)
27+
ExecuteModelRequest, Sequence,
28+
SequenceGroup, SequenceGroupMetadata,
29+
SequenceStatus)
3030
from aphrodite.common.utils import Counter, Device
3131
from aphrodite.engine.args_tools import EngineArgs
3232
from aphrodite.engine.metrics_types import StatLoggerBase, Stats
@@ -42,6 +42,7 @@
4242
SingletonPromptInputs)
4343
from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
4444
from aphrodite.lora.request import LoRARequest
45+
from aphrodite.modeling.layers.sampler import SamplerOutput
4546
from aphrodite.multimodal import MultiModalDataDict
4647
from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
4748
SchedulerOutputs)

aphrodite/engine/async_aphrodite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
1515
from aphrodite.common.pooling_params import PoolingParams
1616
from aphrodite.common.sampling_params import SamplingParams
17-
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
17+
from aphrodite.common.sequence import ExecuteModelRequest
1818
from aphrodite.common.utils import print_warning_once
1919
from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
2020
DecoderPromptComponents,
@@ -29,6 +29,7 @@
2929
SingletonPromptInputs)
3030
from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
3131
from aphrodite.lora.request import LoRARequest
32+
from aphrodite.modeling.layers.sampler import SamplerOutput
3233
from aphrodite.processing.scheduler import SchedulerOutputs
3334
from aphrodite.prompt_adapter.request import PromptAdapterRequest
3435
from aphrodite.transformers_utils.tokenizer import AnyTokenizer

aphrodite/engine/output_processor/multi_step.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from aphrodite.common.utils import Counter
1212
from aphrodite.engine.output_processor.interfaces import (
1313
SequenceGroupOutputProcessor)
14+
from aphrodite.engine.output_processor.single_step import (
15+
single_step_process_prompt_logprob)
1416
from aphrodite.engine.output_processor.stop_checker import StopChecker
1517
from aphrodite.processing.scheduler import Scheduler
1618
from aphrodite.transformers_utils.detokenizer import Detokenizer
@@ -46,9 +48,15 @@ def __init__(
4648

4749
def process_prompt_logprob(self, seq_group: SequenceGroup,
4850
outputs: List[SequenceGroupOutput]) -> None:
49-
# TODO: Prompt logprob currently not implemented in multi step
50-
# workers.
51-
self._log_prompt_logprob_unsupported_warning_once()
51+
"""Process prompt logprobs associated with each step of a multi-step-
52+
scheduled computation.
53+
Args:
54+
seq_group: the outputs are associated with this :class:`SequenceGroup`
55+
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
56+
"""
57+
for output in outputs:
58+
# Concatenate single-step prompt logprob processing results.
59+
single_step_process_prompt_logprob(self, seq_group, output)
5260

5361
@staticmethod
5462
@functools.lru_cache()

aphrodite/engine/output_processor/single_step.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,42 @@
1313
from aphrodite.transformers_utils.detokenizer import Detokenizer
1414

1515

16+
def single_step_process_prompt_logprob(
17+
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
18+
output: SequenceGroupOutput) -> None:
19+
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
20+
for a given step.
21+
Do nothing if the output has no prompt logprobs.
22+
Account for the fact that transformers do not compute first-token logprobs.
23+
24+
Args:
25+
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
26+
seq_group: the output is associated with this :class:`SequenceGroup`
27+
output: the :class:`SequenceGroupOutput` for a single scheduler step
28+
"""
29+
prompt_logprobs = output.prompt_logprobs
30+
31+
# If this is the first (or only) "chunk" of the prefill, we need
32+
# to prepend None to the list of prompt logprobs. The reason for this
33+
# is that for N prompt tokens, the Sampler will generate N-1 total
34+
# prompt logprobs during prefill since the token at idx 0 will not
35+
# have a logprob associated with it.
36+
if prompt_logprobs is not None:
37+
if not seq_group.prompt_logprobs:
38+
prompt_logprobs = [None] + prompt_logprobs
39+
seq_group.prompt_logprobs = []
40+
41+
assert hasattr(sg_output_proc, 'detokenizer')
42+
if (seq_group.sampling_params.detokenize
43+
and sg_output_proc.detokenizer):
44+
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
45+
seq_group,
46+
prompt_logprobs,
47+
position_offset=len(seq_group.prompt_logprobs))
48+
49+
seq_group.prompt_logprobs.extend(prompt_logprobs)
50+
51+
1652
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
1753
"""SequenceGroupOutputProcessor which handles "output processing" logic,
1854
which happens after the model returns generated token ids and before
@@ -57,25 +93,16 @@ def process_outputs(self, sequence_group: SequenceGroup,
5793

5894
def process_prompt_logprob(self, seq_group: SequenceGroup,
5995
outputs: List[SequenceGroupOutput]) -> None:
96+
"""Process prompt logprobs associated with one step of a single-step-
97+
scheduled computation.
98+
99+
Args:
100+
seq_group: the output is associated with this :class:`SequenceGroup`
101+
output: the :class:`SequenceGroupOutput` for a single scheduler step
102+
"""
60103
assert len(outputs) == 1, ("Single step should only has 1 output.")
61104
output = outputs[0]
62-
prompt_logprobs = output.prompt_logprobs
63-
64-
# If this is the first (or only) "chunk" of the prefill, we need
65-
# to prepend None to the list of prompt logprobs. The reason for this
66-
# is that for N prompt tokens, the Sampler will generate N-1 total
67-
# prompt logprobs during prefill since the token at idx 0 will not
68-
# have a logprob associated with it.
69-
if prompt_logprobs is not None:
70-
if not seq_group.prompt_logprobs:
71-
prompt_logprobs = [None] + prompt_logprobs
72-
seq_group.prompt_logprobs = []
73-
if seq_group.sampling_params.detokenize and self.detokenizer:
74-
self.detokenizer.decode_prompt_logprobs_inplace(
75-
seq_group,
76-
prompt_logprobs,
77-
position_offset=len(seq_group.prompt_logprobs))
78-
seq_group.prompt_logprobs.extend(prompt_logprobs)
105+
single_step_process_prompt_logprob(self, seq_group, output)
79106

80107
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
81108
outputs: SequenceGroupOutput,

aphrodite/engine/output_processor/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Sequence as GenericSequence
33
from typing import Union
44

5-
from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
6-
SequenceGroupOutput)
5+
from aphrodite.common.sequence import PoolerOutput, SequenceGroupOutput
6+
from aphrodite.modeling.layers.sampler import SamplerOutput
77

88

99
def create_output_by_sequence_group(

aphrodite/engine/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
77
from aphrodite.common.pooling_params import PoolingParams
88
from aphrodite.common.sampling_params import SamplingParams
9-
from aphrodite.common.sequence import SamplerOutput
109
from aphrodite.inputs.data import PromptInputs
1110
from aphrodite.lora.request import LoRARequest
11+
from aphrodite.modeling.layers.sampler import SamplerOutput
1212
from aphrodite.processing.scheduler import SchedulerOutputs
1313
from aphrodite.prompt_adapter.request import PromptAdapterRequest
1414

aphrodite/executor/cpu_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import aphrodite.common.envs as envs
99
from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
10-
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
10+
from aphrodite.common.sequence import ExecuteModelRequest
1111
from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
1212
get_distributed_init_method, get_open_port,
1313
make_async)
@@ -16,6 +16,7 @@
1616
ResultHandler,
1717
WorkerMonitor)
1818
from aphrodite.lora.request import LoRARequest
19+
from aphrodite.modeling.layers.sampler import SamplerOutput
1920
from aphrodite.prompt_adapter.request import PromptAdapterRequest
2021
from aphrodite.task_handler.worker_base import WorkerWrapperBase
2122

aphrodite/executor/distributed_gpu_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from loguru import logger
66

7-
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
7+
from aphrodite.common.sequence import ExecuteModelRequest
88
from aphrodite.executor.executor_base import ExecutorAsyncBase
99
from aphrodite.executor.gpu_executor import GPUExecutor
1010
from aphrodite.lora.request import LoRARequest
11+
from aphrodite.modeling.layers.sampler import SamplerOutput
1112

1213

1314
class DistributedGPUExecutor(GPUExecutor):

aphrodite/executor/executor_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
LoRAConfig, ModelConfig, ParallelConfig,
66
PromptAdapterConfig, SchedulerConfig,
77
SpeculativeConfig)
8-
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
8+
from aphrodite.common.sequence import ExecuteModelRequest
99
from aphrodite.lora.request import LoRARequest
10+
from aphrodite.modeling.layers.sampler import SamplerOutput
1011
from aphrodite.prompt_adapter.request import PromptAdapterRequest
1112

1213

0 commit comments

Comments
 (0)