Skip to content

Commit b201bfa

Browse files
committed
Nits
1 parent d7e50be commit b201bfa

2 files changed

Lines changed: 4 additions & 5 deletions

File tree

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def switch_to_paged_attn(self, model: ProtoPretrainedModel) -> None:
570570
@traced
571571
def start(self) -> None:
572572
"""Start the background generation thread."""
573-
if self._generation_thread is not None and self._generation_thread.is_alive():
573+
if self.is_running():
574574
logger.warning("Manager thread is already running.")
575575
return
576576
self.stop_event.clear()

tests/generation/test_continuous_batching.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def test_continuous_batching_will_allocation_be_successful(
367367
config=AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa"),
368368
continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=8),
369369
device=torch_device,
370+
tp_plan={},
370371
distributed_helper=DistributedHelper(device_mesh=None),
371372
)
372373

@@ -700,9 +701,7 @@ def test_continuous_batching_config_combinations_no_compile(
700701
attn_implementation=attn_implementation,
701702
)
702703

703-
@parameterized.expand(
704-
[("eager", False), ("sdpa", False), ("sdpa", True), ("flash_attention_2", True)]
705-
)
704+
@parameterized.expand([("eager", False), ("sdpa", False), ("sdpa", True), ("flash_attention_2", True)])
706705
@slow
707706
def test_continuous_batching_config_combinations_with_compile(
708707
self,
@@ -1211,11 +1210,11 @@ def test_per_request_logits_processors(self, use_cuda_graph: bool, use_async_bat
12111210
use_async_batching=use_async_batching,
12121211
per_request_processors=True,
12131212
return_logprobs=True,
1213+
q_padding_interval_size=16, # allows for exact comparison between CB and regular generation
12141214
)
12151215
manager = model.init_continuous_batching(
12161216
generation_config=generation_config,
12171217
continuous_batching_config=continuous_batching_config,
1218-
q_padding_interval_size=16, # allows for exact comparison between CB and regular generation
12191218
)
12201219

12211220
# Trick to have temperature, top-k, top-p ... without randomness: diable sampling after manager creation

0 commit comments

Comments
 (0)