Skip to content

Commit 784cc26

Browse files
Fix Evoformer's multi-arch dispatch root cause (#7881)
Fixes #7863 Replaces #7872 @Flamefire Issue #7863 reports order-dependent failures in Evoformer when building for mixed CUDA architectures. The guard-only approach prevents some bad outputs but does not solve multi-generation packaging requirements. This PR takes the root-cause direction: produce a correct multi-arch binary that can run on pre-Ampere and Ampere+ and select the right kernel family at runtime. With TORCH_CUDA_ARCH_LIST='7.0;8.0': 1. Build is no longer pinned by -DGPU_ARCH; it uses runtime arch dispatch (evoformer_attn.py:33, gemm_kernel_utils.h:53). 1. Runtime chooses implementation by device CC: - CC >= 80 -> Sm80 (Ampere+ path) - CC >= 75 -> Sm75 - CC >= 70 -> Sm70 1. So pre-Ampere uses pre-Ampere kernels, and Ampere+ uses the Ampere-family kernel path. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
1 parent f88d0f8 commit 784cc26

4 files changed

Lines changed: 116 additions & 56 deletions

File tree

csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,45 +42,28 @@
4242
template <typename arch, typename scalar_t>
4343
struct CheckArch {
4444
static constexpr bool isPreVolta = arch::kMinComputeCapability < 70;
45-
static constexpr bool isPreAmpere =
46-
arch::kMinComputeCapability < 80 && arch::kMinComputeCapability >= 70;
45+
// DISPATCH_ARCHTAG only binds Sm70/Sm75/Sm80+, so overlap with isPreVolta is unreachable.
46+
static constexpr bool isPreAmpere = arch::kMinComputeCapability < 80;
4747
static constexpr bool isAmpere = arch::kMinComputeCapability >= 80;
48-
#if defined(__CUDA_ARCH__)
49-
static constexpr bool compiler_cc = arch::kMinComputeCapability * 10 <= __CUDA_ARCH__;
50-
#else
51-
static constexpr bool compiler_cc = true;
52-
#endif
5348
static constexpr bool value = (isPreVolta && std::is_same_v<scalar_t, float>) ||
5449
(isPreAmpere && !std::is_same_v<scalar_t, cutlass::bfloat16_t>) ||
55-
isAmpere && compiler_cc;
50+
isAmpere;
5651
};
5752

58-
#define DISPATCH_ARCHTAG(CC, func) \
59-
{ \
60-
if constexpr (GPU_ARCH >= 80) { \
61-
if (CC >= 80) { \
62-
using ArchTag = cutlass::arch::Sm80; \
63-
func; \
64-
} else { \
65-
EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
66-
} \
67-
} else if constexpr (GPU_ARCH >= 75) { \
68-
if (CC >= 75) { \
69-
using ArchTag = cutlass::arch::Sm75; \
70-
func; \
71-
} else { \
72-
EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
73-
} \
74-
} else if constexpr (GPU_ARCH >= 70) { \
75-
if (CC >= 70) { \
76-
using ArchTag = cutlass::arch::Sm70; \
77-
func; \
78-
} else { \
79-
EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
80-
} \
81-
} else { \
82-
EVOFORMER_CHECK(false, "Only GPUs with Tensor Core are supported for now"); \
83-
} \
53+
#define DISPATCH_ARCHTAG(CC, func) \
54+
{ \
55+
if ((CC) >= 80) { \
56+
using ArchTag = cutlass::arch::Sm80; \
57+
func; \
58+
} else if ((CC) >= 75) { \
59+
using ArchTag = cutlass::arch::Sm75; \
60+
func; \
61+
} else if ((CC) >= 70) { \
62+
using ArchTag = cutlass::arch::Sm70; \
63+
func; \
64+
} else { \
65+
EVOFORMER_CHECK(false, "Only GPUs with Tensor Core (SM >= 70) are supported"); \
66+
} \
8467
}
8568

8669
#define DISPATCH_TYPES(tensor, func) \

docs/_tutorials/ds4sci_evoformerattention.md

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,42 @@ export CUTLASS_PATH=/path/to/cutlass
2626
```
2727
The kernels will be compiled when `DS4Sci_EvoformerAttention` is called for the first time.
2828

29-
`DS4Sci_EvoformerAttention` requires GPUs with compute capability 7.0 or higher (NVIDIA V100 or later GPUs) and the minimal CUDA version is 11.3. It is recommended to use CUDA 11.7 or later for better performance. Besides, the performance of backward kernel on V100 kernel is not as good as that on A100 for now.
30-
The extension checks both requirements and fails if any is not met. To disable the check, for example for cross-compiling in a system without GPUs, you can set the environment variable ```DS_IGNORE_CUDA_DETECTION=TRUE```
31-
and the environment value ```DS_EVOFORMER_GPU_ARCH={70|75|80}```, which controls the target GPU (80 being the last supported and meaning NVIDIA Ampere and later).
29+
`DS4Sci_EvoformerAttention` requires GPUs with compute capability 7.0 or higher
30+
(NVIDIA V100 or later GPUs) and the minimal CUDA version is 11.3. It is
31+
recommended to use CUDA 11.7 or later for better performance. Besides, the
32+
performance of backward kernel on V100 is not as good as on A100 for now.
33+
34+
The extension checks both requirements and fails if any is not met. To disable
35+
the check (for example cross-compiling in a system without GPUs), set
36+
`DS_IGNORE_CUDA_DETECTION=TRUE`.
37+
38+
### Multi-Arch Build Behavior
39+
40+
Evoformer now supports mixed-architecture packaging directly via
41+
`TORCH_CUDA_ARCH_LIST`.
42+
43+
Example:
44+
45+
```shell
46+
CUTLASS_PATH=/path/to/cutlass \
47+
TORCH_CUDA_ARCH_LIST='7.0;8.0' \
48+
DS_BUILD_OPS=0 DS_BUILD_EVOFORMER_ATTN=1 \
49+
pip install -e .
50+
```
51+
52+
- `TORCH_CUDA_ARCH_LIST` controls generated CUDA slices (order-independent).
53+
- Targets below `sm_70` are pruned for Evoformer because Tensor Cores are
54+
required.
55+
- `DS_EVOFORMER_GPU_ARCH` is **deprecated** and ignored for Evoformer builds.
56+
Use `TORCH_CUDA_ARCH_LIST` instead.
57+
58+
Supported dtype matrix by architecture family:
59+
60+
| Arch family | fp16 | bf16 |
61+
|-------------|------|------|
62+
| Sm70 (Volta) | Yes | No |
63+
| Sm75 (Turing) | Yes | No |
64+
| Sm80+ (Ampere/Ada/Hopper) | Yes | Yes |
3265

3366
### 3.2 Unit test and benchmark
3467

op_builder/evoformer_attn.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@ def __init__(self, name=None):
1616
name = self.NAME if name is None else name
1717
super().__init__(name=name)
1818
self.cutlass_path = os.environ.get("CUTLASS_PATH")
19-
# Target GPU architecture.
20-
# Current useful values are: 70, 75, 80.
21-
# For modern GPUs, 80 is the right value.
22-
# No specializations of the kernel beyond Ampere are implemented
23-
# See gemm_kernel_utils.h (also in cutlass example for fused attention) and cutlass/arch/arch.h
24-
self.gpu_arch = os.environ.get("DS_EVOFORMER_GPU_ARCH")
2519

2620
def absolute_name(self):
2721
return f"deepspeed.ops.{self.NAME}_op"
@@ -37,19 +31,23 @@ def sources(self):
3731
return [f"{src_dir}/attention.cpp", f"{src_dir}/attention_back.cu", f"{src_dir}/attention_cu.cu"]
3832

3933
def nvcc_args(self):
40-
args = super().nvcc_args()
41-
if not self.gpu_arch:
42-
try:
43-
import torch
44-
except ImportError:
45-
self.warning("Please install torch if trying to pre-compile kernels")
46-
return args
47-
major = torch.cuda.get_device_properties(0).major #ignore-cuda
48-
minor = torch.cuda.get_device_properties(0).minor #ignore-cuda
49-
args.append(f"-DGPU_ARCH={major}{minor}")
50-
else:
51-
args.append(f"-DGPU_ARCH={self.gpu_arch}")
52-
return args
34+
if os.environ.get("DS_EVOFORMER_GPU_ARCH"):
35+
self.warning("DS_EVOFORMER_GPU_ARCH is deprecated and ignored for Evoformer builds. "
36+
"Use TORCH_CUDA_ARCH_LIST to control build targets.")
37+
return super().nvcc_args()
38+
39+
def filter_ccs(self, ccs):
40+
"""Keep only Tensor Core capable targets (>= 7.0)."""
41+
retained = []
42+
pruned = []
43+
for cc in [cc.split('.') for cc in ccs]:
44+
if int(cc[0]) >= 7:
45+
retained.append(cc)
46+
else:
47+
pruned.append(cc)
48+
if pruned:
49+
self.warning(f"Evoformer: excluding targets below SM 7.0: {pruned}. Tensor Core required.")
50+
return retained
5351

5452
def is_compatible(self, verbose=False):
5553
try:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from pathlib import Path
7+
from unittest.mock import patch
8+
9+
from deepspeed.ops.op_builder.builder import CUDAOpBuilder
10+
# Import the concrete builder class instead of the accelerator-dispatched alias.
11+
from deepspeed.ops.op_builder.evoformer_attn import EvoformerAttnBuilder
12+
13+
14+
def test_filter_ccs_removes_below_70_and_keeps_ptx_suffix():
15+
builder = EvoformerAttnBuilder()
16+
result = builder.filter_ccs(["6.0", "6.1", "7.0", "8.0+PTX"])
17+
18+
majors = [int(cc[0]) for cc in result]
19+
assert 6 not in majors
20+
assert 7 in majors
21+
assert 8 in majors
22+
23+
ptx_entries = [cc for cc in result if cc[1].endswith("+PTX")]
24+
assert len(ptx_entries) == 1
25+
assert ptx_entries[0] == ["8", "0+PTX"]
26+
27+
28+
def test_nvcc_args_deprecates_env_and_omits_gpu_arch_define():
29+
builder = EvoformerAttnBuilder()
30+
with patch.dict("os.environ", {"DS_EVOFORMER_GPU_ARCH": "80"}, clear=False):
31+
with patch.object(builder, "warning") as warn:
32+
with patch.object(CUDAOpBuilder, "nvcc_args", return_value=["-O3", "-lineinfo"]):
33+
args = builder.nvcc_args()
34+
35+
warning_messages = [call.args[0] for call in warn.call_args_list if call.args]
36+
assert any("DS_EVOFORMER_GPU_ARCH is deprecated and ignored" in msg for msg in warning_messages)
37+
assert all("-DGPU_ARCH=" not in arg for arg in args)
38+
39+
40+
def test_no_cuda_arch_in_checkarch():
41+
header = Path(__file__).resolve().parents[4] / "csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h"
42+
text = header.read_text()
43+
start = text.index("struct CheckArch")
44+
end = text.index("};", start) + 2
45+
block = text[start:end]
46+
assert "__CUDA_ARCH__" not in block

0 commit comments

Comments
 (0)