Skip to content

Commit 70234a7

Browse files
committed
hooks/pyramid_attention_broadcast: remove redundant iteration==0 guard and fix stale cache VRAM leak
1 parent c8c8401 commit 70234a7

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
159159
)
160160
should_compute_attention = (
161161
self.state.cache is None
162-
or self.state.iteration == 0
163162
or not is_within_timestep_range
164163
or self.state.iteration % self.block_skip_range == 0
165164
)
@@ -169,7 +168,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
169168
else:
170169
output = self.state.cache
171170

172-
self.state.cache = output
171+
self.state.cache = output if is_within_timestep_range else None
173172
self.state.iteration += 1
174173
return output
175174

0 commit comments

Comments
 (0)