-
Notifications
You must be signed in to change notification settings - Fork 7k
[tests] add attention backend tests. #13174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
73b159f
add attention backend tests.
sayakpaul 7220687
remove existing tests/others/test_attention_backends.py file
sayakpaul 6648604
Merge branch 'main' into attn-backend-tests
sayakpaul e33e162
modify generate_model_tests.py
sayakpaul b9dee8e
Merge branch 'main' into attn-backend-tests
sayakpaul 02fdd95
Merge branch 'main' into attn-backend-tests
sayakpaul 7eae4a1
remove native.
sayakpaul eb776ad
account for _keep_in_fp32_modules
sayakpaul 4a3b4fd
don't skip when exception is raised.
sayakpaul e228d51
use is_kernels_available()
sayakpaul 44d894c
mark with compile.
sayakpaul 090eac8
move rtol and atol to methods as defaults.
sayakpaul 0b71142
Merge branch 'main' into attn-backend-tests
sayakpaul a026a28
Apply suggestions from code review
sayakpaul c58ff35
up
sayakpaul 700a6d9
Merge branch 'main' into attn-backend-tests
sayakpaul 9dc7088
resolve conflicts
sayakpaul 36722a0
up
sayakpaul cbb1878
Merge branch 'main' into attn-backend-tests
sayakpaul d81b081
Merge branch 'main' into attn-backend-tests
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,22 +14,105 @@ | |
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import logging | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from diffusers.models.attention import AttentionModuleMixin | ||
| from diffusers.models.attention_processor import ( | ||
| AttnProcessor, | ||
| from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend | ||
| from diffusers.models.attention_processor import AttnProcessor | ||
| from diffusers.utils import is_kernels_available, is_torch_version | ||
|
|
||
| from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Module-level backend parameter sets for AttentionBackendTesterMixin | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| _CUDA_AVAILABLE = torch.cuda.is_available() | ||
| _KERNELS_AVAILABLE = is_kernels_available() | ||
|
|
||
| _PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native") | ||
|
|
||
| _PARAM_NATIVE_CUDNN = pytest.param( | ||
| AttentionBackendName._NATIVE_CUDNN, | ||
| id="native_cudnn", | ||
| marks=pytest.mark.skipif( | ||
| not _CUDA_AVAILABLE, | ||
| reason="CUDA is required for _native_cudnn backend.", | ||
| ), | ||
| ) | ||
|
|
||
| _PARAM_FLASH_HUB = pytest.param( | ||
| AttentionBackendName.FLASH_HUB, | ||
| id="flash_hub", | ||
| marks=[ | ||
| pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."), | ||
| pytest.mark.skipif( | ||
| not _KERNELS_AVAILABLE, | ||
| reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.", | ||
| ), | ||
| ], | ||
| ) | ||
|
|
||
| from ...testing_utils import ( | ||
| assert_tensors_close, | ||
| backend_empty_cache, | ||
| is_attention, | ||
| torch_device, | ||
| _PARAM_FLASH_3_HUB = pytest.param( | ||
| AttentionBackendName._FLASH_3_HUB, | ||
| id="flash_3_hub", | ||
| marks=[ | ||
| pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."), | ||
| pytest.mark.skipif( | ||
| not _KERNELS_AVAILABLE, | ||
| reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.", | ||
| ), | ||
| ], | ||
| ) | ||
|
|
||
| # All backends under test. | ||
| _ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB] | ||
|
|
||
| # Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them. | ||
| _BF16_REQUIRED_BACKENDS = { | ||
| AttentionBackendName._NATIVE_CUDNN, | ||
| AttentionBackendName.FLASH_HUB, | ||
| AttentionBackendName._FLASH_3_HUB, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Include Sage as well? |
||
| } | ||
|
|
||
| # Backends that perform non-deterministic operations and therefore cannot run when | ||
| # torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()). | ||
| _NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN} | ||
|
|
||
|
|
||
|
sayakpaul marked this conversation as resolved.
|
||
| def _maybe_cast_to_bf16(backend, model, inputs_dict): | ||
| """Cast model and floating-point inputs to bfloat16 when the backend requires it.""" | ||
| if backend not in _BF16_REQUIRED_BACKENDS: | ||
| return model, inputs_dict | ||
| model = model.to(dtype=torch.bfloat16) | ||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||
| inputs_dict = { | ||
| k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v | ||
| for k, v in inputs_dict.items() | ||
| } | ||
| return model, inputs_dict | ||
|
|
||
|
|
||
| def _skip_if_backend_requires_nondeterminism(backend): | ||
| """Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend. | ||
|
|
||
| This check is intentionally deferred to test execution time because | ||
| enable_full_determinism() is typically called at module level in test files *after* | ||
| the module-level pytest.param() objects in this file have already been evaluated, | ||
| making it impossible to catch via a collection-time skipif condition. | ||
| """ | ||
| if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled(): | ||
| pytest.skip( | ||
| f"Backend '{backend.value}' performs non-deterministic operations and cannot run " | ||
| f"while `torch.use_deterministic_algorithms(True)` is active." | ||
| ) | ||
|
|
||
|
|
||
| @is_attention | ||
| class AttentionTesterMixin: | ||
|
|
@@ -39,7 +122,6 @@ class AttentionTesterMixin: | |
| Tests functionality from AttentionModuleMixin including: | ||
| - Attention processor management (set/get) | ||
| - QKV projection fusion/unfusion | ||
| - Attention backends (XFormers, NPU, etc.) | ||
|
|
||
| Expected from config mixin: | ||
| - model_class: The model class to test | ||
|
|
@@ -179,3 +261,208 @@ def test_attention_processor_count_mismatch_raises_error(self): | |
| model.set_attn_processor(wrong_processors) | ||
|
|
||
| assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" | ||
|
|
||
|
|
||
| @is_attention | ||
| class AttentionBackendTesterMixin: | ||
| """ | ||
| Mixin class for testing attention backends on models. Following things are tested: | ||
|
|
||
| 1. Backends can be set with the `attention_backend` context manager and with | ||
| `set_attention_backend()` method. | ||
| 2. SDPA outputs don't deviate too much from backend outputs. | ||
| 3. Backend works with (regional) compilation. | ||
| 4. Backends can be restored. | ||
|
|
||
| Tests the backends using the model provided by the host test class. The backends to test | ||
| are defined in `_ALL_BACKEND_PARAMS`. | ||
|
|
||
| Expected from the host test class: | ||
| - model_class: The model class to instantiate. | ||
|
|
||
| Expected methods from the host test class: | ||
| - get_init_dict(): Returns dict of kwargs to construct the model. | ||
| - get_dummy_inputs(): Returns dict of inputs for the model's forward pass. | ||
|
|
||
| Pytest mark: attention | ||
| Use `pytest -m "not attention"` to skip these tests. | ||
| """ | ||
|
|
||
| # ----------------------------------------------------------------------- | ||
| # Tolerance attributes — override in host class to loosen/tighten checks. | ||
| # ----------------------------------------------------------------------- | ||
|
|
||
| # test_output_close_to_native: alternate backends (flash, cuDNN) may | ||
| # accumulate small numerical errors vs the reference PyTorch SDPA kernel. | ||
| backend_vs_native_atol: float = 1e-2 | ||
| backend_vs_native_rtol: float = 1e-2 | ||
|
|
||
| # test_compile: regional compilation introduces the same kind of numerical | ||
| # error as the non-compiled backend path, so the same loose tolerance applies. | ||
| compile_vs_native_atol: float = 1e-2 | ||
| compile_vs_native_rtol: float = 1e-2 | ||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||
|
|
||
| def setup_method(self): | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
|
|
||
| def teardown_method(self): | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
|
|
||
| @torch.no_grad() | ||
| @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) | ||
| def test_set_attention_backend_matches_context_manager(self, backend): | ||
| """set_attention_backend() and the attention_backend() context manager must yield identical outputs.""" | ||
| _skip_if_backend_requires_nondeterminism(backend) | ||
|
|
||
| init_dict = self.get_init_dict() | ||
| inputs_dict = self.get_dummy_inputs() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
| model.eval() | ||
|
|
||
| model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict) | ||
|
|
||
| with attention_backend(backend): | ||
| ctx_output = model(**inputs_dict, return_dict=False)[0] | ||
|
|
||
| initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
|
|
||
| try: | ||
| model.set_attention_backend(backend.value) | ||
| except Exception as e: | ||
| logger.warning("Skipping test for backend '%s': %s", backend.value, e) | ||
| pytest.skip(str(e)) | ||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||
|
|
||
| try: | ||
| set_output = model(**inputs_dict, return_dict=False)[0] | ||
| finally: | ||
| model.reset_attention_backend() | ||
| _AttentionBackendRegistry.set_active_backend(initial_registry_backend) | ||
|
|
||
| assert_tensors_close( | ||
| set_output, | ||
| ctx_output, | ||
| atol=0, | ||
| rtol=0, | ||
| msg=( | ||
| f"Output from model.set_attention_backend('{backend.value}') should be identical " | ||
| f"to the output from `with attention_backend('{backend.value}'):`." | ||
| ), | ||
| ) | ||
|
|
||
| @torch.no_grad() | ||
| @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) | ||
| def test_output_close_to_native(self, backend): | ||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||
| """All backends should produce model output numerically close to the native SDPA reference.""" | ||
| _skip_if_backend_requires_nondeterminism(backend) | ||
|
|
||
| init_dict = self.get_init_dict() | ||
| inputs_dict = self.get_dummy_inputs() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
| model.eval() | ||
|
|
||
| model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict) | ||
|
|
||
| with attention_backend(AttentionBackendName.NATIVE): | ||
| native_output = model(**inputs_dict, return_dict=False)[0] | ||
|
|
||
| initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
|
|
||
| try: | ||
| model.set_attention_backend(backend.value) | ||
| except Exception as e: | ||
| logger.warning("Skipping test for backend '%s': %s", backend.value, e) | ||
| pytest.skip(str(e)) | ||
|
|
||
| try: | ||
| backend_output = model(**inputs_dict, return_dict=False)[0] | ||
| finally: | ||
| model.reset_attention_backend() | ||
| _AttentionBackendRegistry.set_active_backend(initial_registry_backend) | ||
|
|
||
| assert_tensors_close( | ||
| backend_output, | ||
| native_output, | ||
| atol=self.backend_vs_native_atol, | ||
| rtol=self.backend_vs_native_rtol, | ||
| msg=f"Output from {backend} should be numerically close to native SDPA.", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) | ||
| def test_context_manager_switches_and_restores_backend(self, backend): | ||
| """attention_backend() should activate the requested backend and restore the previous one on exit.""" | ||
| initial_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
|
|
||
| with attention_backend(backend): | ||
| active_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
| assert active_backend == backend, ( | ||
| f"Backend should be {backend} inside the context manager, got {active_backend}." | ||
| ) | ||
|
|
||
| restored_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
| assert restored_backend == initial_backend, ( | ||
| f"Backend should be restored to {initial_backend} after exiting the context manager, " | ||
| f"got {restored_backend}." | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) | ||
| def test_compile(self, backend): | ||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||
| """ | ||
| `torch.compile` tests checking for recompilation, graph breaks, forward can run, etc. | ||
| For speed, we use regional compilation here (`model.compile_repeated_blocks()` | ||
| as opposed to `model.compile`). | ||
| """ | ||
| _skip_if_backend_requires_nondeterminism(backend) | ||
| if getattr(self.model_class, "_repeated_blocks", None) is None: | ||
| pytest.skip("Skipping tests as regional compilation is not supported.") | ||
|
|
||
| if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"): | ||
| pytest.xfail( | ||
| "test_compile with the native backend requires torch >= 2.9.0 for stable " | ||
| "fullgraph compilation with error_on_recompile=True." | ||
| ) | ||
|
|
||
| init_dict = self.get_init_dict() | ||
| inputs_dict = self.get_dummy_inputs() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
| model.eval() | ||
|
|
||
| model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict) | ||
|
|
||
| with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE): | ||
| native_output = model(**inputs_dict, return_dict=False)[0] | ||
|
|
||
| initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
|
|
||
| try: | ||
| model.set_attention_backend(backend.value) | ||
| except Exception as e: | ||
| logger.warning("Skipping test for backend '%s': %s", backend.value, e) | ||
| pytest.skip(str(e)) | ||
|
|
||
| try: | ||
| model.compile_repeated_blocks(fullgraph=True) | ||
| torch.compiler.reset() | ||
|
|
||
| with ( | ||
| torch._inductor.utils.fresh_inductor_cache(), | ||
| torch._dynamo.config.patch(error_on_recompile=True), | ||
| ): | ||
| with torch.no_grad(): | ||
| compile_output = model(**inputs_dict, return_dict=False)[0] | ||
| model(**inputs_dict, return_dict=False) | ||
| finally: | ||
| model.reset_attention_backend() | ||
| _AttentionBackendRegistry.set_active_backend(initial_registry_backend) | ||
|
|
||
| assert_tensors_close( | ||
| compile_output, | ||
| native_output, | ||
| atol=self.compile_vs_native_atol, | ||
| rtol=self.compile_vs_native_rtol, | ||
| msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.", | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.