Skip to content

Commit a9855c4

Browse files
authored
[tests] fix audioldm2 tests. (#13293)
fix audioldm2 tests.
1 parent 0b35834 commit a9855c4

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,18 @@ def generate_language_model(
324324
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
325325
The sequence of generated hidden-states.
326326
"""
327-
cache_position_kwargs = {}
328-
if is_transformers_version("<", "4.52.1"):
329-
cache_position_kwargs["input_ids"] = inputs_embeds
330-
else:
331-
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
332-
cache_position_kwargs["device"] = (
333-
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
334-
)
335-
cache_position_kwargs["model_kwargs"] = model_kwargs
336327
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
337-
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
328+
if hasattr(self.language_model, "_get_initial_cache_position"):
329+
cache_position_kwargs = {}
330+
if is_transformers_version("<", "4.52.1"):
331+
cache_position_kwargs["input_ids"] = inputs_embeds
332+
else:
333+
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
334+
cache_position_kwargs["device"] = (
335+
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
336+
)
337+
cache_position_kwargs["model_kwargs"] = model_kwargs
338+
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
338339

339340
for _ in range(max_new_tokens):
340341
# prepare model inputs

0 commit comments

Comments
 (0)