Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/llm/custom_ops/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ runtime.python_test(
],
deps = [
"//caffe2:torch",
"//executorch/extension/pybindings:portable_lib",
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TARGETS file is missing a corresponding dependency update for test_update_cross_attn_cache. While test_update_cross_attn_cache.py doesn't currently use _unsafe_reset_threadpool, the file is importing from the same module family and may benefit from the same OMP error workaround that was applied to test_sdpa_with_kv_cache and test_quantized_sdpa. Consider adding the dependency if similar threading issues arise in the future.

Copilot uses AI. Check for mistakes.
],
)

Expand Down
14 changes: 5 additions & 9 deletions extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F

from executorch.extension.llm.custom_ops import custom_ops # noqa
from executorch.extension.pybindings.portable_lib import _unsafe_reset_threadpool


def is_fbcode():
Expand Down Expand Up @@ -45,7 +46,6 @@ def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq


class SDPATest(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.k_cache = torch.zeros((1, 10, 8, 4))
Expand Down Expand Up @@ -233,7 +233,6 @@ def test_sdpa_with_cache_no_mqa_4(self):


class SDPAWithAttentionMaskTest(SDPATest):

def setUp(self):
SDPATest.setUp(self)
self.mask = torch.full(
Expand All @@ -244,7 +243,6 @@ def setUp(self):


class SDPAWithAttentionMaskLongSequenceTest(SDPATest):

def setUp(self):
SDPATest.setUp(self)
max_context_len = 700
Expand Down Expand Up @@ -276,14 +274,12 @@ def setUp(self):


class SDPAWithCausalTest(SDPATest):

def setUp(self):
SDPATest.setUp(self)
self.is_causal = True


class SDPAWithDynamicShapeTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.k_cache = torch.zeros((1, 10, 8, 4))
Expand Down Expand Up @@ -346,7 +342,6 @@ def test_sdpa_with_cache_dynamic_shape_4(self):


class SDPATestWithMQA(unittest.TestCase):

def setup_caches(self):
self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
Expand Down Expand Up @@ -415,7 +410,6 @@ def test_sdpa_with_cache_mqa_3(self):


class SDPATestCommon(unittest.TestCase):

def setup_caches(self):
self.k_cache = torch.zeros(
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
Expand All @@ -437,6 +431,10 @@ def setUp(self):
self.head_dim = 128
self.max_seq_len = 2048
self.setup_caches()
# This setting is needed to make this test not flaky due to OMP
# error of "OMP: Error #131: Thread identifier invalid"
# See also test_quantized_sdpa.py for the same workaround
_unsafe_reset_threadpool(3)

def _scale_tensor(self, tensor, min_value, max_value, scale=True):
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
Expand Down Expand Up @@ -532,7 +530,6 @@ def _test_sdpa_common(


class SDPATestForLargeSeqLength(SDPATestCommon):

def test_sdpa_with_cache_seq_len_130(self):
n_heads_kv = 8
n_heads_q = 8
Expand Down Expand Up @@ -579,7 +576,6 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self):


class SDPATestForSpeculativeDecode(SDPATestCommon):

def test_sdpa_with_cache_seq_len_130(self):
n_heads_kv = 32
n_heads_q = 32
Expand Down
12 changes: 7 additions & 5 deletions extension/llm/custom_ops/test_update_cross_attn_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# Check CUDA availability once at module level
CUDA_AVAILABLE = torch.cuda.is_available()

# Check if CUDA device has compatible compute capability for Triton kernels
# Minimum CC 9.0 (Hopper) required for current PyTorch/Triton build
CUDA_CC_COMPATIBLE = CUDA_AVAILABLE and torch.cuda.get_device_capability()[0] >= 9
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute capability check may raise an IndexError when CUDA is available but no GPU device is present. Consider adding a device count check or wrapping in a try-except block. For example, the condition should handle the case where torch.cuda.device_count() is 0.

Suggested change
CUDA_CC_COMPATIBLE = CUDA_AVAILABLE and torch.cuda.get_device_capability()[0] >= 9
CUDA_CC_COMPATIBLE = (
CUDA_AVAILABLE
and torch.cuda.device_count() > 0
and torch.cuda.get_device_capability()[0] >= 9
)

Copilot uses AI. Check for mistakes.


class TestUpdateCrossAttnCache(unittest.TestCase):
def test_update_cross_attn_cache(self):

# Create tensors
# Cache: [B=2, H=1, S_max=4, D=4]
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
Expand Down Expand Up @@ -101,7 +104,6 @@ def compiled_fn(pred, v1, v2, c):
)

def test_update_cross_attn_cache_export(self):

# Create tensors
# Cache: [B=2, H=1, S_max=4, D=4]
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
Expand Down Expand Up @@ -154,7 +156,6 @@ def false_fn(v1, v2, cache):
)

def test_update_cross_attn_cache_different_shapes(self):

# Test with different batch sizes and sequence lengths
test_cases = [
# (B, H, S_max, S, D)
Expand Down Expand Up @@ -190,7 +191,6 @@ def fn(v, c):
)

def test_update_cross_attn_cache_full_sequence(self):

# Cache: [B=2, H=1, S_max=4, D=4]
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
# Value: [B=2, H=1, S=4, D=4] (S == S_max)
Expand All @@ -207,7 +207,9 @@ def fn(v, c):
cache, value, msg="Cache not fully updated when S == S_max"
)

@unittest.skipUnless(CUDA_AVAILABLE, "CUDA not available")
@unittest.skipUnless(
CUDA_CC_COMPATIBLE, "Requires CUDA with compute capability >= 9.0"
)
def test_alias_and_update_cross_attn_cache_with_cond_triton(self):
"""Test combining alias and update_cross_attn_cache ops with torch.cond,
lowered to Triton on CUDA. True branch uses alias, false branch uses
Expand Down
Loading