1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import unittest
16-
1715import pytest
1816import torch
1917
@@ -46,7 +44,7 @@ def model_class(self):
4644
4745 @property
4846 def output_shape (self ) -> tuple [int , ...]:
49- return (1 , 2 , 4 , 16 , 16 )
47+ return (4 , 4 , 16 , 16 )
5048
5149 @property
5250 def input_shape (self ) -> tuple [int , ...]:
@@ -137,15 +135,12 @@ def test_gradient_checkpointing_is_applied(self):
137135 # GPU-only (`torch.nn.attention.flex_attention` raises NotImplementedError on CPU). The
138136 # bidi transformer test file covers training on the SDPA path; FAR training correctness
139137 # is exercised end-to-end on H200 via the pipeline replay (L2=0 against NVlabs/AnyFlow).
140- @unittest .skipIf (torch_device == "cpu" , "FlexAttention has no CPU backward kernel." )
141138 def test_training (self ):
142139 super ().test_training ()
143140
144- @unittest .skipIf (torch_device == "cpu" , "FlexAttention has no CPU backward kernel." )
145141 def test_training_with_ema (self ):
146142 super ().test_training_with_ema ()
147143
148- @unittest .skipIf (torch_device == "cpu" , "FlexAttention has no CPU backward kernel." )
149144 def test_gradient_checkpointing_equivalence (self , loss_tolerance = 1e-5 , param_grad_tol = 5e-5 , skip = None ):
150145 super ().test_gradient_checkpointing_equivalence (loss_tolerance , param_grad_tol , skip )
151146
@@ -186,7 +181,7 @@ def test_compile_works_with_aot(self, tmp_path):
186181 super ().test_compile_works_with_aot (tmp_path )
187182
188183
189- class AnyFlowCausalAttnProcessorTest ( unittest . TestCase ) :
184+ class TestAnyFlowCausalAttnProcessor :
190185 """Stand-alone smoke tests for the FAR causal attention processor.
191186
192187 These cover behaviors not reached by the generated model mixins:
@@ -196,7 +191,7 @@ class AnyFlowCausalAttnProcessorTest(unittest.TestCase):
196191
197192 def test_default_backend_is_flex (self ):
198193 processor = AnyFlowCausalAttnProcessor ()
199- self . assertEqual ( processor ._attention_backend , "flex" )
194+ assert processor ._attention_backend == "flex"
200195
201196 def test_unsupported_backend_raises (self ):
202197 processor = AnyFlowCausalAttnProcessor ()
@@ -217,10 +212,10 @@ def to_v(self, x):
217212
218213 to_out = [lambda x : x , lambda x : x ]
219214
220- with self . assertRaises (ValueError ):
215+ with pytest . raises (ValueError ):
221216 processor (_DummyAttn (), torch .zeros (1 , 4 , 4 ))
222217
223218 def test_output_dataclass_exposed (self ):
224219 # Downstream type-checking + autodoc rely on these attributes existing.
225- self . assertTrue ( hasattr (AnyFlowFARTransformerOutput , "sample" ) )
226- self . assertTrue ( hasattr (AnyFlowFARTransformerOutput , "kv_cache" ) )
220+ assert hasattr (AnyFlowFARTransformerOutput , "sample" )
221+ assert hasattr (AnyFlowFARTransformerOutput , "kv_cache" )
0 commit comments