1212import torch .nn .functional as F
1313
1414from executorch .extension .llm .custom_ops import custom_ops # noqa
15+ from executorch .extension .pybindings .portable_lib import _unsafe_reset_threadpool
1516
1617
1718def is_fbcode ():
@@ -45,7 +46,6 @@ def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq
4546
4647
4748class SDPATest (unittest .TestCase ):
48-
4949 def setUp (self ):
5050 torch .manual_seed (42 )
5151 self .k_cache = torch .zeros ((1 , 10 , 8 , 4 ))
@@ -233,7 +233,6 @@ def test_sdpa_with_cache_no_mqa_4(self):
233233
234234
235235class SDPAWithAttentionMaskTest (SDPATest ):
236-
237236 def setUp (self ):
238237 SDPATest .setUp (self )
239238 self .mask = torch .full (
@@ -244,7 +243,6 @@ def setUp(self):
244243
245244
246245class SDPAWithAttentionMaskLongSequenceTest (SDPATest ):
247-
248246 def setUp (self ):
249247 SDPATest .setUp (self )
250248 max_context_len = 700
@@ -276,14 +274,12 @@ def setUp(self):
276274
277275
278276class SDPAWithCausalTest (SDPATest ):
279-
280277 def setUp (self ):
281278 SDPATest .setUp (self )
282279 self .is_causal = True
283280
284281
285282class SDPAWithDynamicShapeTest (unittest .TestCase ):
286-
287283 def setUp (self ):
288284 torch .manual_seed (42 )
289285 self .k_cache = torch .zeros ((1 , 10 , 8 , 4 ))
@@ -346,7 +342,6 @@ def test_sdpa_with_cache_dynamic_shape_4(self):
346342
347343
348344class SDPATestWithMQA (unittest .TestCase ):
349-
350345 def setup_caches (self ):
351346 self .k_cache = torch .zeros ((1 , 5 , self .n_heads_kv , 4 ))
352347 self .v_cache = torch .zeros ((1 , 5 , self .n_heads_kv , 4 ))
@@ -415,7 +410,6 @@ def test_sdpa_with_cache_mqa_3(self):
415410
416411
417412class SDPATestCommon (unittest .TestCase ):
418-
419413 def setup_caches (self ):
420414 self .k_cache = torch .zeros (
421415 (self .n_batch , self .max_seq_len , self .n_heads_kv , self .head_dim )
@@ -437,6 +431,10 @@ def setUp(self):
437431 self .head_dim = 128
438432 self .max_seq_len = 2048
439433 self .setup_caches ()
434+ # This setting is needed to make this test not flaky due to OMP
435+ # error of "OMP: Error #131: Thread identifier invalid"
436+ # See also test_quantized_sdpa.py for the same workaround
437+ _unsafe_reset_threadpool (3 )
440438
441439 def _scale_tensor (self , tensor , min_value , max_value , scale = True ):
442440 normalized_tensor = (tensor - tensor .min ()) / (tensor .max () - tensor .min ())
@@ -532,7 +530,6 @@ def _test_sdpa_common(
532530
533531
534532class SDPATestForLargeSeqLength (SDPATestCommon ):
535-
536533 def test_sdpa_with_cache_seq_len_130 (self ):
537534 n_heads_kv = 8
538535 n_heads_q = 8
@@ -579,7 +576,6 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
579576
580577
581578class SDPATestForSpeculativeDecode (SDPATestCommon ):
582-
583579 def test_sdpa_with_cache_seq_len_130 (self ):
584580 n_heads_kv = 32
585581 n_heads_q = 32
0 commit comments