Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions openfold3/core/kernels/triton/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,22 @@ def forward(ctx, Q, K, V, res_mask, pair_bias, has_pair_bias=True):
).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM)

BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = Q.shape

assert res_mask.shape == (
BATCH_SIZE,
N_SEQ,
1,
1,
SEQ_LEN,
), f"{tuple(res_mask.shape)} != {(BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN)}"
assert pair_bias.shape == (
BATCH_SIZE,
1,
HEAD,
SEQ_LEN,
SEQ_LEN,
), f"{tuple(pair_bias.shape)} != {(BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN)}"

softmax_scale = DIM**-0.5
BLOCK_DIM = max(triton.next_power_of_2(DIM), 32)

Expand Down
12 changes: 10 additions & 2 deletions openfold3/core/model/heads/head_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,16 @@ def forward(
apply_per_sample = (
not torch.is_grad_enabled()
and num_samples > 1
and self.per_sample_token_cutoff is not None
and repr_x_pred.shape[-2] > self.per_sample_token_cutoff
and (
(self.per_sample_token_cutoff is not None
and repr_x_pred.shape[-2] > self.per_sample_token_cutoff)
# The optimized attention kernels do not support cross-sample
# chunking because it requires expanding the pair bias. For now
# we just always apply per sample if these kernels are in use.
or use_deepspeed_evo_attention
or use_cueq_triangle_kernels
or use_triton_triangle_kernels
)
)
out_device = atom_positions_predicted.device

Expand Down
9 changes: 5 additions & 4 deletions openfold3/core/model/heads/prediction_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ def reshape_outputs(x: torch.Tensor, feat_dims: list):
single_mask = reshape_inputs(x=single_mask, feat_dims=single_mask.shape[-1:])
pair_mask = reshape_inputs(x=pair_mask, feat_dims=pair_mask.shape[-2:])

# Using the DS kernel with chunk tuning and multiple samples causes shape issues
# in the DS kernel. To avoid this, chunk tuning is disabled in this case.
# TODO: cuEq seems to fail comparison unit tests with the same settings,
# disable for now and verify behavior
# The optimized kernels all require that pair bias have size 1 in the
# second dimension and cross-sample chunking has to combine the batch
# dimensions and expand it. We mostly avoid this path entirely by
# splitting per-sample when using the optimized kernels, but this avoids
# a potential correctness issue here.
use_kernels = (
use_deepspeed_evo_attention
or use_cueq_triangle_kernels
Expand Down
1 change: 0 additions & 1 deletion openfold3/core/utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def _determine_favorable_chunk_size(fn, args, min_chunk_size, max_chunk_size):
candidates = [2**l for l in range(int(math.log(max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4

def test_chunk_size(chunk_size):
try:
Expand Down
22 changes: 5 additions & 17 deletions openfold3/tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,13 +719,13 @@ def _compare_template_stack(
chunk_size=None,
):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer
attention kernel. Kernel can be used for Triangle Attention in the Template Pair
Stack.
Compare Template Stack output with and without using different optimized
attention kernels. Kernel can be used for Triangle Attention in the
Template Pair Stack.
"""
batch_size = consts.batch_size
if chunk_size is not None and use_deepspeed_evo_attention:
# Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel
if chunk_size is not None:
# Chunking is not supported with batch size > 1 for optimized kernels
batch_size = 1

n_templ = 3
Expand Down Expand Up @@ -793,47 +793,41 @@ def to_device(t):
def test_compare_template_stack_dsk_fp32(self):
self._compare_template_stack(
use_deepspeed_evo_attention=True,
use_cueq_triangle_kernels=False,
dtype=torch.float32,
)

@compare_utils.skip_unless_ds4s_installed()
def test_compare_template_stack_dsk_bf16(self):
self._compare_template_stack(
use_deepspeed_evo_attention=True,
use_cueq_triangle_kernels=False,
dtype=torch.bfloat16,
)

@compare_utils.skip_unless_ds4s_installed()
def test_compare_template_stack_dsk_fp32_chunk(self):
self._compare_template_stack(
use_deepspeed_evo_attention=True,
use_cueq_triangle_kernels=False,
dtype=torch.float32,
chunk_size=4,
)

@compare_utils.skip_unless_cueq_installed()
def test_compare_template_stack_cueq_fp32(self):
self._compare_template_stack(
use_deepspeed_evo_attention=False,
use_cueq_triangle_kernels=True,
dtype=torch.float32,
)

@compare_utils.skip_unless_cueq_installed()
def test_compare_template_stack_cueq_bf16(self):
self._compare_template_stack(
use_deepspeed_evo_attention=False,
use_cueq_triangle_kernels=True,
dtype=torch.bfloat16,
)

@compare_utils.skip_unless_cueq_installed()
def test_compare_template_stack_cueq_fp32_chunk(self):
self._compare_template_stack(
use_deepspeed_evo_attention=False,
use_cueq_triangle_kernels=True,
dtype=torch.float32,
chunk_size=4,
Expand All @@ -842,8 +836,6 @@ def test_compare_template_stack_cueq_fp32_chunk(self):
@compare_utils.skip_unless_triton_installed()
def test_compare_template_stack_triton_fp32_chunk(self):
self._compare_template_stack(
use_deepspeed_evo_attention=False,
use_cueq_triangle_kernels=False,
use_triton_triangle_kernels=True,
dtype=torch.float32,
chunk_size=4,
Expand All @@ -852,17 +844,13 @@ def test_compare_template_stack_triton_fp32_chunk(self):
@compare_utils.skip_unless_triton_installed()
def test_compare_template_stack_triton_fp32(self):
self._compare_template_stack(
use_deepspeed_evo_attention=False,
use_cueq_triangle_kernels=False,
use_triton_triangle_kernels=True,
dtype=torch.float32,
)

@compare_utils.skip_unless_triton_installed()
def test_compare_template_stack_triton_bf16(self):
self._compare_template_stack(
use_deepspeed_evo_attention=False,
use_cueq_triangle_kernels=False,
use_triton_triangle_kernels=True,
dtype=torch.bfloat16,
)
Expand Down