diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index e4d87eb0dd7..baee5afe17c 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -19,6 +19,7 @@ runtime.python_test( ], deps = [ "//caffe2:torch", + "//executorch/extension/pybindings:portable_lib", ], ) diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index 502a6238a7d..d044a4789ff 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from executorch.extension.llm.custom_ops import custom_ops # noqa +from executorch.extension.pybindings.portable_lib import _unsafe_reset_threadpool def is_fbcode(): @@ -45,7 +46,6 @@ def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq class SDPATest(unittest.TestCase): - def setUp(self): torch.manual_seed(42) self.k_cache = torch.zeros((1, 10, 8, 4)) @@ -233,7 +233,6 @@ def test_sdpa_with_cache_no_mqa_4(self): class SDPAWithAttentionMaskTest(SDPATest): - def setUp(self): SDPATest.setUp(self) self.mask = torch.full( @@ -244,7 +243,6 @@ def setUp(self): class SDPAWithAttentionMaskLongSequenceTest(SDPATest): - def setUp(self): SDPATest.setUp(self) max_context_len = 700 @@ -276,14 +274,12 @@ def setUp(self): class SDPAWithCausalTest(SDPATest): - def setUp(self): SDPATest.setUp(self) self.is_causal = True class SDPAWithDynamicShapeTest(unittest.TestCase): - def setUp(self): torch.manual_seed(42) self.k_cache = torch.zeros((1, 10, 8, 4)) @@ -346,7 +342,6 @@ def test_sdpa_with_cache_dynamic_shape_4(self): class SDPATestWithMQA(unittest.TestCase): - def setup_caches(self): self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4)) self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4)) @@ -415,7 +410,6 @@ def test_sdpa_with_cache_mqa_3(self): class SDPATestCommon(unittest.TestCase): - def setup_caches(self): self.k_cache = torch.zeros( (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) @@ -437,6 +431,10 @@ def setUp(self): self.head_dim = 128 self.max_seq_len = 2048 self.setup_caches() + # This setting is needed to make this test not flaky due to OMP + # error of "OMP: Error #131: Thread identifier invalid" + # See also test_quantized_sdpa.py for the same workaround + _unsafe_reset_threadpool(3) def _scale_tensor(self, tensor, min_value, max_value, scale=True): normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) @@ -532,7 +530,6 @@ def _test_sdpa_common( class SDPATestForLargeSeqLength(SDPATestCommon): - def test_sdpa_with_cache_seq_len_130(self): n_heads_kv = 8 n_heads_q = 8 @@ -579,7 +576,6 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self): class SDPATestForSpeculativeDecode(SDPATestCommon): - def test_sdpa_with_cache_seq_len_130(self): n_heads_kv = 32 n_heads_q = 32 diff --git a/extension/llm/custom_ops/test_update_cross_attn_cache.py b/extension/llm/custom_ops/test_update_cross_attn_cache.py index dde2da68f51..c36126b3b97 100644 --- a/extension/llm/custom_ops/test_update_cross_attn_cache.py +++ b/extension/llm/custom_ops/test_update_cross_attn_cache.py @@ -14,10 +14,13 @@ # Check CUDA availability once at module level CUDA_AVAILABLE = torch.cuda.is_available() +# Check if CUDA device has compatible compute capability for Triton kernels +# Minimum CC 9.0 (Hopper) required for current PyTorch/Triton build +CUDA_CC_COMPATIBLE = CUDA_AVAILABLE and torch.cuda.get_device_capability()[0] >= 9 + class TestUpdateCrossAttnCache(unittest.TestCase): def test_update_cross_attn_cache(self): - # Create tensors # Cache: [B=2, H=1, S_max=4, D=4] cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) @@ -101,7 +104,6 @@ def compiled_fn(pred, v1, v2, c): ) def test_update_cross_attn_cache_export(self): - # Create tensors # Cache: [B=2, H=1, S_max=4, D=4] cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) @@ -154,7 +156,6 @@ def false_fn(v1, v2, cache): ) def test_update_cross_attn_cache_different_shapes(self): - # Test with different batch sizes and sequence lengths test_cases = [ # (B, H, S_max, S, D) @@ -190,7 +191,6 @@ def fn(v, c): ) def test_update_cross_attn_cache_full_sequence(self): - # Cache: [B=2, H=1, S_max=4, D=4] cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) # Value: [B=2, H=1, S=4, D=4] (S == S_max) @@ -207,7 +207,9 @@ def fn(v, c): cache, value, msg="Cache not fully updated when S == S_max" ) - @unittest.skipUnless(CUDA_AVAILABLE, "CUDA not available") + @unittest.skipUnless( + CUDA_CC_COMPATIBLE, "Requires CUDA with compute capability >= 9.0" + ) def test_alias_and_update_cross_attn_cache_with_cond_triton(self): """Test combining alias and update_cross_attn_cache ops with torch.cond, lowered to Triton on CUDA. True branch uses alias, false branch uses