Skip to content

Commit c33abfc

Browse files
sayakpauldg845
andauthored
[tests] fix anyflow tests (#13855)
* fix anyflow tests * [tests] fix anyflow tests layerwise casting (#13863) Fix AnyFlow FAR causal transformer training layerwise / mixed precision tests --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent b3ec481 commit c33abfc

2 files changed

Lines changed: 12 additions & 11 deletions

File tree

src/diffusers/models/transformers/transformer_anyflow_far.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __call__(
111111

112112
if encoder_hidden_states is None:
113113
encoder_hidden_states = hidden_states
114+
target_dtype = hidden_states.dtype # Effective compute dtype
114115

115116
query = attn.to_q(hidden_states)
116117
key = attn.to_k(encoder_hidden_states)
@@ -121,6 +122,11 @@ def __call__(
121122
if attn.norm_k is not None:
122123
key = attn.norm_k(key)
123124

125+
# norm_q and norm_k upcast query and key to FP32 due to the use of RMSNorm, so cast them back to the effective
126+
# compute dtype.
127+
query = query.to(target_dtype)
128+
key = key.to(target_dtype)
129+
124130
# Layout (B, H, L, D) is required by KV-cache slicing and rotary application.
125131
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
126132
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)

tests/models/transformers/test_models_transformer_anyflow_far.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
1715
import pytest
1816
import torch
1917

@@ -46,7 +44,7 @@ def model_class(self):
4644

4745
@property
4846
def output_shape(self) -> tuple[int, ...]:
49-
return (1, 2, 4, 16, 16)
47+
return (4, 4, 16, 16)
5048

5149
@property
5250
def input_shape(self) -> tuple[int, ...]:
@@ -137,15 +135,12 @@ def test_gradient_checkpointing_is_applied(self):
137135
# GPU-only (`torch.nn.attention.flex_attention` raises NotImplementedError on CPU). The
138136
# bidi transformer test file covers training on the SDPA path; FAR training correctness
139137
# is exercised end-to-end on H200 via the pipeline replay (L2=0 against NVlabs/AnyFlow).
140-
@unittest.skipIf(torch_device == "cpu", "FlexAttention has no CPU backward kernel.")
141138
def test_training(self):
142139
super().test_training()
143140

144-
@unittest.skipIf(torch_device == "cpu", "FlexAttention has no CPU backward kernel.")
145141
def test_training_with_ema(self):
146142
super().test_training_with_ema()
147143

148-
@unittest.skipIf(torch_device == "cpu", "FlexAttention has no CPU backward kernel.")
149144
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
150145
super().test_gradient_checkpointing_equivalence(loss_tolerance, param_grad_tol, skip)
151146

@@ -186,7 +181,7 @@ def test_compile_works_with_aot(self, tmp_path):
186181
super().test_compile_works_with_aot(tmp_path)
187182

188183

189-
class AnyFlowCausalAttnProcessorTest(unittest.TestCase):
184+
class TestAnyFlowCausalAttnProcessor:
190185
"""Stand-alone smoke tests for the FAR causal attention processor.
191186
192187
These cover behaviors not reached by the generated model mixins:
@@ -196,7 +191,7 @@ class AnyFlowCausalAttnProcessorTest(unittest.TestCase):
196191

197192
def test_default_backend_is_flex(self):
198193
processor = AnyFlowCausalAttnProcessor()
199-
self.assertEqual(processor._attention_backend, "flex")
194+
assert processor._attention_backend == "flex"
200195

201196
def test_unsupported_backend_raises(self):
202197
processor = AnyFlowCausalAttnProcessor()
@@ -217,10 +212,10 @@ def to_v(self, x):
217212

218213
to_out = [lambda x: x, lambda x: x]
219214

220-
with self.assertRaises(ValueError):
215+
with pytest.raises(ValueError):
221216
processor(_DummyAttn(), torch.zeros(1, 4, 4))
222217

223218
def test_output_dataclass_exposed(self):
224219
# Downstream type-checking + autodoc rely on these attributes existing.
225-
self.assertTrue(hasattr(AnyFlowFARTransformerOutput, "sample"))
226-
self.assertTrue(hasattr(AnyFlowFARTransformerOutput, "kv_cache"))
220+
assert hasattr(AnyFlowFARTransformerOutput, "sample")
221+
assert hasattr(AnyFlowFARTransformerOutput, "kv_cache")

0 commit comments

Comments
 (0)