Skip to content

Commit a70e329

Browse files
committed
[DFlash] add num_timesteps property for parity with LLaDA2
1 parent 7bc8161 commit a70e329

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/diffusers/pipelines/dflash/pipeline_dflash.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def __init__(
9595
draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler
9696
)
9797

98+
@property
99+
def num_timesteps(self):
100+
return self._num_timesteps
101+
98102
# --- Prompt encoding ---
99103

100104
def _prepare_input_ids(
@@ -391,6 +395,7 @@ def _new_cache(cfg):
391395
start = num_input_tokens
392396
global_step = 0
393397
num_blocks = (max_length - num_input_tokens + block_size - 1) // block_size
398+
self._num_timesteps = int(num_blocks)
394399

395400
# 5. Block-wise speculative decoding loop
396401
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()

0 commit comments

Comments
 (0)