Skip to content

Commit 9faeefa

Browse files
committed
fix: update test_spec_agent for guided decoding support
- Add guided_decoding_manager=None to SpecModelAgent mock objects - Make _DummyProposer.get_outputs async and accept guided_processors kwarg
1 parent 2b25488 commit 9faeefa

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

tests/pytorch/spec_decode/test_spec_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self):
6060
self.update_inputs_decoding_calls = 0
6161
self.model = _DummyDraftModel()
6262

63-
def get_outputs(self, outputs, inputs, extra_inputs=None):
63+
async def get_outputs(self, outputs, inputs, extra_inputs=None, guided_processors=None):
6464
batch_size = inputs.seq_length.size(0)
6565
draft_token_ids = inputs.input_ids.new_full((batch_size, 1), self.get_outputs_calls)
6666
self.get_outputs_calls += 1
@@ -204,6 +204,7 @@ def test_async_model_forward_dp1_non_last_chunk_skips_remaining_spec_forwards():
204204
agent.num_spec_tokens = 3
205205
agent.rank = 0
206206
agent.proposer = _DummyProposer()
207+
agent.guided_decoding_manager = None
207208
forward_calls = 0
208209

209210
def _forward_impl(_inputs):
@@ -235,6 +236,7 @@ def test_async_model_forward_dp_non_last_chunk_runs_all_spec_forwards(monkeypatc
235236
agent.num_spec_tokens = 3
236237
agent.rank = 0
237238
agent.proposer = _DummyProposer()
239+
agent.guided_decoding_manager = None
238240
forward_calls = 0
239241

240242
def _forward_impl(_inputs):

0 commit comments

Comments
 (0)