We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b8c3dff commit e00d26bCopy full SHA for e00d26b
1 file changed
tests/models/t5gemma2/test_modeling_t5gemma2.py
@@ -621,6 +621,10 @@ def create_and_check_cross_attention_cache_is_not_sliding(
621
lm_labels,
622
pixel_values,
623
):
624
+ """
625
+ Regression test for #45521. Checks whether the cross attention cache is correctly handled, i.e. not a SWA cache.
626
+ This would previously fail on instances where the sliding window < encoder len.
627
628
config.decoder.sliding_window = self.encoder_seq_length // 2
629
self.parent.assertGreater(self.encoder_seq_length, config.decoder.sliding_window)
630
model = self.causal_lm_class(config=config).to(torch_device).eval()
0 commit comments