From 84ae39f7af5ed17f9b1eb32728ea95deb817c774 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 1 May 2026 15:43:23 -0700 Subject: [PATCH 1/2] Apply power-of-two chunking consistently - Avoid weird addition of 4 to power-of-two chunk sizes. This was added in https://github.com/aqlaboratory/openfold-3/commit/a9a12890d without explanation. We can hypothesize that it was related to adding 4 to an input dimension in trace_utils.py (trying to get a test case to fit in one chunk?), but that file was long ago deleted. This just looks like a bug and makes us hit unhappy paths all over the place. Fixes https://github.com/aqlaboratory/openfold-3/issues/203 - Enable chunking for AuxiliaryHeadsAllAtom pairformer embedding when using optimized kernels. Without chunking, this is the first call to cause OOMs because its `diffusion_samples*sequence_length` batches. Chunking gets turned off in prediction_heads.py due to batch size > 1 and use of optimized kernels because cross-sample chunking requires expanding out pair bias and they all require it to have size 1 in the second dimension with implicit broadcasting. So we turn on `apply_per_sample` when optimized kernels are in use. This splits the > 1 batch dimension, which avoids this problematic path and then we can do normal chunking for the rest if it's still too large. We could do something more elaborate (see suggestions in linked issue), but this is an improvement for now. Fixes https://github.com/aqlaboratory/openfold-3/issues/206 --- openfold3/core/kernels/triton/evoformer.py | 4 ++++ openfold3/core/model/heads/head_modules.py | 12 ++++++++-- .../core/model/heads/prediction_heads.py | 9 ++++---- openfold3/core/utils/chunk_utils.py | 1 - openfold3/tests/test_kernels.py | 22 +++++-------------- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/openfold3/core/kernels/triton/evoformer.py b/openfold3/core/kernels/triton/evoformer.py index 03a40ff11..310a9cb68 100644 --- a/openfold3/core/kernels/triton/evoformer.py +++ b/openfold3/core/kernels/triton/evoformer.py @@ -904,6 +904,10 @@ def forward(ctx, Q, K, V, res_mask, pair_bias, has_pair_bias=True): ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = Q.shape + + assert res_mask.shape == (BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN), f"{tuple(res_mask.shape)} != {(BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN)}" + assert pair_bias.shape == (BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN), f"{tuple(pair_bias.shape)} != {(BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN)}" + softmax_scale = DIM**-0.5 BLOCK_DIM = max(triton.next_power_of_2(DIM), 32) diff --git a/openfold3/core/model/heads/head_modules.py b/openfold3/core/model/heads/head_modules.py index 46a132ac4..92830d48a 100644 --- a/openfold3/core/model/heads/head_modules.py +++ b/openfold3/core/model/heads/head_modules.py @@ -187,8 +187,16 @@ def forward( apply_per_sample = ( not torch.is_grad_enabled() and num_samples > 1 - and self.per_sample_token_cutoff is not None - and repr_x_pred.shape[-2] > self.per_sample_token_cutoff + and ( + (self.per_sample_token_cutoff is not None + and repr_x_pred.shape[-2] > self.per_sample_token_cutoff) + # The optimized attention kernels do not support cross-sample + # chunking because it requires expanding the pair bias. For now + # we just always apply per sample if these kernels are in use. + or use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels + ) ) out_device = atom_positions_predicted.device diff --git a/openfold3/core/model/heads/prediction_heads.py b/openfold3/core/model/heads/prediction_heads.py index e957bc68c..d89de0fda 100644 --- a/openfold3/core/model/heads/prediction_heads.py +++ b/openfold3/core/model/heads/prediction_heads.py @@ -209,10 +209,11 @@ def reshape_outputs(x: torch.Tensor, feat_dims: list): single_mask = reshape_inputs(x=single_mask, feat_dims=single_mask.shape[-1:]) pair_mask = reshape_inputs(x=pair_mask, feat_dims=pair_mask.shape[-2:]) - # Using the DS kernel with chunk tuning and multiple samples causes shape issues - # in the DS kernel. To avoid this, chunk tuning is disabled in this case. - # TODO: cuEq seems to fail comparison unit tests with the same settings, - # disable for now and verify behavior + # The optimized kernels all require that pair bias have size 1 in the + # second dimension and cross-sample chunking has to combine the batch + # dimensions and expand it. We mostly avoid this path entirely by + # splitting per-sample when using the optimized kernels, but this avoids + # a potential correctness issue here. use_kernels = ( use_deepspeed_evo_attention or use_cueq_triangle_kernels diff --git a/openfold3/core/utils/chunk_utils.py b/openfold3/core/utils/chunk_utils.py index c89e1d1c3..af870cc72 100644 --- a/openfold3/core/utils/chunk_utils.py +++ b/openfold3/core/utils/chunk_utils.py @@ -363,7 +363,6 @@ def _determine_favorable_chunk_size(fn, args, min_chunk_size, max_chunk_size): candidates = [2**l for l in range(int(math.log(max_chunk_size, 2)) + 1)] candidates = [c for c in candidates if c > min_chunk_size] candidates = [min_chunk_size] + candidates - candidates[-1] += 4 def test_chunk_size(chunk_size): try: diff --git a/openfold3/tests/test_kernels.py b/openfold3/tests/test_kernels.py index ed5a01111..490068fcd 100644 --- a/openfold3/tests/test_kernels.py +++ b/openfold3/tests/test_kernels.py @@ -719,13 +719,13 @@ def _compare_template_stack( chunk_size=None, ): """ - Compare Template Stack output with and without using DeepSpeed Evoformer - attention kernel. Kernel can be used for Triangle Attention in the Template Pair - Stack. + Compare Template Stack output with and without using different optimized + attention kernels. Kernel can be used for Triangle Attention in the + Template Pair Stack. """ batch_size = consts.batch_size - if chunk_size is not None and use_deepspeed_evo_attention: - # Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel + if chunk_size is not None: + # Chunking is not supported with batch size > 1 for optimized kernels batch_size = 1 n_templ = 3 @@ -793,7 +793,6 @@ def to_device(t): def test_compare_template_stack_dsk_fp32(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.float32, ) @@ -801,7 +800,6 @@ def test_compare_template_stack_dsk_fp32(self): def test_compare_template_stack_dsk_bf16(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.bfloat16, ) @@ -809,7 +807,6 @@ def test_compare_template_stack_dsk_bf16(self): def test_compare_template_stack_dsk_fp32_chunk(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.float32, chunk_size=4, ) @@ -817,7 +814,6 @@ def test_compare_template_stack_dsk_fp32_chunk(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_fp32(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.float32, ) @@ -825,7 +821,6 @@ def test_compare_template_stack_cueq_fp32(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_bf16(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.bfloat16, ) @@ -833,7 +828,6 @@ def test_compare_template_stack_cueq_bf16(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_fp32_chunk(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.float32, chunk_size=4, @@ -842,8 +836,6 @@ def test_compare_template_stack_cueq_fp32_chunk(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_fp32_chunk(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.float32, chunk_size=4, @@ -852,8 +844,6 @@ def test_compare_template_stack_triton_fp32_chunk(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_fp32(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.float32, ) @@ -861,8 +851,6 @@ def test_compare_template_stack_triton_fp32(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_bf16(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.bfloat16, ) From 42700051396cf8c0bbbc9268b29344c1d6d7e02b Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Mon, 18 May 2026 12:25:07 -0700 Subject: [PATCH 2/2] Fix formatting --- openfold3/core/kernels/triton/evoformer.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/openfold3/core/kernels/triton/evoformer.py b/openfold3/core/kernels/triton/evoformer.py index 310a9cb68..297795e64 100644 --- a/openfold3/core/kernels/triton/evoformer.py +++ b/openfold3/core/kernels/triton/evoformer.py @@ -905,8 +905,20 @@ def forward(ctx, Q, K, V, res_mask, pair_bias, has_pair_bias=True): BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = Q.shape - assert res_mask.shape == (BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN), f"{tuple(res_mask.shape)} != {(BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN)}" - assert pair_bias.shape == (BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN), f"{tuple(pair_bias.shape)} != {(BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN)}" + assert res_mask.shape == ( + BATCH_SIZE, + N_SEQ, + 1, + 1, + SEQ_LEN, + ), f"{tuple(res_mask.shape)} != {(BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN)}" + assert pair_bias.shape == ( + BATCH_SIZE, + 1, + HEAD, + SEQ_LEN, + SEQ_LEN, + ), f"{tuple(pair_bias.shape)} != {(BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN)}" softmax_scale = DIM**-0.5 BLOCK_DIM = max(triton.next_power_of_2(DIM), 32)