Skip to content

Commit e53613f

Browse files
authored
Fix executorch/extension/llm/custom_op/... (#16665)
Reviewed By: rascani Differential Revision: D90888544
1 parent 6cd2589 commit e53613f

3 files changed

Lines changed: 13 additions & 14 deletions

File tree

extension/llm/custom_ops/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ runtime.python_test(
1919
],
2020
deps = [
2121
"//caffe2:torch",
22+
"//executorch/extension/pybindings:portable_lib",
2223
],
2324
)
2425

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313

1414
from executorch.extension.llm.custom_ops import custom_ops # noqa
15+
from executorch.extension.pybindings.portable_lib import _unsafe_reset_threadpool
1516

1617

1718
def is_fbcode():
@@ -45,7 +46,6 @@ def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq
4546

4647

4748
class SDPATest(unittest.TestCase):
48-
4949
def setUp(self):
5050
torch.manual_seed(42)
5151
self.k_cache = torch.zeros((1, 10, 8, 4))
@@ -233,7 +233,6 @@ def test_sdpa_with_cache_no_mqa_4(self):
233233

234234

235235
class SDPAWithAttentionMaskTest(SDPATest):
236-
237236
def setUp(self):
238237
SDPATest.setUp(self)
239238
self.mask = torch.full(
@@ -244,7 +243,6 @@ def setUp(self):
244243

245244

246245
class SDPAWithAttentionMaskLongSequenceTest(SDPATest):
247-
248246
def setUp(self):
249247
SDPATest.setUp(self)
250248
max_context_len = 700
@@ -276,14 +274,12 @@ def setUp(self):
276274

277275

278276
class SDPAWithCausalTest(SDPATest):
279-
280277
def setUp(self):
281278
SDPATest.setUp(self)
282279
self.is_causal = True
283280

284281

285282
class SDPAWithDynamicShapeTest(unittest.TestCase):
286-
287283
def setUp(self):
288284
torch.manual_seed(42)
289285
self.k_cache = torch.zeros((1, 10, 8, 4))
@@ -346,7 +342,6 @@ def test_sdpa_with_cache_dynamic_shape_4(self):
346342

347343

348344
class SDPATestWithMQA(unittest.TestCase):
349-
350345
def setup_caches(self):
351346
self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
352347
self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
@@ -415,7 +410,6 @@ def test_sdpa_with_cache_mqa_3(self):
415410

416411

417412
class SDPATestCommon(unittest.TestCase):
418-
419413
def setup_caches(self):
420414
self.k_cache = torch.zeros(
421415
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
@@ -437,6 +431,10 @@ def setUp(self):
437431
self.head_dim = 128
438432
self.max_seq_len = 2048
439433
self.setup_caches()
434+
# This setting is needed to make this test not flaky due to OMP
435+
# error of "OMP: Error #131: Thread identifier invalid"
436+
# See also test_quantized_sdpa.py for the same workaround
437+
_unsafe_reset_threadpool(3)
440438

441439
def _scale_tensor(self, tensor, min_value, max_value, scale=True):
442440
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
@@ -532,7 +530,6 @@ def _test_sdpa_common(
532530

533531

534532
class SDPATestForLargeSeqLength(SDPATestCommon):
535-
536533
def test_sdpa_with_cache_seq_len_130(self):
537534
n_heads_kv = 8
538535
n_heads_q = 8
@@ -579,7 +576,6 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
579576

580577

581578
class SDPATestForSpeculativeDecode(SDPATestCommon):
582-
583579
def test_sdpa_with_cache_seq_len_130(self):
584580
n_heads_kv = 32
585581
n_heads_q = 32

extension/llm/custom_ops/test_update_cross_attn_cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
# Check CUDA availability once at module level
1515
CUDA_AVAILABLE = torch.cuda.is_available()
1616

17+
# Check if CUDA device has compatible compute capability for Triton kernels
18+
# Minimum CC 9.0 (Hopper) required for current PyTorch/Triton build
19+
CUDA_CC_COMPATIBLE = CUDA_AVAILABLE and torch.cuda.get_device_capability()[0] >= 9
20+
1721

1822
class TestUpdateCrossAttnCache(unittest.TestCase):
1923
def test_update_cross_attn_cache(self):
20-
2124
# Create tensors
2225
# Cache: [B=2, H=1, S_max=4, D=4]
2326
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
@@ -101,7 +104,6 @@ def compiled_fn(pred, v1, v2, c):
101104
)
102105

103106
def test_update_cross_attn_cache_export(self):
104-
105107
# Create tensors
106108
# Cache: [B=2, H=1, S_max=4, D=4]
107109
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
@@ -154,7 +156,6 @@ def false_fn(v1, v2, cache):
154156
)
155157

156158
def test_update_cross_attn_cache_different_shapes(self):
157-
158159
# Test with different batch sizes and sequence lengths
159160
test_cases = [
160161
# (B, H, S_max, S, D)
@@ -190,7 +191,6 @@ def fn(v, c):
190191
)
191192

192193
def test_update_cross_attn_cache_full_sequence(self):
193-
194194
# Cache: [B=2, H=1, S_max=4, D=4]
195195
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
196196
# Value: [B=2, H=1, S=4, D=4] (S == S_max)
@@ -207,7 +207,9 @@ def fn(v, c):
207207
cache, value, msg="Cache not fully updated when S == S_max"
208208
)
209209

210-
@unittest.skipUnless(CUDA_AVAILABLE, "CUDA not available")
210+
@unittest.skipUnless(
211+
CUDA_CC_COMPATIBLE, "Requires CUDA with compute capability >= 9.0"
212+
)
211213
def test_alias_and_update_cross_attn_cache_with_cond_triton(self):
212214
"""Test combining alias and update_cross_attn_cache ops with torch.cond,
213215
lowered to Triton on CUDA. True branch uses alias, false branch uses

0 commit comments

Comments
 (0)