diff --git a/TensorRT-Edge-LLM b/TensorRT-Edge-LLM new file mode 120000 index 0000000000..3cde4f042b --- /dev/null +++ b/TensorRT-Edge-LLM @@ -0,0 +1 @@ +/workspace/TensorRT-Edge-LLM \ No newline at end of file diff --git a/examples/dynamo/attention_plugin_example.py b/examples/dynamo/attention_plugin_example.py new file mode 100644 index 0000000000..d0498eeeaa --- /dev/null +++ b/examples/dynamo/attention_plugin_example.py @@ -0,0 +1,756 @@ +""" +.. _attention_plugin_example: + +Custom Attention Plugin with KV Cache Management +================================================= + +This example demonstrates how to use a custom TensorRT AttentionPlugin that implements +efficient multi-head attention with Rotary Position Embedding (RoPE) and KV cache management +for autoregressive generation in Large Language Models (LLMs). + +**Plugin Library:** + +This example uses a custom TensorRT plugin shared library (``libNvInfer_edgellm_plugin.so``) +that replaces standard transformer attention operations and RoPE computations with optimized +CUDA kernels. The plugin source code is available at: + +https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime + +Build instructions and implementation details can be found in the repository above. +This implementation has been verified with TensorRT-Edge-LLM release 0.4.0. + +**Key Features:** + +- **Dual Kernel Support:** + + - **FMHA (Fused Multi-Head Attention)** for context phase when ``seq_len > 1`` (processing multiple tokens) + - **XQA (Extended Query Attention)** for decode phase when ``seq_len = 1`` (single token generation) + +- **KV Cache Management:** Efficiently manages key-value cache for autoregressive generation +- **Perfect Accuracy:** Achieves cosine similarity = 1.0 with PyTorch's ``scaled_dot_product_attention`` +- **Grouped Query Attention (GQA):** Supports efficient attention with fewer KV heads + +**What This Example Tests:** + +1. **XQA Kernel (seq_len=1):** Single token generation, with and without past context +2. **FMHA Kernel (seq_len>1):** Context processing with multiple tokens +3. **Multi-Step Generation:** Realistic LLM scenario - process prompt (FMHA), then generate tokens (XQA) +4. **Perfect Accuracy:** All tests achieve ``cosine_similarity ≥ 0.99`` with PyTorch SDPA + +**Installation Requirements:** + +.. code-block:: bash + + pip install torch torch_tensorrt tensorrt + +Build the AttentionPlugin shared library following instructions at the GitHub repository above. +The compiled library should be located at: ``/path/to/tensorrt-edgellm/build/libNvInfer_edgellm_plugin.so`` +""" + +# %% +# Imports and Setup +# ----------------- + +import ctypes +import os +from typing import Tuple + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# %% +# Enable plugin debug logging +# ---------------------------- +os.environ["TRT_EDGELLM_DEBUG_PLUGIN"] = "1" + +# %% +# Initialize CUDA and Load Plugin +# -------------------------------- +# CUDA must be initialized before loading the TensorRT plugin library + +print("Initializing CUDA context...") +DEVICE = torch.device("cuda:0") +_ = torch.zeros(1, device=DEVICE) # Initialize CUDA +print(f"CUDA initialized on {DEVICE}\n") + +PLUGIN_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", +) +ctypes.CDLL(PLUGIN_PATH) +print(f"Loaded plugin: {PLUGIN_PATH}\n") + +# %% +# Model Configuration +# ------------------- +# These hyperparameters match typical LLM architectures with Grouped Query Attention (GQA) + +BATCH_SIZE = 1 +NUM_Q_HEADS = 4 # Number of query heads +NUM_KV_HEADS = 2 # Number of key/value heads (GQA: fewer than query heads) +HEAD_DIM = 64 # Dimension per head +KV_CACHE_CAPACITY = 128 # Maximum sequence length +HIDDEN_DIM = NUM_Q_HEADS * HEAD_DIM # 256 +NUM_KV_GROUPS = NUM_Q_HEADS // NUM_KV_HEADS # 2 + +DTYPE = torch.float16 + +# %% +# RoPE (Rotary Position Embedding) Utilities +# ------------------------------------------- +# RoPE encodes positional information through rotation in complex space + + +def precompute_rope(head_dim: int, max_seq_len: int = 128, base: float = 10000.0): + """ + Precompute RoPE cos/sin for all positions. + + Returns: + Tensor of shape [1, max_seq_len, head_dim] in FP32 + """ + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + rope = torch.cat([cos, sin], dim=-1) + return rope.unsqueeze(0).to(DEVICE) + + +def apply_rope(x, rope_cache, position_ids): + """ + Apply RoPE to input tensor. + + Args: + x: [batch, num_heads, seq_len, head_dim] + rope_cache: [1, max_seq_len, head_dim] + position_ids: [seq_len] position indices + """ + seq_len = x.shape[2] + rope = rope_cache[:, position_ids, :] # [1, seq_len, head_dim] + rope = rope.unsqueeze(1) # [1, 1, seq_len, head_dim] + + half_dim = x.shape[-1] // 2 + cos = rope[..., :half_dim] + sin = rope[..., half_dim:] + + x_fp32 = x.float() + x1 = x_fp32[..., :half_dim] + x2 = x_fp32[..., half_dim:] + + rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + return rotated.half() + + +def repeat_kv(x, n_rep): + """Repeat KV heads for Grouped Query Attention""" + if n_rep == 1: + return x + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + + +# %% +# PyTorch SDPA Reference Implementation +# ------------------------------------- +# This serves as the ground truth for correctness validation + + +class SDPAModel(nn.Module): + """Reference attention using PyTorch's scaled_dot_product_attention""" + + def __init__(self): + super().__init__() + self.num_q_heads = NUM_Q_HEADS + self.num_kv_heads = NUM_KV_HEADS + self.head_dim = HEAD_DIM + self.num_key_value_groups = NUM_KV_GROUPS + + self.qkv = nn.Linear( + HIDDEN_DIM, HIDDEN_DIM + 2 * NUM_KV_HEADS * HEAD_DIM, bias=True + ) + self.out = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + def forward(self, x, kv_cache, ctx_len_tensor, rope): + """ + Args: + x: [batch, seq_len, hidden_dim] + kv_cache: [batch, 2, num_kv_heads, capacity, head_dim] + ctx_len_tensor: [batch] - total context length including current tokens + rope: [1, max_seq_len, head_dim] + """ + batch_size, seq_len, _ = x.shape + ctx_len = ctx_len_tensor[0].item() + past_len = ctx_len - seq_len + + # QKV projection + qkv = self.qkv(x) + q_size = self.num_q_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + + # Reshape to multi-head format + query = query.view( + batch_size, seq_len, self.num_q_heads, self.head_dim + ).transpose(1, 2) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose( + 1, 2 + ) + value = value.view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device) + query = apply_rope(query, rope, position_ids) + key = apply_rope(key, rope, position_ids) + + # Update KV cache + kv_cache[:, 0, :, past_len : past_len + seq_len, :] = key + kv_cache[:, 1, :, past_len : past_len + seq_len, :] = value + + # Get full K/V from cache + full_key = kv_cache[:, 0, :, :ctx_len, :] + full_value = kv_cache[:, 1, :, :ctx_len, :] + + # Expand for GQA + full_key = repeat_kv(full_key, self.num_key_value_groups) + full_value = repeat_kv(full_value, self.num_key_value_groups) + + # Scaled dot-product attention + is_causal = seq_len > 1 + attn_out = F.scaled_dot_product_attention( + query.contiguous(), + full_key.contiguous(), + full_value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=is_causal, + ) + + # Output projection + attn_out = ( + attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, HIDDEN_DIM) + ) + output = self.out(attn_out) + + return output, kv_cache + + +# %% +# TensorRT Plugin Integration +# ---------------------------- +# Register custom operation and converter for TensorRT plugin + + +def register_plugin_op(): + """ + Register custom attention operation. + + Note: The release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv: [B, S, (Hq+Hk+Hv)*D] fused QKV tensor + - kv: [B, 2, Hkv, Capacity, D] KV cache tensor + - ctx_len: [B] context length per batch + - rope: [S, D] rotary position encoding + - kv_cache_start_idx: [B] starting index in KV cache (required for release version) + """ + + @torch.library.custom_op("tensorrt_edge_llm::xqa_attn", mutates_args=()) + def attn( + qkv: torch.Tensor, + kv: torch.Tensor, + ctx_len: torch.Tensor, + rope: torch.Tensor, + kv_cache_start_idx: torch.Tensor, # Required 5th input for release plugin + nq: int, + nkv: int, + d: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.zeros( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + @torch.library.register_fake("tensorrt_edge_llm::xqa_attn") + def _(qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d): + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.empty( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + +register_plugin_op() + + +@dynamo_tensorrt_converter( + torch.ops.tensorrt_edge_llm.xqa_attn.default, supports_dynamic_shapes=True +) +def convert_attn(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert PyTorch custom op to TensorRT plugin. + + Release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv, kv, ctx_len, rope, kv_cache_start_idx + + Plugin fields for release version: + - num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output + """ + # args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d + qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8] + + # Get plugin creator + creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("AttentionPlugin not found! Make sure plugin is loaded.") + + # Plugin fields for TensorRT-Edge-LLM AttentionPlugin + # Required: num_q_heads, num_kv_heads, head_size, enable_tree_attention + # enable_delta_kv_output=1 enables delta KV output for Python/torch_tensorrt compatibility + field_list = [ + trt.PluginField( + field_name, np.array([field_val], dtype=np.int32), trt.PluginFieldType.INT32 + ) + for field_name, field_val in [ + ("num_q_heads", nq), + ("num_kv_heads", nkv), + ("head_size", d), + ("enable_tree_attention", 0), + ("enable_delta_kv_output", 1), + ] + ] + + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + + if plugin is None: + raise RuntimeError("Failed to create plugin") + + # 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx + inputs = [ + ( + get_trt_tensor(ctx, i, f"{name}_i{idx}") + if not isinstance(i, trt.ITensor) + else i + ) + for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx]) + ] + + # Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[4]) + shuffle_layer.reshape_dims = (inputs[4].shape[0],) + inputs[4] = shuffle_layer.get_output(0) + + layer = ctx.net.add_plugin_v2(inputs, plugin) + + return layer.get_output(0), layer.get_output(1) + + +class PluginModel(nn.Module): + """Attention model using TensorRT plugin""" + + def __init__(self): + super().__init__() + self.qkv = nn.Linear( + HIDDEN_DIM, HIDDEN_DIM + 2 * NUM_KV_HEADS * HEAD_DIM, bias=True + ) + self.out = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + def forward(self, x, kv_cache, ctx_len_tensor, rope): + bsz, seq_len, _ = x.shape + qkv = self.qkv(x) + + # kv_cache_start_idx: starting position in KV cache for each batch + # For normal inference, this is 0 (start from beginning) + kv_cache_start_idx = torch.zeros(bsz, dtype=torch.int32, device=x.device) + + # Custom plugin call (5 inputs for release version) + attn_out, updated_kv = torch.ops.tensorrt_edge_llm.xqa_attn.default( + qkv, + kv_cache, + ctx_len_tensor, + rope, + kv_cache_start_idx, + NUM_Q_HEADS, + NUM_KV_HEADS, + HEAD_DIM, + ) + + # Reshape from [B, S, num_heads, head_dim] to [B, S, hidden_dim] + attn_out = attn_out.reshape(bsz, seq_len, HIDDEN_DIM) + + return self.out(attn_out), updated_kv + + +# %% +# Test Functions +# -------------- + + +def test_case( + name: str, seq_len: int, has_past_context: bool, sdpa_model, trt_model, rope +): + """ + Run a single test case and validate correctness. + + Args: + name: Test case name + seq_len: Sequence length (1 for XQA, >1 for FMHA) + has_past_context: Whether to initialize KV cache with past tokens + sdpa_model: PyTorch SDPA reference model + trt_model: Compiled TensorRT model + rope: Precomputed RoPE cache + + Note: + With enable_delta_kv_output=1, TRT plugin outputs only the delta KV: + - Context Phase: [B, 2, H, seq_len, D] (newly processed tokens) + - Generation Phase: [B, 2, H, 1, D] (single new token) + Python runtime must merge this delta into the main KV cache. + """ + print(f"\n{name}") + + # Determine context length + past_len = 10 if has_past_context else 0 + ctx_len = torch.tensor([past_len + seq_len], dtype=torch.int32, device=DEVICE) + + # Initialize KV caches + sdpa_kv = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + trt_kv = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + + # Add past context if needed + if has_past_context: + past_values = torch.randn( + BATCH_SIZE, 2, NUM_KV_HEADS, past_len, HEAD_DIM, dtype=DTYPE, device=DEVICE + ) + sdpa_kv[:, :, :, :past_len, :] = past_values + trt_kv[:, :, :, :past_len, :] = past_values + print(f" Input: {seq_len} new tokens + {past_len} past tokens in cache") + else: + print(f" Input: {seq_len} tokens (empty KV cache)") + + # Generate input tokens + x = torch.randn(BATCH_SIZE, seq_len, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + + # Run both models + with torch.no_grad(): + sdpa_out, sdpa_kv_new = sdpa_model(x, sdpa_kv, ctx_len, rope) + trt_out, trt_kv_delta = trt_model(x, trt_kv, ctx_len, rope) + + # TRT plugin with enable_delta_kv_output=1 returns only delta KV + # Merge delta into main KV cache at the correct position + delta_seq_len = trt_kv_delta.shape[3] # Should be seq_len + trt_kv[:, :, :, past_len : past_len + delta_seq_len, :] = trt_kv_delta + + # Compute similarities + attn_sim = F.cosine_similarity( + sdpa_out.flatten().float(), trt_out.flatten().float(), dim=0 + ).item() + + # Compare the newly updated portion of KV cache (after merge) + new_kv_sim = F.cosine_similarity( + sdpa_kv_new[:, :, :, past_len : past_len + seq_len, :].flatten().float(), + trt_kv[:, :, :, past_len : past_len + seq_len, :].flatten().float(), + dim=0, + ).item() + + # Determine which kernel was used + kernel_type = "XQA (decode)" if seq_len == 1 else "FMHA (context)" + + # Print results + print(f" Kernel Used: {kernel_type}") + print(f" Attention Output: cosine_similarity = {attn_sim:.6f}") + print(f" Updated KV Cache: cosine_similarity = {new_kv_sim:.6f}") + + # If there's past context, verify it's preserved in our main buffer + if has_past_context: + past_sim = F.cosine_similarity( + sdpa_kv_new[:, :, :, :past_len, :].flatten().float(), + trt_kv[:, :, :, :past_len, :].flatten().float(), + dim=0, + ).item() + print(f" Past KV Preserved: cosine_similarity = {past_sim:.6f}") + passed = attn_sim >= 0.99 and new_kv_sim >= 0.99 and past_sim >= 0.99 + else: + passed = attn_sim >= 0.99 and new_kv_sim >= 0.99 + + status = "PASS" if passed else "FAIL" + print(f" Result: {status}") + + return passed, attn_sim, new_kv_sim + + +# %% +# Main Execution +# -------------- + +if __name__ == "__main__": + print("\nCustom Attention Plugin - Correctness Validation") + + # Precompute RoPE + rope = precompute_rope(HEAD_DIM, KV_CACHE_CAPACITY) + + # Create models + print("\nCreating models...") + sdpa_model = SDPAModel().to(DEVICE).to(DTYPE).eval() + plugin_model = PluginModel().to(DEVICE).to(DTYPE).eval() + + # Share weights + plugin_model.qkv.weight.data.copy_(sdpa_model.qkv.weight.data) + plugin_model.qkv.bias.data.copy_(sdpa_model.qkv.bias.data) + plugin_model.out.weight.data.copy_(sdpa_model.out.weight.data) + print("Weights shared between models") + + # Compile with Torch-TensorRT (with dynamic shapes for seq_len) + print("\nCompiling with Torch-TensorRT...") + x_example = torch.randn(BATCH_SIZE, 1, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + kv_example = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + ctx_example = torch.tensor([1], dtype=torch.int32, device=DEVICE) + + # Enable dynamic shapes for seq_len dimension + inputs_spec = [ + torch_tensorrt.Input( + min_shape=(BATCH_SIZE, 1, HIDDEN_DIM), + opt_shape=(BATCH_SIZE, 8, HIDDEN_DIM), + max_shape=(BATCH_SIZE, 32, HIDDEN_DIM), + dtype=DTYPE, + ), + kv_example, + ctx_example, + rope, + ] + + with torch_tensorrt.logging.errors(): + trt_model = torch_tensorrt.compile( + plugin_model, + inputs=inputs_spec, + enabled_precisions={torch.float32}, + use_explicit_typing=True, + use_fp32_acc=True, + min_block_size=1, + device=DEVICE, + ) + print("Compilation complete") + + # %% + # Run Test Cases + # -------------- + # Test all 4 combinations: {seq_len=1, seq_len>1} × {empty cache, with past} + + print("\nRunning Test Cases") + + results = [] + + # Test 1: Single token, empty cache (XQA kernel, cold start) + results.append( + test_case( + "Test 1: Single Token Generation (XQA) - Empty Cache", + seq_len=1, + has_past_context=False, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # Test 2: Single token, with past context (XQA kernel, typical decode) + results.append( + test_case( + "Test 2: Single Token Generation (XQA) - With Past Context", + seq_len=1, + has_past_context=True, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # Test 3: Multiple tokens, empty cache (FMHA kernel, prefill phase) + results.append( + test_case( + "Test 3: Context Processing (FMHA) - Empty Cache", + seq_len=16, + has_past_context=False, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # %% + # Multi-Step Generation Test + # --------------------------- + # Realistic test: Process initial context (FMHA), then generate tokens one by one (XQA) + # Note: With enable_delta_kv_output=1, we must merge delta KV into main buffer + + print("\nTest 4: Multi-Step Generation (FMHA -> XQA x 3)") + print("Simulating real LLM generation:") + print(" 1. Process initial prompt with FMHA (seq_len=16)") + print(" 2. Generate tokens one by one with XQA (seq_len=1)") + + # Step 1: Process initial prompt (FMHA) + initial_seq_len = 16 + x_init = torch.randn( + BATCH_SIZE, initial_seq_len, HIDDEN_DIM, dtype=DTYPE, device=DEVICE + ) + ctx_len_init = torch.tensor([initial_seq_len], dtype=torch.int32, device=DEVICE) + + sdpa_kv_multi = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + trt_kv_multi = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + + with torch.no_grad(): + sdpa_out_init, sdpa_kv_multi = sdpa_model( + x_init, sdpa_kv_multi, ctx_len_init, rope + ) + trt_out_init, trt_kv_delta = trt_model(x_init, trt_kv_multi, ctx_len_init, rope) + + # Merge delta KV into main buffer (context phase: delta has shape [B, 2, H, seq_len, D]) + delta_len = trt_kv_delta.shape[3] + trt_kv_multi[:, :, :, :delta_len, :] = trt_kv_delta + + init_sim = F.cosine_similarity( + sdpa_out_init.flatten().float(), trt_out_init.flatten().float(), dim=0 + ).item() + + print(f"\nStep 1: Initial prompt (FMHA, seq_len={initial_seq_len})") + print(f" Similarity: {init_sim:.6f}") + + # Step 2: Generate tokens one by one (XQA) + num_gen_tokens = 3 + all_passed_multi = init_sim > 0.99 + current_pos = initial_seq_len # Track current position in KV cache + + for gen_step in range(num_gen_tokens): + current_ctx_len = initial_seq_len + gen_step + 1 + x_gen = torch.randn(BATCH_SIZE, 1, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + ctx_len_gen = torch.tensor([current_ctx_len], dtype=torch.int32, device=DEVICE) + + with torch.no_grad(): + sdpa_out_gen, sdpa_kv_multi = sdpa_model( + x_gen, sdpa_kv_multi, ctx_len_gen, rope + ) + trt_out_gen, trt_kv_delta = trt_model( + x_gen, trt_kv_multi, ctx_len_gen, rope + ) + + # Merge delta KV into main buffer (generation phase: delta has shape [B, 2, H, 1, D]) + trt_kv_multi[:, :, :, current_pos : current_pos + 1, :] = trt_kv_delta + current_pos += 1 + + gen_sim = F.cosine_similarity( + sdpa_out_gen.flatten().float(), trt_out_gen.flatten().float(), dim=0 + ).item() + + kv_sim_gen = F.cosine_similarity( + sdpa_kv_multi[:, :, :, :current_ctx_len, :].flatten().float(), + trt_kv_multi[:, :, :, :current_ctx_len, :].flatten().float(), + dim=0, + ).item() + + passed = gen_sim > 0.99 and kv_sim_gen > 0.99 + all_passed_multi = all_passed_multi and passed + + print(f"\nStep {gen_step + 2}: Generate token {gen_step + 1} (XQA, seq_len=1)") + print(f" Attn similarity: {gen_sim:.6f}") + print(f" KV similarity: {kv_sim_gen:.6f}") + + results.append( + ( + all_passed_multi, + 1.0 if all_passed_multi else 0.0, + 1.0 if all_passed_multi else 0.0, + ) + ) + + print(f"\nResult: {'PASS - All steps matched!' if all_passed_multi else 'FAIL'}") + + # %% + # Summary + # ------- + + print("\nSUMMARY") + + test_names = [ + "Test 1: XQA - Empty Cache", + "Test 2: XQA - With Past", + "Test 3: FMHA - Empty Cache", + "Test 4: Multi-Step (FMHA->XQA)", + ] + + for name, (passed, attn_sim, kv_sim) in zip(test_names, results): + status = "PASS" if passed else "FAIL" + print(f"{name}: {status}") + print(f" Attention: {attn_sim:.4f}, KV Cache: {kv_sim:.4f}") + + all_passed = all(r[0] for r in results) + + if all_passed: + print("SUCCESS: All tests passed!") + print("Both FMHA and XQA kernels work correctly") + print("KV cache management is accurate") + print("Perfect agreement with PyTorch SDPA (cosine similarity >= 0.99)") + else: + print("FAILURE: Some tests failed") diff --git a/examples/dynamo/end_to_end_llm_generation_example.py b/examples/dynamo/end_to_end_llm_generation_example.py new file mode 100644 index 0000000000..dd2a23f721 --- /dev/null +++ b/examples/dynamo/end_to_end_llm_generation_example.py @@ -0,0 +1,404 @@ +""" +End-to-End LLM Generation Example with TensorRT Attention Plugin + +This example demonstrates how to use the TensorRT attention plugin for +efficient LLM inference with KV caching. + +The plugin utilities are shared with tools/llm/run_llm.py for consistency. + +This implementation has been verified with TensorRT-Edge-LLM release 0.4.0. +The plugin source code is available at: +https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime +""" + +import os +import sys +import time + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +# Add tools/llm to path for shared utilities +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../tools/llm")) + +from plugin_utils import ( + LLMPluginWrapper, + PluginAttention, + benchmark_plugin_generation, + compile_plugin_model, + create_kv_caches, + generate_with_plugin, + get_plugin_config, + get_plugin_rope_cache, + load_plugin, + register_plugin_op, + replace_attention_with_plugin, + set_plugin_config_from_model, +) + +# Configuration +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +MAX_SEQ_LEN = 2048 +DTYPE = torch.float16 +DEVICE = torch.device("cuda:0") + +# Load the plugin +load_plugin() +register_plugin_op() + + +# ----------------------------------------------------------------------------- +# Backward Compatibility Exports +# ----------------------------------------------------------------------------- + +# These are exported for backward compatibility with any code that imports +# from this module directly. + +# Re-export Qwen2Wrapper as an alias for LLMPluginWrapper +Qwen2Wrapper = LLMPluginWrapper + + +# Re-export replace_attention for backward compatibility +def replace_attention(model, config): + """ + Replace attention modules with plugin attention. + + This is a backward-compatible wrapper around replace_attention_with_plugin. + """ + return replace_attention_with_plugin(model, config, MAX_SEQ_LEN, DEVICE, DTYPE) + + +def compile_model(model, input_ids, position_ids, kv_caches, ctx_len): + """ + Compile a model for TensorRT inference. + + This is a backward-compatible wrapper that extracts config from the model. + """ + # Get config from the wrapped model + if hasattr(model, "model"): + inner_model = model.model + if hasattr(inner_model, "config"): + config = inner_model.config + else: + config = inner_model.model.config + else: + config = model.config + + return compile_plugin_model(model, config, MAX_SEQ_LEN, DEVICE, DTYPE) + + +# Global config for backward compatibility with converter +TARGET_CONFIG = None + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def apply_repetition_penalty(logits, generated_ids, penalty): + """Apply repetition penalty to logits.""" + if penalty == 1.0: + return logits + + score = torch.gather(logits, 1, generated_ids) + score = torch.where(score < 0, score * penalty, score / penalty) + logits.scatter_(1, generated_ids, score) + return logits + + +# ----------------------------------------------------------------------------- +# Benchmarking +# ----------------------------------------------------------------------------- + + +def benchmark_generation(model_func, isl, osl, config, run_name="Model"): + """ + Benchmark generation with the plugin model. + + This wraps benchmark_plugin_generation for backward compatibility. + """ + return benchmark_plugin_generation( + model_func, config, isl, osl, MAX_SEQ_LEN, DEVICE, DTYPE, run_name + ) + + +def run_pytorch_benchmark_manual(model, config, isl, osl): + """Run PyTorch benchmark with manual loop (no KV cache).""" + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=DEVICE) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + with torch.no_grad(): + generated_ids = input_ids + + for _ in range(osl): + outputs = model(generated_ids, use_cache=False) + next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"PyTorch (Manual - No Cache) | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms + + +def run_pytorch_benchmark_generate(model, config, isl, osl): + """Run PyTorch benchmark with model.generate() API.""" + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=DEVICE) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + with torch.no_grad(): + _ = model.generate( + input_ids, + max_new_tokens=osl, + min_new_tokens=osl, + do_sample=False, + use_cache=True, + pad_token_id=config.eos_token_id, + ) + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"PyTorch (Generate) | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms + + +def generate_reference(model, tokenizer, prompt, max_new_tokens=20): + """ + Generate reference output with PyTorch (greedy, no cache). + """ + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) + generated_ids = input_ids + + repetition_penalty = getattr(model.generation_config, "repetition_penalty", 1.0) + print( + f"DEBUG: Using repetition_penalty={repetition_penalty} for Reference Generation" + ) + + for _ in range(max_new_tokens): + current_seq_len = generated_ids.shape[1] + position_ids = torch.arange( + current_seq_len, dtype=torch.long, device=DEVICE + ).unsqueeze(0) + + outputs = model(generated_ids, position_ids=position_ids, use_cache=False) + next_token_logits = outputs.logits[:, -1, :] + + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + if next_token.item() == tokenizer.eos_token_id: + break + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + return tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + +def verify_output(trt_model_func, model_pytorch, tokenizer, prompt, max_new_tokens=20): + """Verify TensorRT output matches PyTorch reference.""" + print(f"\nPrompt: '{prompt}'") + + # 1. PyTorch Reference Generation + print("\n=== PyTorch Reference Generation ===") + inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) + input_ids = inputs.input_ids + + with torch.no_grad(): + pyt_outputs = generate_reference( + model_pytorch, tokenizer, prompt, max_new_tokens=30 + ) + print(f"PyTorch Reference Text Output: {pyt_outputs}") + + with torch.no_grad(): + pyt_outputs_generate_ids = model_pytorch.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + pad_token_id=tokenizer.eos_token_id, + ) + pyt_outputs_generate_text = tokenizer.decode( + pyt_outputs_generate_ids[0], skip_special_tokens=True + ) + print(f"PyTorch Generate Text Output: {pyt_outputs_generate_text}") + + pyt_text = pyt_outputs + print(f"PyTorch Output: {pyt_text}") + + # 2. TensorRT Plugin Generation + print("\n=== TensorRT Plugin Generation ===") + + repetition_penalty = getattr( + model_pytorch.generation_config, "repetition_penalty", 1.0 + ) + print( + f"DEBUG: Using repetition_penalty={repetition_penalty} for TensorRT Generation" + ) + + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len, dtype=torch.long, device=DEVICE).unsqueeze(0) + + config = model_pytorch.config + kv_caches = create_kv_caches(config, MAX_SEQ_LEN, 1, DEVICE, DTYPE) + + generated_ids = input_ids + + # Prefill + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=DEVICE) + logits, kv_caches_delta = trt_model_func( + input_ids, position_ids, kv_caches, ctx_len + ) + + for i, delta in enumerate(kv_caches_delta): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token_logits = logits[:, -1, :] + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Decode + cur_pos = seq_len + + if next_token.item() != tokenizer.eos_token_id: + for _ in range(max_new_tokens - 1): + input_ids_step = next_token + position_ids_step = torch.tensor( + [[cur_pos]], dtype=torch.long, device=DEVICE + ) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=DEVICE) + + logits, kv_caches_delta = trt_model_func( + input_ids_step, position_ids_step, kv_caches, ctx_len_step + ) + + for i, delta in enumerate(kv_caches_delta): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token_logits = logits[:, -1, :] + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + if next_token.item() == tokenizer.eos_token_id: + break + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + cur_pos += 1 + + trt_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + print(f"TensorRT Output: {trt_text}") + + # 3. Comparison + print("\n=== Comparison ===") + if pyt_text == trt_text: + print("SUCCESS: Outputs match exactly!") + else: + print("FAILURE: Outputs differ.") + print(f"PyTorch: {pyt_text}") + print(f"TensorRT: {trt_text}") + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + torch.manual_seed(42) + + print(f"Loading {MODEL_NAME}...") + config = AutoConfig.from_pretrained(MODEL_NAME) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + + # Set global config for backward compatibility + # Note: TARGET_CONFIG is defined at module level for backward compatibility + globals()["TARGET_CONFIG"] = config + + # Set plugin config + set_plugin_config_from_model(config, MAX_SEQ_LEN) + + # 1. PyTorch Model + model_pytorch = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + ).to(DEVICE) + model_pytorch.eval() + + # 2. TensorRT Plugin Model + model_trt = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to( + DEVICE + ) + model_trt.eval() + + model_trt = replace_attention(model_trt, config) + wrapper = LLMPluginWrapper(model_trt) + + # Compilation + print("Compiling TensorRT model...") + + dummy_input_ids = torch.tensor([[1, 2, 3]], device=DEVICE) + dummy_pos_ids = torch.tensor([[0, 1, 2]], device=DEVICE) + dummy_ctx_len = torch.tensor([3], dtype=torch.int32, device=DEVICE) + dummy_kvs = create_kv_caches(config, MAX_SEQ_LEN, 1, DEVICE, DTYPE) + + trt_model_func = compile_model( + wrapper, dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len + ) + + # 3. Verification + print("\n=== Verifying Output Accuracy ===") + verify_output( + trt_model_func, + model_pytorch, + tokenizer, + "What is parallel programming?", + max_new_tokens=30, + ) + + # 4. Benchmarks + benchmarks = [ + (128, 128), + (256, 128), + (512, 256), + ] + + print("\n=== Starting Benchmarks ===") + print(f"Device: {torch.cuda.get_device_name(0)}") + + for isl, osl in benchmarks: + print("-" * 60) + # PyTorch Manual Loop + run_pytorch_benchmark_manual(model_pytorch, config, isl, osl) + + # PyTorch Generate API + run_pytorch_benchmark_generate(model_pytorch, config, isl, osl) + + # TensorRT + benchmark_generation(trt_model_func, isl, osl, config, run_name="TensorRT") diff --git a/examples/dynamo/end_to_end_vit_attention_plugin_example.py b/examples/dynamo/end_to_end_vit_attention_plugin_example.py new file mode 100644 index 0000000000..777e058fbb --- /dev/null +++ b/examples/dynamo/end_to_end_vit_attention_plugin_example.py @@ -0,0 +1,765 @@ +""" +End-to-End ViT Attention Plugin Example + +This mirrors the structure of end_to_end_llm_generation_example.py, but for the +visual towers used by production multimodal models: + +1. load the TensorRT-Edge-LLM plugin shared library +2. register the PyTorch custom op and Torch-TensorRT converter +3. load a PyTorch reference model +4. load a second model, replace visual attention with plugin attention +5. wrap and compile the visual model +6. verify output shape and run a small latency benchmark + +Supported end-to-end paths: +- Qwen2.5-VL visual tower +- Meta Llama 3.2 Vision / HuggingFace Mllama vision tower +- NVIDIA GR00T N1.5 configured Eagle/SigLIP2 vision tower + +By default this script runs all benchmarks. +""" + +import os +import sys +import json + +import torch +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import RepositoryNotFoundError +from transformers import ( + AutoConfig, + AutoModel, + MllamaVisionModel, + Qwen2_5_VLForConditionalGeneration, +) + +# Add tools/llm to path for shared plugin utilities, matching the LLM example style. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../tools/llm")) + +from plugin_utils_vit import ( + ViTPluginWrapper, + VIT_INPUT_CONTRACT_NATIVE, + VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO, + VIT_INPUT_CONTRACT_WINDOWED_ROPE, + compile_vit_plugin_model, + get_vit_plugin_config, + load_plugin, + measure_vit_latency, + register_vit_plugin_op, + replace_vit_attention_with_plugin, + set_vit_plugin_config, +) + +# Configuration +QWEN_MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" +MLLAMA_MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision" +GROOT_MODEL_NAME = "nvidia/GR00T-N1-2B" + +DTYPE = torch.float16 +DEVICE = torch.device("cuda:0") + +# Optional model-agnostic remaps for policy configs that reference placeholder +# or internal backbone paths. Add entries as: +# "": "" +BACKBONE_OVERRIDES = { + # Example for nvidia/GR00T-N1-2B. + # TODO: This public SigLIP checkpoint is only a smoke-test stand-in for exercising + # policy -> backbone resolution. It is not the real GR00T Eagle backbone as it is not publicly accessible. + "$GR00T_BACKBONE_PATH/eagle2_hg_model": "google/siglip-base-patch16-224", +} + +# One image grid used for the compile example: t, h, w. +# Qwen2.5-VL visual input is already a patch-vector tensor: [t*h*w, patch_dim]. +IMAGE_GRID_THW = (1, 8, 16) + +# Load the plugin and register the op/converter path. +load_plugin() +register_vit_plugin_op() + +# Global config for compatibility with converter-style imports. +TARGET_CONFIG = None + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + +def get_vision_config(config): + """Return the vision config from a top-level or vision-only config.""" + if isinstance(config, dict): + if "vision_config" in config: + return config["vision_config"] + if "visual" in config: + return config["visual"] + if "attention_heads" in config and "max_num_tiles" in config: + return config + if hasattr(config, "vision_config"): + return config.vision_config + if hasattr(config, "visual"): + return config.visual + if hasattr(config, "attention_heads") and hasattr(config, "max_num_tiles"): + return config + raise ValueError("Cannot find vision config") + + +def expand_backbone_model_name(backbone_name): + """ + Resolve nested backbone names from policy configs. + + A policy config can point to a public HF repo, a local absolute path, an + env-var based path, or a placeholder path that this example remaps through + BACKBONE_OVERRIDES. GR00T-N1-2B, for example, uses + "$GR00T_BACKBONE_PATH/eagle2_hg_model"; add that string to + BACKBONE_OVERRIDES when the backbone is available locally. + """ + resolved = BACKBONE_OVERRIDES.get(backbone_name, os.path.expandvars(backbone_name)) + if "$" in resolved: + raise RuntimeError( + f"Backbone checkpoint path '{backbone_name}' contains an unresolved " + "environment variable. Add it to BACKBONE_OVERRIDES, set the env var, " + "or replace the config value with a local path/public Hugging Face repo id." + ) + if os.path.isabs(resolved) and not os.path.exists(os.path.join(resolved, "config.json")): + raise RuntimeError( + f"Backbone checkpoint path '{resolved}' does not contain config.json." + ) + return resolved + + +def get_backbone_model_name(config): + """ + Return a nested vision/VLM backbone path from policy-style configs. + + In this file, "backbone" means the reusable perception/VLM feature extractor + inside a larger model. GR00T is a robotics policy, but the ViT attention we + want to compile lives in its Eagle/SigLIP2 backbone, not in the action head. + Different GR00T releases expose that backbone differently: + + - N1.5: backbone_cfg.eagle_path -> an Eagle checkpoint repo/path + - N1-2B: backbone_cfg.model_name -> often an env-var based local path + - N1.6/N1.7/H: backbone_model_type="eagle" without a loadable path + """ + if isinstance(config, dict): + backbone_cfg = config.get("backbone_cfg") + backbone_model_type = config.get("backbone_model_type") + else: + backbone_cfg = getattr(config, "backbone_cfg", None) + backbone_model_type = getattr(config, "backbone_model_type", None) + if backbone_cfg is None: + if backbone_model_type is not None: + raise ValueError( + f"Config declares backbone_model_type='{backbone_model_type}' " + "but does not include a loadable backbone checkpoint path." + ) + return None + + if isinstance(backbone_cfg, dict): + backbone_name = backbone_cfg.get("eagle_path") or backbone_cfg.get("model_name") + else: + backbone_name = ( + getattr(backbone_cfg, "eagle_path", None) + or getattr(backbone_cfg, "model_name", None) + ) + if backbone_name is None: + return None + return expand_backbone_model_name(backbone_name) + + +def load_config_or_raise(model_name, *, source_model_name=None): + """Load a config and raise a concise error when a nested checkpoint is missing.""" + try: + return AutoConfig.from_pretrained(model_name, trust_remote_code=True) + except (OSError, RepositoryNotFoundError) as exc: + source = f" referenced by {source_model_name}" if source_model_name else "" + raise RuntimeError( + f"Checkpoint '{model_name}'{source} is not available from Hugging Face. " + "Confirm the repo id is public, authenticate with a token that has access, " + "or provide a local checkpoint path." + ) + + +def resolve_vision_config(model_name): + """ + Load the config that owns the visual transformer. + + Standard VLMs expose a vision_config directly. Robotics policy checkpoints + such as GR00T can instead point at a reusable Eagle/SigLIP2 backbone; in + that case we lower the top-level policy config to the backbone config before + extracting the vision settings. + """ + config = load_config_or_raise(model_name) + backbone_name = get_backbone_model_name(config) + if backbone_name is None: + return model_name, get_vision_config(config) + + backbone_config = load_config_or_raise( + backbone_name, + source_model_name=model_name, + ) + return backbone_name, get_vision_config(backbone_config) + + +def resolve_policy_backbone_config(model_name): + """ + Resolve configs that AutoConfig cannot load because they are policy models. + + GR00T-N1.5, for example, has model_type="gr00t_n1_5", which vanilla + Transformers does not recognize. We therefore read config.json as plain JSON, + find the nested Eagle/SigLIP2 backbone, and then use AutoConfig on that + backbone because it is the model that owns the visual transformer layers. + """ + config = load_hf_config_dict(model_name) + backbone_name = get_backbone_model_name(config) + if backbone_name is None: + return model_name, get_vision_config(config) + + backbone_config = load_config_or_raise( + backbone_name, + source_model_name=model_name, + ) + return backbone_name, get_vision_config(backbone_config) + +def get_visual_num_patches(vision_config, image_grid_thw=None, include_cls=True): + """Return the visual-token extent used by the ViT plugin config.""" + if image_grid_thw is not None: + grid_t, grid_h, grid_w = image_grid_thw + return grid_t * grid_h * grid_w + + image_size = vision_config.image_size + patch_size = vision_config.patch_size + if isinstance(image_size, (tuple, list)): + image_h, image_w = image_size + else: + image_h = image_w = image_size + if isinstance(patch_size, (tuple, list)): + patch_h, patch_w = patch_size + else: + patch_h = patch_w = patch_size + + num_patches = (image_h // patch_h) * (image_w // patch_w) + if include_cls: + num_patches += 1 + if hasattr(vision_config, "max_num_tiles"): + target_length = num_patches + (8 - (num_patches % 8)) % 8 + return vision_config.max_num_tiles * target_length + return num_patches + +def set_plugin_config_from_vision_config(vision_config, num_patches): + """Set ViT plugin fields from a generic vision config.""" + num_heads = ( + getattr(vision_config, "num_heads", None) + or getattr(vision_config, "num_attention_heads", None) + or getattr(vision_config, "attention_heads", None) + ) + if num_heads is None: + raise ValueError("Cannot infer number of attention heads from vision config") + + head_dim = getattr(vision_config, "head_dim", None) + if head_dim is None: + head_dim = vision_config.hidden_size // num_heads + + set_vit_plugin_config( + num_attention_heads=num_heads, + head_dim=head_dim, + num_patches=num_patches, + ) + + +def load_hf_config_dict(model_name): + """Load config.json directly for repos without an AutoConfig registration.""" + config_path = hf_hub_download(model_name, "config.json") + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def get_vision_model(model): + """Return the visual tower from a multimodal model or the model itself.""" + if hasattr(model, "vision_model"): + return model.vision_model + if hasattr(model, "visual"): + return model.visual + return model + +def create_windowed_rope_metadata(visual_model, pixel_values, image_grid_thw): + """ + Lower windowed visual-attention metadata to raw tensors. + + Qwen-VL derives RoPE positions and window boundaries from image_grid_thw + using Python list/index logic that is awkward for torch.export. We compute + it once here and pass the compiled wrapper only tensor inputs. + """ + with torch.no_grad(): + rotary_pos_emb = visual_model.rot_pos_emb(image_grid_thw) + window_index, cu_window_seqlens = visual_model.get_window_index(image_grid_thw) + + window_index = window_index.to(device=DEVICE, dtype=torch.long) + reverse_window_index = torch.argsort(window_index) + + seq_len = pixel_values.shape[0] + attention_mask = torch.zeros(1, seq_len, seq_len, dtype=DTYPE, device=DEVICE) + window_attention_mask = torch.full( + (1, seq_len, seq_len), + torch.finfo(DTYPE).min, + dtype=DTYPE, + device=DEVICE, + ) + if isinstance(cu_window_seqlens, torch.Tensor): + cu_window_seqlens = cu_window_seqlens.to(device="cpu", dtype=torch.long).tolist() + max_window_seq_len = max(end - start for start, end in zip(cu_window_seqlens[:-1], cu_window_seqlens[1:])) + for start, end in zip(cu_window_seqlens[:-1], cu_window_seqlens[1:]): + window_attention_mask[:, start:end, start:end] = 0 + + return { + "rotary_pos_emb": rotary_pos_emb.to(device=DEVICE), + "attention_mask": attention_mask, + "window_attention_mask": window_attention_mask, + "cu_window_seqlens": torch.tensor(cu_window_seqlens, dtype=torch.int32, device=DEVICE), + "max_window_seq_len": max_window_seq_len, + "window_index": window_index, + "reverse_window_index": reverse_window_index, + } + +def create_patch_vector_inputs(vision_config): + """ + Create native PyTorch args for flattened patch-vector visual input. + + Input style: + patch_vector_inputs -> [num_patches, patch_vector_dim] + + These are the public inputs for models like Qwen-VL: + visual(pixel_values, image_grid_thw). + """ + grid_t, grid_h, grid_w = IMAGE_GRID_THW + num_patches = grid_t * grid_h * grid_w + patch_dim = ( + vision_config.in_chans + * vision_config.temporal_patch_size + * vision_config.patch_size + * vision_config.patch_size + ) + pixel_values = torch.randn( + num_patches, + patch_dim, + dtype=DTYPE, + device=DEVICE, + ) + image_grid_thw = torch.tensor([IMAGE_GRID_THW], dtype=torch.long, device=DEVICE) + return pixel_values, image_grid_thw + + +def create_tiled_vision_inputs(vision_config): + """ + Create native PyTorch args for a tiled visual input. + + Input style: + tiled_image_inputs -> [B, images, tiles, C, H, W] + + These are the public HuggingFace Mllama/Llama Vision visual inputs. + - pixel_values: [batch, max_num_images, max_num_tiles, channels, H, W] + - aspect_ratio_ids: [batch, max_num_images] + - aspect_ratio_mask: [batch, max_num_images, max_num_tiles] + """ + batch_size = 1 + max_num_images = 1 + num_tiles = vision_config.max_num_tiles + pixel_values = torch.randn( + batch_size, + max_num_images, + num_tiles, + vision_config.num_channels, + vision_config.image_size, + vision_config.image_size, + dtype=DTYPE, + device=DEVICE, + ) + aspect_ratio_ids = torch.ones( + batch_size, + max_num_images, + dtype=torch.long, + device=DEVICE, + ) + aspect_ratio_mask = torch.ones( + batch_size, + max_num_images, + num_tiles, + dtype=torch.long, + device=DEVICE, + ) + aspect_ratio_mask[:, :, 1:] = 0 + return pixel_values, aspect_ratio_ids, aspect_ratio_mask + + +def create_raw_image_inputs(vision_config): + """ + Create native PyTorch args for raw image visual input. + + Input style: + raw_image_inputs -> [B, C, H, W] + """ + image_size = vision_config.image_size + if isinstance(image_size, (tuple, list)): + image_h, image_w = image_size + else: + image_h = image_w = image_size + num_channels = getattr(vision_config, "num_channels", 3) + pixel_values = torch.randn( + 1, + num_channels, + image_h, + image_w, + dtype=DTYPE, + device=DEVICE, + ) + return (pixel_values,) + +def create_tiled_aspect_ratio_attention_mask(vision_config, aspect_ratio_mask): + """ + Create an expanded additive mask from tiled aspect-ratio validity metadata. + + This lowers Mllama's compact tile-validity mask into the raw attention mask + tensor consumed by the export-friendly plugin wrapper. + + This mirrors HuggingFace _prepare_aspect_ratio_attention_mask but avoids + compiling its in-place padding-mask update through TensorRT. + """ + batch_size, max_num_images, max_num_tiles = aspect_ratio_mask.shape + flat_mask = aspect_ratio_mask.reshape(batch_size * max_num_images, max_num_tiles) + num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 + 1 + target_length = num_patches + (8 - (num_patches % 8)) % 8 + + attention_mask = flat_mask.view( + batch_size * max_num_images, + max_num_tiles, + 1, + 1, + ).to(DTYPE) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + pad_patches = target_length - num_patches + if pad_patches > 0: + attention_mask[:, :, -pad_patches:] = 0 + + attention_mask = 1 - attention_mask + attention_mask = attention_mask.reshape( + batch_size * max_num_images, + max_num_tiles * target_length, + 1, + ) + attention_mask = ( + attention_mask + @ attention_mask.transpose(-1, -2) + * torch.finfo(DTYPE).min + ) + return attention_mask.unsqueeze(1).to(device=DEVICE, dtype=DTYPE) + +def create_visual_plugin_metadata(model_type, visual_model, vision_config, pytorch_args): + """ + Create raw plugin-only metadata tensors for a visual model. + + The public PyTorch visual APIs are intentionally model-specific. This + function lowers those public inputs into the tensor-only metadata expected + by ViTPluginWrapper. + """ + if model_type == "qwen_vl": + pixel_values, image_grid_thw = pytorch_args + return create_windowed_rope_metadata( + visual_model, + pixel_values, + image_grid_thw, + ) + if model_type == "mllama": + _, _, aspect_ratio_mask = pytorch_args + return { + "attention_mask": create_tiled_aspect_ratio_attention_mask( + vision_config, + aspect_ratio_mask, + ) + } + if model_type in ("raw_image", "groot"): + return {} + raise ValueError(f"Unsupported visual model type: {model_type}") + +def create_visual_inputs(model_type, visual_model, vision_config): + """ + Create native PyTorch args and raw plugin kwargs for a visual model. + + Input styles this example currently covers: + - raw_image_inputs -> [B, C, H, W] + - patch_vector_inputs -> [num_patches, patch_vector_dim] + - tiled_image_inputs -> [B, images, tiles, C, H, W] + + Returns: + - pytorch_args: inputs for the original HuggingFace visual tower + - plugin_kwargs: lowered raw tensors for the compiled plugin wrapper + """ + if model_type == "qwen_vl": + pytorch_args = create_patch_vector_inputs(vision_config) + elif model_type == "mllama": + pytorch_args = create_tiled_vision_inputs(vision_config) + elif model_type in ("raw_image", "groot"): + pytorch_args = create_raw_image_inputs(vision_config) + else: + raise ValueError(f"Unsupported visual model type: {model_type}") + + plugin_kwargs = { + "pixel_values": pytorch_args[0], + **create_visual_plugin_metadata( + model_type, + visual_model, + vision_config, + pytorch_args, + ), + } + if model_type == "mllama": + plugin_kwargs["aspect_ratio_ids"] = pytorch_args[1] + return pytorch_args, plugin_kwargs + +def get_last_hidden_state(output): + """Normalize HF model outputs to a tensor for verification and benchmark.""" + if hasattr(output, "last_hidden_state"): + return output.last_hidden_state + if isinstance(output, (tuple, list)): + return output[0] + return output + +def benchmark_visual(visual_model, input_kwargs, run_name="TensorRT"): + """Benchmark the compiled visual model.""" + pixel_values = input_kwargs["pixel_values"] + if pixel_values.dim() == 2: + input_desc = f"Patches: {pixel_values.shape[0]}" + elif pixel_values.dim() == 6: + input_desc = ( + f"Images: {pixel_values.shape[0]} | " + f"Tiles: {pixel_values.shape[2]}" + ) + else: + input_desc = f"Input shape: {tuple(pixel_values.shape)}" + + def forward(): + with torch.no_grad(): + return get_last_hidden_state(visual_model(**input_kwargs)) + + mean_ms, std_ms, median_ms = measure_vit_latency(forward) + print( + f"{run_name} | {input_desc} | " + f"Mean: {mean_ms:.3f} ms | Median: {median_ms:.3f} ms | Std: {std_ms:.3f} ms" + ) + return mean_ms + + +def verify_visual_output(model_name, pytorch_model, pytorch_args, trt_model, trt_kwargs): + """Compare PyTorch and TensorRT-plugin visual outputs.""" + print(f"\n=== Verifying {model_name} Visual Output ===") + + with torch.no_grad(): + pyt_output = get_last_hidden_state(pytorch_model(*pytorch_args)) + trt_output = get_last_hidden_state(trt_model(**trt_kwargs)) + + print(f"PyTorch output shape: {tuple(pyt_output.shape)}") + print(f"TensorRT output shape: {tuple(trt_output.shape)}") + + if pyt_output.shape == trt_output.shape: + print("SUCCESS: Output shapes match.") + else: + print("FAILURE: Output shapes differ.") + return + + max_abs_diff = (pyt_output - trt_output).abs().max().item() + print(f"Max absolute difference: {max_abs_diff:.6f}") + print( + "Note: small numerical differences are expected across PyTorch SDPA, " + "TensorRT, and the custom CUDA plugin." + ) + + +def run_qwen(model_name): + print(f"Loading {model_name}...") + config = AutoConfig.from_pretrained(model_name) + vision_config = get_vision_config(config) + + globals()["TARGET_CONFIG"] = vision_config + set_plugin_config_from_vision_config( + vision_config, + get_visual_num_patches(vision_config, IMAGE_GRID_THW), + ) + print(f"Plugin config: {get_vit_plugin_config()}") + + model_pytorch = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, + torch_dtype=DTYPE, + ).to(DEVICE) + model_pytorch.eval() + + model_trt = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, + torch_dtype=DTYPE, + ).to(DEVICE) + model_trt.eval() + + visual_trt = replace_vit_attention_with_plugin(model_trt.visual, vision_config) + pytorch_args, plugin_kwargs = create_visual_inputs( + "qwen_vl", + model_trt.visual, + vision_config, + ) + max_window_seq_len = plugin_kwargs.pop("max_window_seq_len") + wrapper = ViTPluginWrapper( + visual_trt, + input_contract=VIT_INPUT_CONTRACT_WINDOWED_ROPE, + max_window_seq_len=max_window_seq_len, + ).eval() + + print("Compiling TensorRT visual model...") + trt_visual_model = compile_vit_plugin_model( + wrapper, + (), + DEVICE, + example_kwargs=plugin_kwargs, + dynamic_shapes={name: {} for name in plugin_kwargs}, + ) + + verify_visual_output( + "Qwen2.5-VL", + model_pytorch.visual, + pytorch_args, + trt_visual_model, + plugin_kwargs, + ) + + print("\n=== Starting Qwen2.5-VL Visual Benchmark ===") + print(f"Device: {torch.cuda.get_device_name(0)}") + benchmark_visual(trt_visual_model, plugin_kwargs, run_name="TensorRT") + + +def run_mllama(model_name): + print(f"Loading {model_name}...") + config = AutoConfig.from_pretrained(model_name) + vision_config = get_vision_config(config) + + globals()["TARGET_CONFIG"] = vision_config + set_plugin_config_from_vision_config( + vision_config, + get_visual_num_patches(vision_config), + ) + print(f"Plugin config: {get_vit_plugin_config()}") + + visual_pytorch = MllamaVisionModel.from_pretrained( + model_name, + torch_dtype=DTYPE, + ).to(DEVICE) + visual_pytorch.eval() + + visual_trt = MllamaVisionModel.from_pretrained( + model_name, + torch_dtype=DTYPE, + ).to(DEVICE) + visual_trt.eval() + + visual_trt = replace_vit_attention_with_plugin( + visual_trt, + vision_config, + ) + wrapper = ViTPluginWrapper( + visual_trt, + input_contract=VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO, + ).eval() + + pytorch_args, plugin_kwargs = create_visual_inputs( + "mllama", + visual_trt, + vision_config + ) + + print("Compiling TensorRT Llama Vision/Mllama visual model...") + trt_visual_model = compile_vit_plugin_model( + wrapper, + (), + DEVICE, + example_kwargs=plugin_kwargs, + dynamic_shapes={name: {} for name in plugin_kwargs}, + ) + + verify_visual_output( + "Llama 3.2 Vision/Mllama", + visual_pytorch, + pytorch_args, + trt_visual_model, + plugin_kwargs, + ) + + print("Starting Llama 3.2 Vision/Mllama Visual Benchmark...") + print(f"Device: {torch.cuda.get_device_name(0)}") + benchmark_visual(trt_visual_model, plugin_kwargs, run_name="TensorRT") + + +def run_groot(model_name): + print(f"Loading {model_name}...") + backbone_name, vision_config = resolve_policy_backbone_config(model_name) + print(f"Using GR00T visual backbone: {backbone_name}") + + globals()["TARGET_CONFIG"] = vision_config + set_plugin_config_from_vision_config( + vision_config, + get_visual_num_patches(vision_config, include_cls=False), + ) + print(f"Plugin config: {get_vit_plugin_config()}") + + model_pytorch = AutoModel.from_pretrained( + backbone_name, + trust_remote_code=True, + torch_dtype=DTYPE, + ).to(DEVICE) + model_pytorch.eval() + + model_trt = AutoModel.from_pretrained( + backbone_name, + trust_remote_code=True, + torch_dtype=DTYPE, + ).to(DEVICE) + model_trt.eval() + + visual_pytorch = get_vision_model(model_pytorch) + visual_trt = get_vision_model(model_trt) + visual_trt = replace_vit_attention_with_plugin(visual_trt, vision_config) + wrapper = ViTPluginWrapper( + visual_trt, + input_contract=VIT_INPUT_CONTRACT_NATIVE, + ).eval() + + pytorch_args, plugin_kwargs = create_visual_inputs( + "groot", + visual_trt, + vision_config, + ) + + print("Compiling TensorRT GR00T/Eagle visual model...") + trt_visual_model = compile_vit_plugin_model( + wrapper, + (), + DEVICE, + example_kwargs=plugin_kwargs, + dynamic_shapes={name: {} for name in plugin_kwargs}, + ) + + verify_visual_output( + "GR00T/Eagle/SigLIP2", + visual_pytorch, + pytorch_args, + trt_visual_model, + plugin_kwargs, + ) + + print("Starting GR00T/Eagle/SigLIP2 Visual Benchmark...") + print(f"Device: {torch.cuda.get_device_name(0)}") + benchmark_visual(trt_visual_model, plugin_kwargs, run_name="TensorRT") + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + torch.manual_seed(42) + + run_qwen(QWEN_MODEL_NAME) + run_mllama(MLLAMA_MODEL_NAME) + run_groot(GROOT_MODEL_NAME) diff --git a/examples/dynamo/vit_attention_plugin_example.py b/examples/dynamo/vit_attention_plugin_example.py new file mode 100644 index 0000000000..9f3eeac13f --- /dev/null +++ b/examples/dynamo/vit_attention_plugin_example.py @@ -0,0 +1,649 @@ +""" +Example: ViT Attention TensorRT Plugin Integration +================================================ + +This example shows how to use custom TensorRT plugin that implements +ViT self-attention using a fused QKV input. + +The Python code demonstrates: +- loading the TensorRT plugin shared library +- registering a placeholder custom op for TorchDynamo conversion +- converting that custom op to a TensorRT plugin layer +- comparing a PyTorch reference self-attention implementation with the plugin model + +attention_plugin_example.py is model-agnostic around this invariant: +LLM decode attention = RoPE + KV cache + causal/GQA attention + +vit_attention_plugin_example.py is model-agnostic around this invariant: +ViT/VLA visual attention = full/window bidirectional attention over image tokens +""" + +import os +import ctypes +from dataclasses import dataclass +from typing import Callable, Dict + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from torch_tensorrt.dynamo.conversion import ConversionContext, dynamo_tensorrt_converter +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# Enable plugin debug logging +# ---------------------------- +#os.environ["TRT_EDGELLM_DEBUG_PLUGIN"] = "1" + +# Initialize CUDA and Load Plugin +# -------------------------------- +# CUDA must be initialized before loading the TensorRT plugin library +print("Initializing CUDA context...") +DEVICE = torch.device("cuda:0") +_ = torch.zeros(1, device=DEVICE) # Initialize CUDA +print(f"CUDA initialized on {DEVICE}\n") + +PLUGIN_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", +) +if not os.path.exists(PLUGIN_PATH): + PLUGIN_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", + ) +ctypes.CDLL(PLUGIN_PATH) +print(f"Loaded plugin: {PLUGIN_PATH}\n") + +# ----------------------------------------------------------------------------- +# Model configuration +# ----------------------------------------------------------------------------- + +BATCH_SIZE = 1 +NUM_HEADS = 4 +HEAD_DIM = 64 +EMBED_DIM = NUM_HEADS * HEAD_DIM +SEQ_LEN = 256 +WINDOW_SIZE = 64 +MASK_TYPE_DENSE = 0 +MASK_TYPE_CU_SEQLENS = 1 +DTYPE = torch.float16 +DEVICE = torch.device("cuda:0") + +def apply_qwen_vl_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """ + Apply Qwen-VL rotate-half RoPE to query/key tensors in [B, H, S, D]. + + Qwen-style RoPE stores the first half and second half of each head as the + two rotation components. This must match the C++ plugin's RoPE prepass for + the dense and cu_seqlens paths to compare against the same reference. + """ + cos = cos.view(1, 1, cos.shape[0], cos.shape[1]).to(dtype=x.dtype) + sin = sin.view(1, 1, sin.shape[0], sin.shape[1]).to(dtype=x.dtype) + half_dim = x.shape[-1] // 2 + x1 = x[..., :half_dim] + x2 = x[..., half_dim:] + rotated = torch.cat((-x2, x1), dim=-1) + return x * cos + rotated * sin + + +def create_identity_rope(seq_len: int, dtype: torch.dtype, device: torch.device): + """ + Create a no-op RoPE table. + + The plugin always receives cos/sin tensors, so plain ViT attention uses + cos=1 and sin=0 instead of a special "no RoPE" code path. + """ + cos = torch.ones(seq_len, HEAD_DIM, dtype=dtype, device=device) + sin = torch.zeros_like(cos) + return cos, sin + + +def create_qwen_vl_rope(seq_len: int, dtype: torch.dtype, device: torch.device): + """ + Build Qwen-VL-style RoPE tables with shape [S, D]. + + The embedding duplicates the frequency matrix across the head dimension so + it aligns with rotate-half RoPE: [freqs, freqs]. + """ + inv_freq = 1.0 / ( + 10000.0 ** (torch.arange(0, HEAD_DIM, 2, dtype=torch.float32, device=device) / HEAD_DIM) + ) + positions = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(positions, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype) + + +def create_zero_mask(batch_size: int, seq_len: int, dtype: torch.dtype, device: torch.device): + """ + Create a dense additive mask for full bidirectional attention. + + A value of zero means every token can attend to every other token. + """ + return torch.zeros(batch_size, seq_len, seq_len, dtype=dtype, device=device) + + +def create_window_mask(batch_size: int, seq_len: int, dtype: torch.dtype, device: torch.device, window_size: int = WINDOW_SIZE): + """ + Create a dense additive mask for independent local attention windows. + + Values outside each window are set to the minimum representable value for + the dtype, which behaves like -inf after the attention score add. + """ + mask = torch.full((batch_size, seq_len, seq_len), torch.finfo(dtype).min, dtype=dtype, device=device) + for start in range(0, seq_len, window_size): + end = min(start + window_size, seq_len) + mask[:, start:end, start:end] = 0 + return mask + + +def create_full_cu_seqlens(seq_len: int, device: torch.device): + """ + Create cu_seqlens for one packed full-attention segment. + + cu_seqlens is a prefix-sum array. [0, S] means one independent sequence + containing S tokens. + """ + return torch.tensor([0, seq_len], dtype=torch.int32, device=device) + + +def create_window_cu_seqlens(seq_len: int, device: torch.device, window_size: int = WINDOW_SIZE): + """ + Create cu_seqlens for independent contiguous attention windows. + + For S=256 and window=64 this returns [0, 64, 128, 192, 256], meaning the + plugin should launch four independent 64-token attention problems over one + packed QKV tensor. + """ + boundaries = list(range(0, seq_len, window_size)) + if boundaries[-1] != seq_len: + boundaries.append(seq_len) + return torch.tensor(boundaries, dtype=torch.int32, device=device) + + +# ----------------------------------------------------------------------------- +# PyTorch SDPA and TensorRT plugin models +# ----------------------------------------------------------------------------- + +class ViTSDPAModel(nn.Module): + """ + PyTorch SDPA reference model for ViT/VLA attention. + + Projection layout is selected explicitly: + - ``fused_qkv`` => one Linear produces Q, K, and V together + - ``separate_qkv`` => separate q_proj, k_proj, and v_proj layers + + This model is the correctness baseline. It uses torch.nn.functional + scaled_dot_product_attention with the same dense additive mask semantics + that the plugin should reproduce. + """ + + def __init__(self, projection_layout: str): + """ + Create projection layers for the requested model-style layout. + + The layer names intentionally match ViTPluginModel so load_state_dict() + can copy the exact same weights into the plugin model. + """ + super().__init__() + self.projection_layout = projection_layout + + if self.projection_layout == "fused_qkv": + self.qkv = nn.Linear(EMBED_DIM, 3 * EMBED_DIM) + self.output_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + elif self.projection_layout == "separate_qkv": + self.q_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + self.k_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + self.v_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + self.output_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + else: + raise ValueError(f"Unsupported projection layout: {projection_layout}") + + def project_qkv(self, x: torch.Tensor) -> torch.Tensor: + """ + Produce a fused [B, S, 3 * H * D] QKV tensor. + + The plugin consumes fused QKV, so the SDPA reference uses the same + representation before splitting into [Q, K, V]. + """ + if self.projection_layout == "fused_qkv": + return self.qkv(x) + + if self.projection_layout == "separate_qkv": + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + else: + raise ValueError(f"Unknown projection layout: {self.projection_layout}") + return torch.cat([q, k, v], dim=-1) + + def forward( + self, + x: torch.Tensor, + attention_mask: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + ) -> torch.Tensor: + """ + Run PyTorch SDPA using the dense equivalent of the requested attention. + + cu_seqlens is accepted so the reference and plugin models can share the + same call signature, but SDPA consumes the dense additive mask. + """ + batch_size, seq_len, _ = x.shape + qkv = self.project_qkv(x) + has_position_embeddings = cos is not None and sin is not None + if not has_position_embeddings: + cos, sin = create_identity_rope(seq_len, x.dtype, x.device) + else: + cos = cos.to(dtype=x.dtype) + sin = sin.to(dtype=x.dtype) + + if attention_mask is None: + attention_mask = create_zero_mask(batch_size, seq_len, x.dtype, x.device) + else: + attention_mask = attention_mask.to(dtype=x.dtype) + + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, NUM_HEADS, HEAD_DIM).permute(2, 0, 3, 1, 4) + query, key, value = qkv[0], qkv[1], qkv[2] + + if has_position_embeddings: + query = apply_qwen_vl_rope(query, cos, sin) + key = apply_qwen_vl_rope(key, cos, sin) + + attn = F.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=attention_mask.unsqueeze(1), + dropout_p=0.0, + is_causal=False, + ) + attn_out = attn.transpose(1, 2).contiguous().view(batch_size, seq_len, EMBED_DIM) + return self.output_proj(attn_out) + + +class ViTPluginModel(nn.Module): + """ + TensorRT plugin model for ViT/VLA attention. + + The projection layers mirror ViTSDPAModel. The attention math is replaced by + the custom Torch op that convert_vit_attention lowers to ViTAttentionPlugin. + """ + + def __init__( + self, + projection_layout: str, + plugin_mask_type: int = MASK_TYPE_DENSE, + plugin_max_seq_len: int = 0, + ): + """ + Create projection layers and remember plugin execution settings. + + plugin_max_seq_len is only used for cu_seqlens FMHA. It is the maximum + segment length represented by the prefix-sum array, which can be smaller + than the total packed token count for windowed visual attention. + """ + super().__init__() + self.projection_layout = projection_layout + self.plugin_mask_type = plugin_mask_type + self.plugin_max_seq_len = plugin_max_seq_len + + if self.projection_layout == "fused_qkv": + self.qkv = nn.Linear(EMBED_DIM, 3 * EMBED_DIM) + self.output_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + elif self.projection_layout == "separate_qkv": + self.q_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + self.k_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + self.v_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + self.output_proj = nn.Linear(EMBED_DIM, EMBED_DIM) + else: + raise ValueError(f"Unsupported projection layout: {projection_layout}") + + def project_qkv(self, x: torch.Tensor) -> torch.Tensor: + """ + Produce the fused [B, S, 3 * H * D] tensor consumed by the plugin. + + Separate projection layouts are concatenated here so the plugin sees the + same input layout regardless of upstream model style. + """ + if self.projection_layout == "fused_qkv": + return self.qkv(x) + + if self.projection_layout == "separate_qkv": + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + else: + raise ValueError(f"Unknown projection layout: {self.projection_layout}") + return torch.cat([q, k, v], dim=-1) + + def forward( + self, + x: torch.Tensor, + attention_mask: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + ) -> torch.Tensor: + """ + Run the TensorRT ViTAttentionPlugin through a Torch custom op. + + Dense mode passes an additive mask. cu_seqlens mode passes prefix-sum + boundaries that describe independent packed attention regions. + """ + batch_size, seq_len, _ = x.shape + qkv = self.project_qkv(x) + has_position_embeddings = cos is not None and sin is not None + if not has_position_embeddings: + cos, sin = create_identity_rope(seq_len, x.dtype, x.device) + else: + cos = cos.to(dtype=x.dtype) + sin = sin.to(dtype=x.dtype) + + if attention_mask is None: + attention_mask = create_zero_mask(batch_size, seq_len, x.dtype, x.device) + else: + attention_mask = attention_mask.to(dtype=x.dtype) + + mask_or_cu_seqlens = attention_mask + if self.plugin_mask_type == MASK_TYPE_CU_SEQLENS: + if cu_seqlens is None: + cu_seqlens = create_full_cu_seqlens(seq_len, x.device) + mask_or_cu_seqlens = cu_seqlens + + attn_out = torch.ops.tensorrt_vit_attention.attn.default( + qkv, + cos, + sin, + mask_or_cu_seqlens, + NUM_HEADS, + HEAD_DIM, + 1, + self.plugin_mask_type, + self.plugin_max_seq_len, + ) + + return self.output_proj(attn_out) + + +@dataclass(frozen=True) +class AttentionCase: + """Bundle one model-style attention layout with its runtime inputs.""" + + name: str + projection_layout: str + kwargs_factory: Callable[[torch.Tensor], Dict[str, torch.Tensor]] + +def no_extra_inputs(x: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Create inputs for full bidirectional attention with identity RoPE. + + The dense path will synthesize a zero mask in forward(). The cu_seqlens path + uses [0, S] to represent the same full attention region. + """ + _, seq_len, _ = x.shape + return { + "cu_seqlens": create_full_cu_seqlens(seq_len, x.device), + } + +def qwen_vl_inputs(x: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Create inputs for QwenVL-style windowed visual attention. + + This is the key validation case for packed cu_seqlens: the dense mask and + cu_seqlens tensor represent the same four independent windows. + """ + batch_size, seq_len, _ = x.shape + cos, sin = create_qwen_vl_rope(seq_len, x.dtype, x.device) + return { + "attention_mask": create_window_mask(batch_size, seq_len, x.dtype, x.device), + "cu_seqlens": create_window_cu_seqlens(seq_len, x.device), + "cos": cos, + "sin": sin, + } + + +ATTENTION_CASES = [ + AttentionCase( + name="Plain ViT Attention", + projection_layout="fused_qkv", + kwargs_factory=no_extra_inputs, + ), + AttentionCase( + name="QwenVL-Style Attention", + projection_layout="fused_qkv", + kwargs_factory=qwen_vl_inputs, + ), + AttentionCase( + name="LlamaVision-Style Attention", + projection_layout="separate_qkv", + kwargs_factory=no_extra_inputs, + ), + AttentionCase( + name="GR00T/SigLip2-Style Attention", + projection_layout="separate_qkv", + kwargs_factory=no_extra_inputs, + ), +] + +# ----------------------------------------------------------------------------- +# Plugin operation registration +# ----------------------------------------------------------------------------- + +def register_vit_attention_op(): + """ + Register a Torch custom op that TorchDynamo can trace. + + The Python implementation is only a placeholder shape function. During + Torch-TensorRT conversion, convert_vit_attention replaces this op with the + real TensorRT plugin layer. + """ + @torch.library.custom_op("tensorrt_vit_attention::attn", mutates_args=()) + def vit_attention( + qkv: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask_or_cu_seqlens: torch.Tensor, + num_heads: int, + head_dim: int, + qkv_fused: int = 1, + mask_type: int = MASK_TYPE_DENSE, + max_seq_len: int = 0, + ) -> torch.Tensor: + """Return an empty-shaped output for eager/tracing fallback.""" + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + output_dim = num_heads * head_dim + return torch.zeros(batch_size, seq_len, output_dim, dtype=qkv.dtype, device=qkv.device) + + @torch.library.register_fake("tensorrt_vit_attention::attn") + def _(qkv, cos, sin, mask_or_cu_seqlens, num_heads, head_dim, qkv_fused=1, mask_type=MASK_TYPE_DENSE, max_seq_len=0): + """Provide fake tensor propagation for torch.export.""" + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + output_dim = num_heads * head_dim + return torch.empty(batch_size, seq_len, output_dim, dtype=qkv.dtype, device=qkv.device) + +register_vit_attention_op() + +@dynamo_tensorrt_converter(torch.ops.tensorrt_vit_attention.attn.default, supports_dynamic_shapes=True) +def convert_vit_attention(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert the traced custom op into a ViTAttentionPlugin layer. + + Scalar arguments become TensorRT plugin fields. Tensor arguments become + plugin inputs. max_seq_len is a field because FMHA needs it when cu_seqlens + packs several smaller attention regions into one QKV tensor. + """ + qkv, cos, sin, mask_or_cu_seqlens, num_heads, head_dim = args[:6] + qkv_fused = args[6] if len(args) > 6 else kwargs.get("qkv_fused", 1) + mask_type = args[7] if len(args) > 7 else kwargs.get("mask_type", MASK_TYPE_DENSE) + max_seq_len = args[8] if len(args) > 8 else kwargs.get("max_seq_len", 0) + + creator = trt.get_plugin_registry().get_plugin_creator("ViTAttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("ViTAttentionPlugin not found! Make sure the plugin library is loaded.") + + field_list = [ + trt.PluginField("num_heads", np.array([num_heads], dtype=np.int32), trt.PluginFieldType.INT32), + trt.PluginField("head_size", np.array([head_dim], dtype=np.int32), trt.PluginFieldType.INT32), + trt.PluginField("qkv_fused", np.array([qkv_fused], dtype=np.int32), trt.PluginFieldType.INT32), + trt.PluginField("mask_type", np.array([mask_type], dtype=np.int32), trt.PluginFieldType.INT32), + trt.PluginField("max_seq_len", np.array([max_seq_len], dtype=np.int32), trt.PluginFieldType.INT32), + ] + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + if plugin is None: + raise RuntimeError("Failed to create ViTAttentionPlugin") + + input_tensors = [ + get_trt_tensor(ctx, qkv, "qkv"), + get_trt_tensor(ctx, cos, "cos"), + get_trt_tensor(ctx, sin, "sin"), + get_trt_tensor(ctx, mask_or_cu_seqlens, "mask_or_cu_seqlens"), + ] + layer = ctx.net.add_plugin_v2(input_tensors, plugin) + return layer.get_output(0) + +# ----------------------------------------------------------------------------- +# Correctness validation +# ----------------------------------------------------------------------------- + +def run_attention_case( + case_name: str, + plugin_label: str, + plugin_mask_type: int, + reference_model: nn.Module, + plugin_model: nn.Module, + x: torch.Tensor, + kwargs, +): + """ + Compile and validate one attention case for one plugin mask mode. + + The reference model and plugin model share weights. Inputs are normalized to + positional tensors so this example can use the same torch_tensorrt.compile() + style as attention_plugin_example.py. + """ + plugin_model.load_state_dict(reference_model.state_dict()) + batch_size, seq_len, _ = x.shape + attention_mask = kwargs.get( + "attention_mask", + create_zero_mask(batch_size, seq_len, x.dtype, x.device), + ) + cu_seqlens = kwargs.get("cu_seqlens", create_full_cu_seqlens(seq_len, x.device)) + cos, sin = kwargs.get("cos"), kwargs.get("sin") + if cos is None or sin is None: + cos, sin = create_identity_rope(seq_len, x.dtype, x.device) + + runtime_inputs = (x, attention_mask, cu_seqlens, cos, sin) + + with torch.no_grad(): + ref_out = reference_model(*runtime_inputs) + + print(f"\n=== {case_name} | {plugin_label} ===") + print("Compiling TensorRT ViT attention plugin model...") + inputs_spec = [ + torch_tensorrt.Input(shape=tuple(x.shape), dtype=x.dtype), + torch_tensorrt.Input(shape=tuple(attention_mask.shape), dtype=attention_mask.dtype), + torch_tensorrt.Input(shape=tuple(cu_seqlens.shape), dtype=cu_seqlens.dtype), + torch_tensorrt.Input(shape=tuple(cos.shape), dtype=cos.dtype), + torch_tensorrt.Input(shape=tuple(sin.shape), dtype=sin.dtype), + ] + with torch_tensorrt.logging.errors(): + trt_model = torch_tensorrt.compile( + plugin_model, + inputs=inputs_spec, + use_explicit_typing=True, + use_fp32_acc=True, + device=DEVICE, + disable_tf32=True, + min_block_size=1, + ) + + with torch.no_grad(): + plugin_out = trt_model(*runtime_inputs) + + print("Reference output shape:", ref_out.shape) + print("Plugin output shape:", plugin_out.shape) + max_abs_diff = (ref_out - plugin_out).abs().max().item() + cosine = F.cosine_similarity(ref_out.flatten().float(), plugin_out.flatten().float(), dim=0).item() + passed = cosine >= 0.99 + print(f"Max absolute difference: {max_abs_diff:.6f}") + print(f"Cosine similarity: {cosine:.6f}") + print(f"Result: {'PASS' if passed else 'FAIL'}") + + return passed, cosine, max_abs_diff + + +def get_plugin_max_seq_len(plugin_mask_type: int, kwargs: Dict[str, torch.Tensor]) -> int: + """ + Return the maximum segment length required by the FMHA cu_seqlens path. + + For full attention this is S. For windowed attention this is the window + length. Dense mode returns 0 so the plugin falls back to runtime S. + """ + if plugin_mask_type != MASK_TYPE_CU_SEQLENS or "cu_seqlens" not in kwargs: + return 0 + cu_seqlens = kwargs["cu_seqlens"] + return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) + + +if __name__ == "__main__": + torch.manual_seed(0) + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM, dtype=DTYPE, device=DEVICE) + + print("\nViT Attention Plugin - Dense Mask vs cu_seqlens Correctness Validation") + print(f"Config: B={BATCH_SIZE}, S={SEQ_LEN}, H={NUM_HEADS}, D={HEAD_DIM}, window={WINDOW_SIZE}") + results = [] + for attention_case in ATTENTION_CASES: + reference_model = ViTSDPAModel( + attention_case.projection_layout, + ).to(device=DEVICE, dtype=DTYPE).eval() + kwargs = attention_case.kwargs_factory(x) + + for plugin_label, plugin_mask_type in ( + ("Dense additive mask", MASK_TYPE_DENSE), + ("cu_seqlens FMHA", MASK_TYPE_CU_SEQLENS), + ): + plugin_model = ViTPluginModel( + attention_case.projection_layout, + plugin_mask_type=plugin_mask_type, + plugin_max_seq_len=get_plugin_max_seq_len(plugin_mask_type, kwargs), + ).to(device=DEVICE, dtype=DTYPE).eval() + results.append( + ( + attention_case.name, + plugin_label, + run_attention_case( + attention_case.name, + plugin_label, + plugin_mask_type, + reference_model, + plugin_model, + x, + kwargs, + ), + ) + ) + + print("\nSUMMARY") + for name, plugin_label, (passed, cosine, max_abs_diff) in results: + status = "PASS" if passed else "FAIL" + print(f"{name} | {plugin_label}: {status}") + print(f" Cosine: {cosine:.4f}, Max abs diff: {max_abs_diff:.6f}") + + all_passed = all(result[0] for _, _, result in results) + if all_passed: + print("SUCCESS: All ViT attention plugin tests passed!") + else: + print("FAILURE: Some ViT attention plugin tests failed") diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 54235e9e3d..9c3116d356 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -24,16 +24,17 @@ def unsqueeze( dim: int, ) -> TRTTensor: # tensorrt version < 10.7.0, use the old unsqueeze implementation - if is_tensorrt_version_supported("10.7.0"): + if is_tensorrt_version_supported("10.7.0") and -1 not in input.shape: # use the new unsqueeze implementation axes = get_trt_tensor(ctx, dim, f"{name}_axes") layer = ctx.net.add_unsqueeze(input, axes) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) else: - logger.warning( - "IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version" - ) + if not is_tensorrt_version_supported("10.7.0"): + logger.warning( + "IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version" + ) return unsqueeze_old(ctx, target, source_ir, name, input, dim) diff --git a/tests/py/dynamo/conversion/test_expand_aten.py b/tests/py/dynamo/conversion/test_expand_aten.py index ce96bc05f7..805598e647 100644 --- a/tests/py/dynamo/conversion/test_expand_aten.py +++ b/tests/py/dynamo/conversion/test_expand_aten.py @@ -81,6 +81,27 @@ def forward(self, x): ExpandTargetDynamic(), input_specs, use_dynamo_tracer=True ) + def test_expand_dynamic_gqa_pattern(self): + class ExpandGroupedQueryAttention(torch.nn.Module): + def forward(self, x): + seq_len = x.shape[2] + x = torch.ops.aten.unsqueeze.default(x, 2) + return torch.ops.aten.expand.default(x, (1, 2, 8, seq_len, 128)) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=(1, 2, 4, 128), + opt_shape=(1, 2, 8, 128), + max_shape=(1, 2, 16, 128), + ), + ] + self.run_test_with_dynamic_shape( + ExpandGroupedQueryAttention(), + input_specs, + use_dynamo_tracer=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tools/llm/plugin_converter.py b/tools/llm/plugin_converter.py new file mode 100644 index 0000000000..89b31d952a --- /dev/null +++ b/tools/llm/plugin_converter.py @@ -0,0 +1,96 @@ +""" +TensorRT converter for Edge-LLM attention plugin ops. + +This module contains the TensorRT converter for the tensorrt_edge_llm::xqa_attn +custom op. It is kept in a separate file from plugin_utils.py for maintainability. +""" + +import numpy as np +import tensorrt as trt +from plugin_utils import get_plugin_config, register_plugin_op +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# Ensure the custom op is registered before the converter decorator runs +register_plugin_op() + +import torch # noqa: E402 (must be after register_plugin_op so the op exists) + +@dynamo_tensorrt_converter( + torch.ops.tensorrt_edge_llm.xqa_attn.default, supports_dynamic_shapes=True +) +def convert_attn(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert tensorrt_edge_llm::xqa_attn op to TensorRT AttentionPlugin. + + TensorRT-Edge-LLM (0.4.0) plugin requires 5 inputs: + - qkv, kv, ctx_len, rope, kv_cache_start_idx + + Plugin fields: + - num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output + """ + # args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d + qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8] + + creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("AttentionPlugin not found in TensorRT plugin registry!") + + # Get config from global settings + config = get_plugin_config() + if config: + nq_val = config["num_attention_heads"] + nkv_val = config["num_key_value_heads"] + d_val = config["head_dim"] + else: + # Fallback to values from args (may not work correctly) + nq_val = nq if isinstance(nq, int) else 14 + nkv_val = nkv if isinstance(nkv, int) else 2 + d_val = d if isinstance(d, int) else 64 + + # Plugin fields for TensorRT-Edge-LLM AttentionPlugin + # Required: num_q_heads, num_kv_heads, head_size, enable_tree_attention + # enable_delta_kv_output=1 enables delta KV output for Python/torch_tensorrt compatibility + field_list = [ + trt.PluginField( + field_name, np.array([field_val], dtype=np.int32), trt.PluginFieldType.INT32 + ) + for field_name, field_val in [ + ("num_q_heads", nq_val), + ("num_kv_heads", nkv_val), + ("head_size", d_val), + ("enable_tree_attention", 0), + ("enable_delta_kv_output", 1), + ] + ] + + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + + # 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx + inputs = [ + ( + get_trt_tensor(ctx, i, f"{name}_i{idx}") + if not isinstance(i, trt.ITensor) + else i + ) + for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx]) + ] + + # Handle ctx_len shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[2].shape) == 2 and inputs[2].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[2]) + shuffle_layer.reshape_dims = (inputs[2].shape[0],) + inputs[2] = shuffle_layer.get_output(0) + + # Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[4]) + shuffle_layer.reshape_dims = (inputs[4].shape[0],) + inputs[4] = shuffle_layer.get_output(0) + + layer = ctx.net.add_plugin_v2(inputs, plugin) + return layer.get_output(0), layer.get_output(1) diff --git a/tools/llm/plugin_converter_vit.py b/tools/llm/plugin_converter_vit.py new file mode 100644 index 0000000000..921c278363 --- /dev/null +++ b/tools/llm/plugin_converter_vit.py @@ -0,0 +1,145 @@ +""" +TensorRT converter for ViT attention plugin ops. + +This module contains the TensorRT converter for the tensorrt_edge_llm::xqa_attn +custom op. It is kept in a separate file from plugin_utils.py for maintainability. +""" + +import numpy as np +import tensorrt as trt + +from plugin_utils_vit import get_vit_plugin_config, register_vit_plugin_op +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) + +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +register_vit_plugin_op() + +import torch # noqa: E402 (must be after register_vit_plugin_op so the op exists) + +_VIT_PLUGIN_CONVERSION_COUNT = 0 + + +def reset_vit_plugin_conversion_count() -> None: + """Reset the number of ViT plugin nodes lowered during TRT conversion.""" + global _VIT_PLUGIN_CONVERSION_COUNT + _VIT_PLUGIN_CONVERSION_COUNT = 0 + + +def get_vit_plugin_conversion_count() -> int: + """Return the number of ViT plugin nodes lowered during TRT conversion.""" + return _VIT_PLUGIN_CONVERSION_COUNT + + +def _as_int(value, default): + """Return value as a Python int when export preserved it as a scalar.""" + if isinstance(value, int): + return value + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _infer_mask_type(mask_arg, default): + """Infer cu_seqlens mode when the fourth plugin input is an INT32 tensor.""" + dtype = getattr(mask_arg, "dtype", None) + dtype_name = str(dtype).lower() + if ( + dtype == torch.int32 + or dtype == trt.int32 + or dtype == trt.DataType.INT32 + or "int32" in dtype_name + ): + return 1 + return default + + +@dynamo_tensorrt_converter( + torch.ops.tensorrt_vit.attention.default, supports_dynamic_shapes=True +) +def convert_vit_attention(ctx: ConversionContext, target, args, kwargs, name): + """Convert tensorrt_vit::attention to TensorRT ViTAttentionPlugin.""" + global _VIT_PLUGIN_CONVERSION_COUNT + _VIT_PLUGIN_CONVERSION_COUNT += 1 + + qkv, cos, sin, mask_or_cu_seqlens, num_heads, head_dim = args[:6] + qkv_fused = args[6] if len(args) > 6 else kwargs.get("qkv_fused", 1) + mask_type = args[7] if len(args) > 7 else kwargs.get("mask_type", 0) + max_seq_len = args[8] if len(args) > 8 else kwargs.get("max_seq_len", 0) + + creator = trt.get_plugin_registry().get_plugin_creator( + "ViTAttentionPlugin", "1", "" + ) + if creator is None: + raise RuntimeError( + "ViTAttentionPlugin not found in TensorRT plugin registry!" + ) + + qkv_tensor = ( + get_trt_tensor(ctx, qkv, f"{name}_qkv") + if not isinstance(qkv, trt.ITensor) + else qkv + ) + cos_tensor = ( + get_trt_tensor(ctx, cos, f"{name}_cos") + if not isinstance(cos, trt.ITensor) + else cos + ) + sin_tensor = ( + get_trt_tensor(ctx, sin, f"{name}_sin") + if not isinstance(sin, trt.ITensor) + else sin + ) + mask_tensor = ( + get_trt_tensor(ctx, mask_or_cu_seqlens, f"{name}_mask") + if not isinstance(mask_or_cu_seqlens, trt.ITensor) + else mask_or_cu_seqlens + ) + + config = get_vit_plugin_config() + num_heads_val = config.get("num_attention_heads", num_heads) + head_dim_val = config.get("head_dim", head_dim) + qkv_fused_val = _as_int(qkv_fused, 1) + inferred_mask_type = _infer_mask_type(mask_tensor, config.get("mask_type", 0)) + mask_type_val = _as_int(mask_type, inferred_mask_type) + if inferred_mask_type == 1: + mask_type_val = 1 + max_seq_len_val = _as_int(max_seq_len, config.get("max_seq_len", 0)) + field_list = [ + trt.PluginField( + "num_heads", + np.array([num_heads_val], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + trt.PluginField( + "head_size", + np.array([head_dim_val], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + trt.PluginField( + "qkv_fused", + np.array([qkv_fused_val], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + trt.PluginField( + "mask_type", + np.array([mask_type_val], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + trt.PluginField( + "max_seq_len", + np.array([max_seq_len_val], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + ] + plugin = creator.create_plugin(name, trt.PluginFieldCollection(field_list)) + if plugin is None: + raise RuntimeError("Failed to create ViTAttentionPlugin") + + layer = ctx.net.add_plugin_v2([qkv_tensor, cos_tensor, sin_tensor, mask_tensor], plugin) + layer.name = name + return layer.get_output(0) diff --git a/tools/llm/plugin_utils.py b/tools/llm/plugin_utils.py new file mode 100644 index 0000000000..c27e30db24 --- /dev/null +++ b/tools/llm/plugin_utils.py @@ -0,0 +1,867 @@ +""" +Plugin utilities for TensorRT LLM inference with custom attention plugins. + +This module provides model-agnostic utilities for using TensorRT attention plugins +with various LLM architectures (Qwen, Llama, etc.). +""" + +import ctypes +import inspect +import os +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch_tensorrt + +# Default plugin path - can be overridden +# Built from: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime +DEFAULT_PLUGIN_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", +) + +# Global configuration for plugin converter +_PLUGIN_CONFIG: Dict[str, Any] = {} + + +def load_plugin(plugin_path: Optional[str] = None) -> bool: + """ + Load the TensorRT attention plugin library. + + Args: + plugin_path: Path to the plugin .so file. If None, uses DEFAULT_PLUGIN_PATH. + + Returns: + True if plugin was loaded successfully, False otherwise. + + Raises: + RuntimeError: If plugin file does not exist. + """ + path = plugin_path or DEFAULT_PLUGIN_PATH + if not os.path.exists(path): + raise RuntimeError(f"Plugin not found at {path}") + ctypes.CDLL(path) + return True + + +def set_plugin_config( + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + max_seq_len: int = 2048, + max_batch_size: int = 4, +) -> None: + """ + Set global configuration for the plugin converter. + + Args: + num_attention_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (for GQA). + head_dim: Dimension of each attention head. + max_seq_len: Maximum sequence length for KV cache. + max_batch_size: Maximum batch size. + """ + global _PLUGIN_CONFIG + _PLUGIN_CONFIG = { + "num_attention_heads": num_attention_heads, + "num_key_value_heads": num_key_value_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + "max_batch_size": max_batch_size, + } + + +def get_plugin_config() -> Dict[str, Any]: + """Get the current plugin configuration.""" + return _PLUGIN_CONFIG.copy() + + +def set_plugin_config_from_model(model_config: Any, max_seq_len: int = 2048) -> None: + """ + Set plugin configuration from a HuggingFace model config. + + Args: + model_config: HuggingFace model configuration object. + max_seq_len: Maximum sequence length for KV cache. + """ + # Qwen3 has explicit head_dim in config that differs from hidden_size // num_attention_heads + if hasattr(model_config, "head_dim") and model_config.head_dim is not None: + head_dim = model_config.head_dim + else: + head_dim = model_config.hidden_size // model_config.num_attention_heads + + set_plugin_config( + num_attention_heads=model_config.num_attention_heads, + num_key_value_heads=model_config.num_key_value_heads, + head_dim=head_dim, + max_seq_len=max_seq_len, + ) + + +# ----------------------------------------------------------------------------- +# Plugin Op Registration +# ----------------------------------------------------------------------------- + + +def _register_plugin_op_impl() -> None: + """ + Internal implementation to register the tensorrt_edge_llm::xqa_attn custom op for PyTorch. + + The TensorRT-Edge-LLM plugin (0.4.0-based) requires 5 inputs: + - qkv: [B, S, (Hq+Hk+Hv)*D] fused QKV tensor + - kv: [B, 2, Hkv, Capacity, D] KV cache tensor + - ctx_len: [B] context length per batch + - rope: [1, MaxSeqLen, RotaryDim] rotary position encoding + - kv_cache_start_idx: [B] starting index in KV cache + + With enable_delta_kv_output=1, output KV shape is [B, 2, H, SeqLen, D] (delta only). + """ + + @torch.library.custom_op("tensorrt_edge_llm::xqa_attn", mutates_args=()) + def attn( + qkv: torch.Tensor, + kv: torch.Tensor, + ctx_len: torch.Tensor, + rope: torch.Tensor, + kv_cache_start_idx: torch.Tensor, + nq: int, + nkv: int, + d: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.zeros( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + # Delta KV output: shape is [B, 2, H, SeqLen, D] + updated_kv = torch.zeros( + batch_size, 2, nkv, seq_len, d, dtype=qkv.dtype, device=qkv.device + ) + return attn_out, updated_kv + + @torch.library.register_fake("tensorrt_edge_llm::xqa_attn") + def _(qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d): + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.empty( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + # Delta KV output + updated_kv = torch.empty( + batch_size, 2, nkv, seq_len, d, dtype=qkv.dtype, device=qkv.device + ) + return attn_out, updated_kv + + +def register_plugin_op() -> None: + """ + Register the tensorrt_edge_llm::xqa_attn custom op for PyTorch. + + This function is idempotent - safe to call multiple times. + """ + if hasattr(torch.ops, "tensorrt_edge_llm") and hasattr( + torch.ops.tensorrt_edge_llm, "xqa_attn" + ): + return + _register_plugin_op_impl() + + +# Register the op at module import time so the converter decorator works +# This is safe because the op registration is idempotent +if not ( + hasattr(torch.ops, "tensorrt_edge_llm") + and hasattr(torch.ops.tensorrt_edge_llm, "xqa_attn") +): + _register_plugin_op_impl() + +# The converter for tensorrt_edge_llm::xqa_attn is defined in plugin_converter.py. +# Import it here so that importing plugin_utils still registers the converter. +from plugin_converter import convert_attn # noqa: F401 + +# ----------------------------------------------------------------------------- +# RoPE Cache Generation +# ----------------------------------------------------------------------------- + +def get_plugin_rope_cache( + rotary_emb: nn.Module, + max_seq_len: int, + head_dim: int, + device: torch.device, +) -> torch.Tensor: + """ + Generate RoPE cache tensor for the plugin from a rotary embedding module. + + Args: + rotary_emb: The rotary embedding module from the model. + max_seq_len: Maximum sequence length. + head_dim: Dimension of each attention head. + device: Device to create the cache on. + + Returns: + RoPE cache tensor of shape [1, max_seq_len, head_dim]. + """ + inv_freq = rotary_emb.inv_freq.to(device).float() + attention_scaling = getattr(rotary_emb, "attention_scaling", 1.0) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos_half = freqs.cos() * attention_scaling + sin_half = freqs.sin() * attention_scaling + rope = torch.cat([cos_half, sin_half], dim=-1) + return rope.unsqueeze(0) + + +# ----------------------------------------------------------------------------- +# Plugin Attention Module +# ----------------------------------------------------------------------------- + + +class PluginAttention(nn.Module): + """ + Model-agnostic Plugin Attention module that replaces standard attention. + + This module wraps the projection layers from the original attention module + and uses the tensorrt_edge_llm::xqa_attn plugin op for the attention computation. + + Supports: + - Qwen2.5, Llama: Standard attention + - Qwen3: Attention with QK Normalization (q_norm, k_norm) + """ + + def __init__( + self, + original_attn: nn.Module, + config: Any, + layer_idx: int, + rope_cache: torch.Tensor, + ): + """ + Initialize PluginAttention. + + Args: + original_attn: The original attention module to wrap. + config: Model configuration. + layer_idx: Index of this layer in the model. + rope_cache: Pre-computed RoPE cache tensor. + """ + super().__init__() + self.q_proj = original_attn.q_proj + self.k_proj = original_attn.k_proj + self.v_proj = original_attn.v_proj + self.o_proj = original_attn.o_proj + + # Qwen3 has QK Normalization + self.q_norm = getattr(original_attn, "q_norm", None) + self.k_norm = getattr(original_attn, "k_norm", None) + + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + + # For Qwen3, attention output size is num_heads * head_dim, not hidden_size + self.attn_hidden_size = self.num_heads * self.head_dim + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.register_buffer("rope_cache", rope_cache) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[torch.Tensor] = None, + ctx_len: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass using the plugin attention. + + Args: + hidden_states: Input tensor of shape [batch, seq_len, hidden_size]. + attention_mask: Unused (plugin handles masking internally). + position_ids: Position IDs (unused, plugin uses RoPE cache). + past_key_value: KV cache tensor of shape [batch, 2, num_kv_heads, capacity, head_dim]. + ctx_len: Context length tensor for each batch item. + + Returns: + Tuple of (output tensor, updated KV cache). + """ + batch_size, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Qwen3: Apply QK Normalization if available + if self.q_norm is not None: + # Reshape for per-head normalization: [B, S, num_heads, head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + q = self.q_norm(q) + q = q.view(batch_size, seq_len, -1) + + if self.k_norm is not None: + # Reshape for per-head normalization: [B, S, num_kv_heads, head_dim] + k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) + k = self.k_norm(k) + k = k.view(batch_size, seq_len, -1) + + qkv = torch.cat([q, k, v], dim=-1) + + if ctx_len is None: + ctx_len = torch.tensor( + [seq_len], dtype=torch.int32, device=hidden_states.device + ).expand(batch_size) + + rope_fp32 = self.rope_cache.float() + + if past_key_value is None: + raise ValueError("past_key_value (KV cache tensor) must be provided") + + # kv_cache_start_idx: starting position in KV cache for each batch + # For normal inference, this is 0 (start from beginning) + kv_cache_start_idx = torch.zeros( + batch_size, dtype=torch.int32, device=hidden_states.device + ) + + attn_out, updated_kv = torch.ops.tensorrt_edge_llm.xqa_attn.default( + qkv, + past_key_value, + ctx_len, + rope_fp32, + kv_cache_start_idx, + self.num_heads, + self.num_key_value_heads, + self.head_dim, + ) + + # Use attn_hidden_size for reshape (may differ from hidden_size in Qwen3) + attn_out = attn_out.reshape(batch_size, seq_len, self.attn_hidden_size) + output = self.o_proj(attn_out) + return output, updated_kv + + +# ----------------------------------------------------------------------------- +# Model Wrappers +# ----------------------------------------------------------------------------- + + +class LLMPluginWrapper(nn.Module): + """ + Generic wrapper for LLM models with plugin attention. + + This wrapper handles the forward pass for models with replaced attention modules, + managing KV caches and context lengths appropriately. + """ + + def __init__(self, model: nn.Module, model_type: str = "auto"): + """ + Initialize the wrapper. + + Args: + model: The model with replaced attention modules. + model_type: Type of model ("qwen", "llama", or "auto" for auto-detection). + """ + super().__init__() + self.model = model + self.model_type = ( + self._detect_model_type(model) if model_type == "auto" else model_type + ) + + def _detect_model_type(self, model: nn.Module) -> str: + """Auto-detect model type from model structure.""" + model_class = model.__class__.__name__.lower() + if "qwen" in model_class: + return "qwen" + elif "llama" in model_class or "mistral" in model_class: + return "llama" + else: + # Default to generic transformer structure + return "generic" + + def _get_transformer(self) -> nn.Module: + """Get the transformer backbone based on model type.""" + if self.model_type == "qwen": + return self.model.model + elif self.model_type == "llama": + return self.model.model + else: + # Try common attribute names + for attr in ["model", "transformer", "backbone"]: + if hasattr(self.model, attr): + return getattr(self.model, attr) + raise ValueError( + f"Cannot find transformer backbone for model type: {self.model_type}" + ) + + def _get_layers(self, transformer: nn.Module) -> nn.ModuleList: + """Get the list of transformer layers.""" + for attr in ["layers", "h", "blocks"]: + if hasattr(transformer, attr): + return getattr(transformer, attr) + raise ValueError("Cannot find transformer layers") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + ctx_len: torch.Tensor, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Forward pass with plugin attention. + + Args: + input_ids: Input token IDs [batch, seq_len]. + position_ids: Position IDs [batch, seq_len]. + kv_caches: List of KV cache tensors, one per layer. + ctx_len: Context length tensor [batch]. + + Returns: + Tuple of (logits, list of updated KV caches). + """ + transformer = self._get_transformer() + hidden_states = transformer.embed_tokens(input_ids) + + layers = self._get_layers(transformer) + new_kv_caches = [] + + for i, layer in enumerate(layers): + past_key_value = kv_caches[i] + residual = hidden_states + + # Input layer norm + if hasattr(layer, "input_layernorm"): + hidden_states = layer.input_layernorm(hidden_states) + elif hasattr(layer, "ln_1"): + hidden_states = layer.ln_1(hidden_states) + + # Self attention + hidden_states, updated_kv = layer.self_attn( + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=past_key_value, + ctx_len=ctx_len, + ) + hidden_states = residual + hidden_states + + # Post attention layer norm + MLP + residual = hidden_states + if hasattr(layer, "post_attention_layernorm"): + hidden_states = layer.post_attention_layernorm(hidden_states) + elif hasattr(layer, "ln_2"): + hidden_states = layer.ln_2(hidden_states) + hidden_states = layer.mlp(hidden_states) + hidden_states = residual + hidden_states + + new_kv_caches.append(updated_kv) + + # Final layer norm + if hasattr(transformer, "norm"): + hidden_states = transformer.norm(hidden_states) + elif hasattr(transformer, "ln_f"): + hidden_states = transformer.ln_f(hidden_states) + + # LM head + logits = self.model.lm_head(hidden_states) + + return logits, new_kv_caches + + +# ----------------------------------------------------------------------------- +# Model Modification Functions +# ----------------------------------------------------------------------------- + + +def replace_attention_with_plugin( + model: nn.Module, + config: Any, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, +) -> nn.Module: + """ + Replace all attention modules in a model with PluginAttention. + + Args: + model: The HuggingFace model to modify. + config: Model configuration. + max_seq_len: Maximum sequence length for RoPE cache. + device: Device for the model. + dtype: Data type for the model. + + Returns: + The modified model with plugin attention. + """ + # Get rotary embedding from model + transformer = model.model if hasattr(model, "model") else model + + # Try to find rotary embedding + rotary_emb = None + if hasattr(transformer, "rotary_emb"): + rotary_emb = transformer.rotary_emb + elif hasattr(transformer, "layers") and len(transformer.layers) > 0: + first_layer = transformer.layers[0] + if hasattr(first_layer, "self_attn") and hasattr( + first_layer.self_attn, "rotary_emb" + ): + rotary_emb = first_layer.self_attn.rotary_emb + + if rotary_emb is None: + raise ValueError("Cannot find rotary embedding in model") + + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + rope_cache = get_plugin_rope_cache(rotary_emb, max_seq_len, head_dim, device) + + # Get layers + if hasattr(transformer, "layers"): + layers = transformer.layers + elif hasattr(transformer, "h"): + layers = transformer.h + else: + raise ValueError("Cannot find transformer layers") + + # Replace attention modules + for i, layer in enumerate(layers): + layer.self_attn = PluginAttention(layer.self_attn, config, i, rope_cache) + + return model + + +# ----------------------------------------------------------------------------- +# Compilation +# ----------------------------------------------------------------------------- + + +def compile_plugin_model( + model: nn.Module, + config: Any, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + debug: bool = False, +) -> Callable: + """ + Compile a model with plugin attention for TensorRT inference. + + Args: + model: The wrapped model (should be LLMPluginWrapper or similar). + config: Model configuration. + max_seq_len: Maximum sequence length. + device: Device for compilation. + dtype: Data type. + debug: Whether to enable debug logging. + + Returns: + Compiled TensorRT model function. + """ + # Prepare dummy inputs + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + dummy_input_ids = torch.tensor([[1, 2, 3]], device=device) + dummy_pos_ids = torch.tensor([[0, 1, 2]], device=device) + dummy_ctx_len = torch.tensor([3], dtype=torch.int32, device=device) + dummy_kvs = [ + torch.zeros( + 1, 2, num_kv_heads, max_seq_len, head_dim, dtype=dtype, device=device + ) + for _ in range(num_layers) + ] + + # Dynamic shapes + seq_len_dim = torch.export.Dim("seq_len", min=1, max=max_seq_len) + kv_cache_dynamics = [{}] * num_layers + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "position_ids": {1: seq_len_dim}, + "kv_caches": kv_cache_dynamics, + "ctx_len": {}, + } + + # Export + ep = torch.export.export( + model, + args=(dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # Compile + with torch_tensorrt.dynamo.Debugger() if debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len], + use_explicit_typing=True, + use_fp32_acc=True, + device=device, + disable_tf32=True, + min_block_size=1, + ) + + return trt_model + + +# ----------------------------------------------------------------------------- +# KV Cache Utilities +# ----------------------------------------------------------------------------- + + +def create_kv_caches( + config: Any, + max_seq_len: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype = torch.float16, +) -> List[torch.Tensor]: + """ + Create empty KV cache tensors for all layers. + + Args: + config: Model configuration. + max_seq_len: Maximum sequence length (capacity). + batch_size: Batch size. + device: Device to create tensors on. + dtype: Data type for the tensors. + + Returns: + List of KV cache tensors, one per layer. + """ + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + return [ + torch.zeros( + batch_size, + 2, + num_kv_heads, + max_seq_len, + head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + + +# ----------------------------------------------------------------------------- +# Generation Utilities +# ----------------------------------------------------------------------------- + + +def generate_with_plugin( + model_func: Callable, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + max_new_tokens: int, + eos_token_id: Optional[int] = None, + device: torch.device = torch.device("cuda:0"), +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Generate tokens using the plugin model. + + Args: + model_func: The compiled model function. + input_ids: Input token IDs [batch, seq_len]. + kv_caches: List of KV cache tensors. + max_new_tokens: Maximum number of new tokens to generate. + eos_token_id: EOS token ID for early stopping (optional). + device: Device for computation. + + Returns: + Tuple of (generated token IDs, updated KV caches). + """ + generated_ids = input_ids.clone() + seq_len = input_ids.shape[1] + + # Prefill + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0) + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=device) + + output = model_func(input_ids, position_ids, kv_caches, ctx_len) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + else: + logits = output + delta_kvs = [] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Check for EOS + if eos_token_id is not None and next_token.item() == eos_token_id: + return generated_ids, kv_caches + + # Decode + cur_pos = seq_len + + for _ in range(max_new_tokens - 1): + input_ids_step = next_token + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long, device=device) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=device) + + output = model_func(input_ids_step, position_ids_step, kv_caches, ctx_len_step) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + cur_pos += 1 + + # Check for EOS + if eos_token_id is not None and next_token.item() == eos_token_id: + break + + return generated_ids, kv_caches + + +def benchmark_plugin_generation( + model_func: Callable, + config: Any, + isl: int, + osl: int, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + run_name: str = "Plugin", +) -> float: + """ + Benchmark plugin model generation. + + Args: + model_func: The compiled model function. + config: Model configuration. + isl: Input sequence length. + osl: Output sequence length (number of tokens to generate). + max_seq_len: Maximum sequence length for KV cache. + device: Device for computation. + dtype: Data type. + run_name: Name for logging. + + Returns: + Elapsed time in milliseconds. + """ + # Check for extra kwargs the model might need + extra_kwargs = {} + if hasattr(model_func, "forward"): + sig = inspect.signature(model_func.forward) + if "arg_start_idx" in sig.parameters: + extra_kwargs["arg_start_idx"] = 0 + if "arg_end_idx" in sig.parameters: + extra_kwargs["arg_end_idx"] = 0 + + # Prepare inputs + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=device) + kv_caches = create_kv_caches(config, max_seq_len, 1, device, dtype) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + # Prefill + seq_len = isl + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0) + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=device) + + output = model_func(input_ids, position_ids, kv_caches, ctx_len, **extra_kwargs) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + else: + logits = output + delta_kvs = [] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + + # Decode + cur_pos = seq_len + + for _ in range(osl - 1): + input_ids_step = next_token + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long, device=device) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=device) + + output = model_func( + input_ids_step, position_ids_step, kv_caches, ctx_len_step, **extra_kwargs + ) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + cur_pos += 1 + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"{run_name} | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms diff --git a/tools/llm/plugin_utils_vit.py b/tools/llm/plugin_utils_vit.py new file mode 100644 index 0000000000..cb6b804e1f --- /dev/null +++ b/tools/llm/plugin_utils_vit.py @@ -0,0 +1,1025 @@ +""" +Plugin utilities for TensorRT ViT inference with custom attention plugins. + +This module provides Vision Transformer-specific utilities for using TensorRT +attention plugins with ViT models. Unlike LLMs, ViT models: +- Do not use KV caching (full bidirectional attention) +- Do not use RoPE (learnable/absolute position embeddings) +- Process fixed-size image patches at once +""" + +import ctypes +import os +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple + +import tensorrt as trt +import torch +import torch.nn as nn +import torch_tensorrt + +_TENSORRT_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_WORKSPACE_ROOT = os.path.dirname(_TENSORRT_REPO_ROOT) + +# Default plugin path for ViT attention plugin. TensorRT-Edge-LLM is checked out +# next to TensorRT in this workspace, not inside the TensorRT repository. +DEFAULT_PLUGIN_PATH = os.path.join( + _WORKSPACE_ROOT, + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", +) + +# Global configuration for ViT plugin converter +_VIT_PLUGIN_CONFIG: Dict[str, Any] = {} + +def load_plugin(plugin_path: Optional[str] = None) -> bool: + """ + Load the TensorRT attention plugin library. + + Args: + plugin_path: Path to the plugin .so file. If None, uses DEFAULT_PLUGIN_PATH. + + Returns: + True if plugin was loaded successfully, False otherwise. + + Raises: + RuntimeError: If plugin file does not exist. + """ + path = plugin_path or os.environ.get("TRT_EDGE_LLM_PLUGIN_PATH") or DEFAULT_PLUGIN_PATH + if not os.path.exists(path): + raise RuntimeError(f"Plugin not found at {path}") + ctypes.CDLL(path) + print(f"Loaded plugin: {path}") + return True + + +def set_vit_plugin_config( + num_attention_heads: int, + head_dim: int, + num_patches: int, + max_batch_size: int = 4, + mask_type: int = 0, + max_seq_len: int = 0, +) -> None: + """ + Set global configuration for the ViT plugin converter. + + Args: + num_attention_heads: Number of attention heads. + head_dim: Dimension of each attention head. + num_patches: Number of image patches (including [CLS] token). + max_batch_size: Maximum batch size. + mask_type: Plugin mask mode. 0=dense additive mask, 1=packed cu_seqlens. + max_seq_len: Maximum packed segment length for cu_seqlens FMHA. + """ + global _VIT_PLUGIN_CONFIG + _VIT_PLUGIN_CONFIG = { + "num_attention_heads": num_attention_heads, + "head_dim": head_dim, + "num_patches": num_patches, + "max_batch_size": max_batch_size, + "mask_type": mask_type, + "max_seq_len": max_seq_len, + } + +def get_vit_plugin_config() -> Dict[str, Any]: + """Get the current ViT plugin configuration.""" + return _VIT_PLUGIN_CONFIG.copy() + +def set_vit_plugin_config_from_model(model_config: Any) -> None: + """ + Set ViT plugin configuration from a HuggingFace vision model config. + + Args: + model_config: HuggingFace model configuration object. + """ + # HuggingFace vision configs use slightly different field names across + # families. Plain ViT uses num_attention_heads; Mllama/Llama Vision uses + # attention_heads. + num_heads = getattr(model_config, "num_attention_heads", None) or getattr( + model_config, "attention_heads" + ) + head_dim = model_config.hidden_size // num_heads + + # Calculate number of patches from image size + image_size = model_config.image_size + patch_size = model_config.patch_size + if isinstance(image_size, (tuple, list)): + image_h, image_w = image_size + else: + image_h = image_w = image_size + if isinstance(patch_size, (tuple, list)): + patch_h, patch_w = patch_size + else: + patch_h = patch_w = patch_size + num_patches = (image_h // patch_h) * (image_w // patch_w) + 1 # +1 for [CLS] + + set_vit_plugin_config( + num_attention_heads=num_heads, + head_dim=head_dim, + num_patches=num_patches, + ) + + +# ----------------------------------------------------------------------------- +# Plugin Op Registration +# ----------------------------------------------------------------------------- + +def _register_vit_plugin_op_impl() -> None: + """ + Internal implementation to register the tensorrt_vit::attention custom op for PyTorch. + + ViT attention differs from LLM attention: + - No KV cache - full bidirectional attention + - Simple fused QKV input + - Single output - no separate KV output + """ + + @torch.library.custom_op("tensorrt_vit::attention", mutates_args=()) + def attn( + qkv: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask_or_cu_seqlens: torch.Tensor, + num_heads: int, + head_dim: int, + qkv_fused: int = 1, + mask_type: int = 0, + max_seq_len: int = 0, + ) -> torch.Tensor: + """ + ViT attention operation. + + Args: + qkv: Fused [Q, K, V] tensor of shape [B, S, (H*D*3)]. + cos: RoPE cosine tensor of shape [S, D]. + sin: RoPE sine tensor of shape [S, D]. + mask_or_cu_seqlens: Dense additive mask or INT32 cu_seqlens. + num_heads: Number of attention heads. + head_dim: Dimension per head. + qkv_fused: Whether QKV is fused (1=yes, 0=no). + mask_type: 0 for dense additive mask, 1 for packed cu_seqlens. + max_seq_len: Max segment length when mask_type=1. + + Returns: + Attention output of shape [B, S, H*D]. + """ + batch_size, seq_len, _ = qkv.shape + output_dim = num_heads * head_dim + attn_out = torch.zeros( + batch_size, seq_len, output_dim, dtype=qkv.dtype, device=qkv.device + ) + return attn_out + + @torch.library.register_fake("tensorrt_vit::attention") + def _(qkv, cos, sin, mask_or_cu_seqlens, num_heads, head_dim, qkv_fused=1, mask_type=0, max_seq_len=0): + batch_size, seq_len, _ = qkv.shape + output_dim = num_heads * head_dim + attn_out = torch.empty( + batch_size, seq_len, output_dim, dtype=qkv.dtype, device=qkv.device + ) + return attn_out + + +def register_vit_plugin_op() -> None: + """ + Register the tensorrt_vit::attention custom op for PyTorch. + + This function is idempotent - safe to call multiple times. + """ + if hasattr(torch.ops, "tensorrt_vit") and hasattr( + torch.ops.tensorrt_vit, "attention" + ): + return + _register_vit_plugin_op_impl() + + +# Register the op at module import time so the converter decorator works +if not ( + hasattr(torch.ops, "tensorrt_vit") + and hasattr(torch.ops.tensorrt_vit, "attention") +): + _register_vit_plugin_op_impl() + +# Importing plugin_converter_vit at the bottom of this file registers the +# Torch-TensorRT converter for tensorrt_vit::attention. + +from plugin_converter_vit import ( # noqa: E402 (must be after op registration) + convert_vit_attention, + get_vit_plugin_conversion_count, + reset_vit_plugin_conversion_count, +) + +# ----------------------------------------------------------------------------- +# Plugin Attention Module +# ----------------------------------------------------------------------------- + +class ViTPluginAttention(nn.Module): + """ + Model-agnostic ViT attention wrapper using the TensorRT ViT attention plugin. + + The wrapper follows the same idea as the LLM plugin path: infer the attention + module layout from the original module instead of requiring a separate hand + written implementation for every model family. It supports common vision + attention layouts: + - fused QKV projection: qkv + proj (Qwen-VL style) + - separate Q/K/V: q_proj/k_proj/v_proj + o_proj (Mllama/Llama Vision style) + - HuggingFace ViT: query/key/value + output.dense + + RoPE is also inferred from the forward inputs. Models that pass + position_embeddings=(cos, sin) use those tensors; models without visual RoPE + get identity cos/sin tensors. + """ + + def __init__( + self, + original_attn: nn.Module, + config: Any, + layer_idx: int, + return_tuple: bool = False, + use_plugin_op: bool = True, + ): + super().__init__() + self.original_attn = original_attn + self.layer_idx = layer_idx + self.return_tuple = return_tuple + self.use_plugin_op = use_plugin_op + + self.projection_layout = self._detect_projection_layout(original_attn) + self.output_proj = self._detect_output_projection(original_attn) + self.num_heads = self._detect_num_heads(original_attn, config) + self.head_dim = self._detect_head_dim(original_attn, config, self.num_heads) + + def _detect_projection_layout(self, original_attn: nn.Module) -> str: + if hasattr(original_attn, "qkv"): + return "fused_qkv" + if all(hasattr(original_attn, name) for name in ("q_proj", "k_proj", "v_proj")): + return "separate_qkv" + if all(hasattr(original_attn, name) for name in ("query", "key", "value")): + return "hf_vit_qkv" + raise ValueError( + "Unsupported ViT attention projection layout. Expected qkv, " + "q_proj/k_proj/v_proj, or query/key/value projections." + ) + + def _detect_output_projection(self, original_attn: nn.Module) -> nn.Module: + for name in ("proj", "o_proj", "out_proj", "out"): + if hasattr(original_attn, name): + return getattr(original_attn, name) + if hasattr(original_attn, "output"): + output = original_attn.output + return output.dense if hasattr(output, "dense") else output + raise ValueError( + "Unsupported ViT attention output projection layout. Expected proj, " + "o_proj, out_proj, out, or output(.dense)." + ) + + def _detect_num_heads(self, original_attn: nn.Module, config: Any) -> int: + for source in (original_attn, config): + for name in ("num_heads", "attention_heads", "num_attention_heads"): + value = getattr(source, name, None) + if value is not None: + return int(value) + raise ValueError("Could not infer number of attention heads for ViT plugin.") + + def _detect_head_dim( + self, original_attn: nn.Module, config: Any, num_heads: int + ) -> int: + for source in (original_attn, config): + value = getattr(source, "head_dim", None) + if value is not None: + return int(value) + + hidden_size = None + for source in (config, original_attn): + for name in ("hidden_size", "embed_dim", "dim"): + value = getattr(source, name, None) + if value is not None: + hidden_size = int(value) + break + if hidden_size is not None: + break + if hidden_size is None: + raise ValueError("Could not infer hidden size for ViT plugin head_dim.") + return hidden_size // num_heads + + def _project_qkv(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.projection_layout == "fused_qkv": + return self.original_attn.qkv(hidden_states) + if self.projection_layout == "separate_qkv": + q = self.original_attn.q_proj(hidden_states) + k = self.original_attn.k_proj(hidden_states) + v = self.original_attn.v_proj(hidden_states) + return torch.cat([q, k, v], dim=-1) + + q = self.original_attn.query(hidden_states) + k = self.original_attn.key(hidden_states) + v = self.original_attn.value(hidden_states) + return torch.cat([q, k, v], dim=-1) + + def _get_rope_tensors( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_len = hidden_states.shape[-2] + if position_embeddings is None: + cos = torch.ones( + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + sin = torch.zeros_like(cos) + return cos, sin + + cos, sin = position_embeddings + if not self.use_plugin_op: + return cos.to(device=hidden_states.device), sin.to(device=hidden_states.device) + return ( + cos.to(device=hidden_states.device, dtype=hidden_states.dtype), + sin.to(device=hidden_states.device, dtype=hidden_states.dtype), + ) + + def _normalize_attention_mask( + self, + attention_mask: Optional[torch.Tensor], + batch_size: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + if attention_mask is None: + return torch.zeros(batch_size, seq_len, seq_len, dtype=dtype, device=device) + + if attention_mask.dim() == 1 and attention_mask.dtype == torch.int32: + return attention_mask + + attention_mask = attention_mask.to(dtype=dtype) + if attention_mask.dim() == 4: + if attention_mask.shape[1] == 1: + attention_mask = attention_mask[:, 0, :, :] + else: + attention_mask = attention_mask.reshape( + attention_mask.shape[0] * attention_mask.shape[1], + attention_mask.shape[2], + attention_mask.shape[3], + ) + return attention_mask + + def _rotate_half(self, x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _torch_attention( + self, + qkv: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + batch_size, seq_len, _ = qkv.shape + qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + q = (q * cos) + (self._rotate_half(q) * sin) + k = (k * cos) + (self._rotate_half(k) * sin) + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(v.dtype) + attn_out = torch.matmul(attn_weights, v) + attn_out = attn_out.transpose(1, 2).reshape( + batch_size, seq_len, self.num_heads * self.head_dim + ) + return attn_out + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seq_len: int = 0, + **kwargs, + ) -> torch.Tensor: + squeeze_batch = False + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + squeeze_batch = True + + batch_size, seq_len, _ = hidden_states.shape + qkv = self._project_qkv(hidden_states) + cos, sin = self._get_rope_tensors(hidden_states, position_embeddings) + attention_mask = self._normalize_attention_mask( + attention_mask, + batch_size, + seq_len, + hidden_states.dtype, + hidden_states.device, + ) + mask_type = 0 + if attention_mask.dim() == 1 and attention_mask.dtype == torch.int32: + mask_type = 1 + if max_seq_len <= 0: + max_seq_len = seq_len + + if self.use_plugin_op: + attn_out = torch.ops.tensorrt_vit.attention.default( + qkv, + cos, + sin, + attention_mask, + self.num_heads, + self.head_dim, + 1, + mask_type, + max_seq_len, + ) + else: + if mask_type == 1: + raise ValueError("PyTorch reference attention requires a dense mask.") + attn_out = self._torch_attention(qkv, cos, sin, attention_mask) + output = self.output_proj(attn_out) + output = output.squeeze(0) if squeeze_batch else output + if self.return_tuple: + return output, None + return output + + +# ----------------------------------------------------------------------------- +# Model Wrappers +# ----------------------------------------------------------------------------- + +VIT_INPUT_CONTRACT_NATIVE = "native" +VIT_INPUT_CONTRACT_WINDOWED_ROPE = "windowed_rope" +VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO = "tiled_aspect_ratio" + +def _require_tensor(value: Optional[torch.Tensor], name: str) -> torch.Tensor: + if value is None: + raise ValueError(f"ViT plugin forward requires {name}.") + return value + + +def _get_windowed_rope_visual_model(model: nn.Module) -> nn.Module: + if hasattr(model, "visual"): + return model.visual + if hasattr(model, "patch_embed") and hasattr(model, "blocks"): + return model + raise ValueError("Cannot find a windowed-RoPE visual backbone.") + + +def _get_windowed_rope_blocks(visual_model: nn.Module) -> nn.ModuleList: + if hasattr(visual_model, "blocks"): + return visual_model.blocks + raise ValueError("Cannot find windowed-RoPE visual blocks.") + + +def _forward_windowed_rope_vision( + model: nn.Module, + pixel_values: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + window_attention_mask: Optional[torch.Tensor] = None, + cu_window_seqlens: Optional[torch.Tensor] = None, + window_index: Optional[torch.Tensor] = None, + reverse_window_index: Optional[torch.Tensor] = None, + max_window_seq_len: int = 0, + **kwargs, +) -> torch.Tensor: + visual = _get_windowed_rope_visual_model(model) + rotary_pos_emb = _require_tensor(rotary_pos_emb, "rotary_pos_emb") + attention_mask = _require_tensor(attention_mask, "attention_mask") + window_attention_mask = _require_tensor( + window_attention_mask, "window_attention_mask" + ) + cu_window_seqlens = _require_tensor(cu_window_seqlens, "cu_window_seqlens") + window_index = _require_tensor(window_index, "window_index") + reverse_window_index = _require_tensor(reverse_window_index, "reverse_window_index") + + hidden_states = visual.patch_embed(pixel_values) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + blocks = _get_windowed_rope_blocks(visual) + for layer_idx, block in enumerate(blocks): + full_attention = layer_idx in visual.fullatt_block_indexes + attention_mask_now = attention_mask if full_attention else window_attention_mask + + residual = hidden_states + hidden_states = block.norm1(hidden_states) + hidden_states = block.attn( + hidden_states, + attention_mask=attention_mask_now, + position_embeddings=position_embeddings, + max_seq_len=0, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = block.norm2(hidden_states) + hidden_states = block.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = visual.merger(hidden_states) + hidden_states = hidden_states[reverse_window_index, :] + return hidden_states + + +def _forward_tiled_aspect_ratio_vision( + model: nn.Module, + pixel_values: torch.Tensor, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + aspect_ratio_ids = _require_tensor(aspect_ratio_ids, "aspect_ratio_ids") + attention_mask = _require_tensor(attention_mask, "attention_mask") + vision = model.vision_model if hasattr(model, "vision_model") else model + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, + num_channels, + height, + width, + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + target_dtype = vision.patch_embedding.weight.dtype + target_device = vision.patch_embedding.weight.device + patch_embeds = vision.patch_embedding( + pixel_values.to(target_device, target_dtype) + ) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + -1, + dim, + ) + hidden_state = vision.pre_tile_positional_embedding( + hidden_state, + aspect_ratio_ids, + ) + + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, + num_patches, + dim, + ) + hidden_state = vision.apply_class_embedding(hidden_state) + num_patches += 1 + + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches, + dim, + ) + hidden_state = vision.gated_positional_embedding( + hidden_state, + aspect_ratio_ids, + ) + hidden_state = vision.layernorm_pre(hidden_state) + + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + hidden_state = torch.nn.functional.pad( + hidden_state, + (0, 0, 0, num_padding_patches), + mode="constant", + value=0, + ) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = vision.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = output.last_hidden_state + hidden_state = vision.layernorm_post(hidden_state) + + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = vision.post_tile_positional_embedding( + hidden_state, + aspect_ratio_ids, + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + global_output = vision.global_transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = global_output.last_hidden_state + + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape( + batch_size, + num_concurrent_media, + num_tiles, + num_patches, + dim, + ) + + all_intermediate_hidden_states = [ + output.hidden_states[i] for i in vision.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack( + all_intermediate_hidden_states, + dim=-1, + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, + num_concurrent_media, + num_tiles, + num_patches, + -1, + ) + + return torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + +def _forward_native_vision( + model: nn.Module, + pixel_values: torch.Tensor, + **kwargs, +) -> torch.Tensor: + output = model(pixel_values) + if hasattr(output, "last_hidden_state"): + return output.last_hidden_state + if isinstance(output, (tuple, list)): + return output[0] + return output + + +class ViTPluginWrapper(nn.Module): + """ + Generic wrapper for vision models with plugin attention. + + The caller chooses the tensor input contract and provides the corresponding + tensors during export/runtime. The contract describes the vision tower's + tensor interface, not a concrete model name. + """ + + def __init__( + self, + model: nn.Module, + input_contract: str = VIT_INPUT_CONTRACT_NATIVE, + max_window_seq_len: int = 0, + ): + super().__init__() + self.model = model + self.input_contract = input_contract + self.max_window_seq_len = max_window_seq_len + + def forward( + self, + pixel_values: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + window_attention_mask: Optional[torch.Tensor] = None, + cu_window_seqlens: Optional[torch.Tensor] = None, + window_index: Optional[torch.Tensor] = None, + reverse_window_index: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.input_contract == VIT_INPUT_CONTRACT_WINDOWED_ROPE: + return _forward_windowed_rope_vision( + self.model, + pixel_values, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + window_attention_mask=window_attention_mask, + cu_window_seqlens=cu_window_seqlens, + window_index=window_index, + reverse_window_index=reverse_window_index, + max_window_seq_len=self.max_window_seq_len, + ) + if self.input_contract == VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO: + return _forward_tiled_aspect_ratio_vision( + self.model, + pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + ) + if self.input_contract == VIT_INPUT_CONTRACT_NATIVE: + return _forward_native_vision(self.model, pixel_values) + raise ValueError(f"Unsupported ViT plugin input contract: {self.input_contract}") + + +# ----------------------------------------------------------------------------- +# Model Modification Functions +# ----------------------------------------------------------------------------- + + +def replace_vit_attention_with_plugin( + model: nn.Module, + config: Any, + use_plugin_op: bool = True, +) -> nn.Module: + """ + Replace all supported vision attention modules with plugin attention. + + This is the vision-side equivalent of the LLM helper: callers use one + replacement entry point, and the function detects the model structure: + - Qwen-VL visual blocks: ``blocks[*].attn`` + - Mllama/Llama Vision stacks: ``transformer/global_transformer.layers[*].self_attn`` + - HF ViT-style encoders: ``encoder.layer[*].attention.self`` + + Args: + model: The HuggingFace vision model or visual tower to modify. + config: Model configuration. + + Returns: + The modified model with plugin attention. + """ + replacement_count = 0 + + # Qwen2.5-VL visual tower: model.visual.blocks or visual.blocks. + visual_model = model.visual if hasattr(model, "visual") else model + if hasattr(visual_model, "blocks"): + for i, block in enumerate(visual_model.blocks): + if hasattr(block, "attn"): + block.attn = ViTPluginAttention( + block.attn, config, i, use_plugin_op=use_plugin_op + ) + replacement_count += 1 + if replacement_count: + return model + + # Mllama is HuggingFace's architecture name for official Meta Llama 3.2 + # Vision models. Its self_attn returns (hidden_state, attn_weights), so + # self_attn replacements ask the generic plugin wrapper to return a tuple. + vision_model = model.vision_model if hasattr(model, "vision_model") else model + layer_idx = 0 + for encoder_name in ("transformer", "global_transformer"): + encoder = getattr(vision_model, encoder_name, None) + if encoder is None or not hasattr(encoder, "layers"): + continue + + for layer in encoder.layers: + if hasattr(layer, "self_attn"): + layer.self_attn = ViTPluginAttention( + layer.self_attn, + config, + layer_idx, + return_tuple=True, + ) + layer_idx += 1 + replacement_count += 1 + + if layer_idx: + return model + + # HF SigLIP/SigLIP2-style tower: model.vision_model.encoder.layers or + # model.encoder.layers. These attention modules return + # (hidden_state, attn_weights), so the replacement returns a tuple. + if hasattr(vision_model, "encoder") and hasattr(vision_model.encoder, "layers"): + for i, layer in enumerate(vision_model.encoder.layers): + if hasattr(layer, "self_attn"): + layer.self_attn = ViTPluginAttention( + layer.self_attn, + config, + i, + return_tuple=True, + use_plugin_op=use_plugin_op, + ) + replacement_count += 1 + + if replacement_count: + return model + + # HF ViT-style tower: model.vision_model.encoder.layer or model.encoder.layer. + if hasattr(vision_model, "encoder") and hasattr(vision_model.encoder, "layer"): + for i, layer in enumerate(vision_model.encoder.layer): + if hasattr(layer, "attention"): + layer.attention.self = ViTPluginAttention( + layer.attention.self, config, i, use_plugin_op=use_plugin_op + ) + replacement_count += 1 + + if replacement_count == 0: + raise ValueError("Cannot find supported ViT attention modules") + + return model + + +def count_vit_plugin_attention_modules(model: nn.Module) -> int: + """Count ViT attention modules replaced with the plugin wrapper.""" + return sum(1 for module in model.modules() if isinstance(module, ViTPluginAttention)) + +# ----------------------------------------------------------------------------- +# Compilation +# ----------------------------------------------------------------------------- + +def compile_vit_plugin_model( + model: nn.Module, + example_inputs: Optional[Tuple[torch.Tensor, ...]], + device: torch.device, + example_kwargs: Optional[Dict[str, torch.Tensor]] = None, + dynamic_shapes: Optional[Dict[str, Any]] = None, + debug: bool = False, +) -> Callable: + """ + Compile a ViT/VLM visual wrapper with plugin attention. + + Model-specific wrappers own input preparation and forward signatures. This + helper owns the shared torch.export -> Torch-TensorRT compile path. + + Args: + model: The vision wrapper or model to export. + example_inputs: Example tensor inputs matching ``model.forward``. + example_kwargs: Optional named tensor inputs matching ``model.forward``. + dynamic_shapes: Optional torch.export dynamic shape spec. + device: Device for compilation. + debug: Whether to enable debug logging. + + Returns: + Compiled TensorRT model function. + """ + if dynamic_shapes is None: + dynamic_shapes = {} + if example_inputs is None: + example_inputs = () + if example_kwargs is None: + example_kwargs = {} + if dynamic_shapes: + dynamic_shapes = { + name: shape + for name, shape in dynamic_shapes.items() + if isinstance(example_kwargs.get(name), torch.Tensor) + } + + ep = torch.export.export( + model, + args=example_inputs, + kwargs=example_kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + compile_inputs = list(example_inputs) + [ + value for value in example_kwargs.values() if isinstance(value, torch.Tensor) + ] + with torch_tensorrt.dynamo.Debugger() if debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=compile_inputs, + use_explicit_typing=True, + use_fp32_acc=True, + device=device, + disable_tf32=True, + min_block_size=1, + ) + + return trt_model + + +# ----------------------------------------------------------------------------- +# Inference Utilities +# ----------------------------------------------------------------------------- + + +def inference_vit_plugin( + model_func: Callable, + pixel_values: torch.Tensor, +) -> torch.Tensor: + """ + Run inference on a compiled ViT plugin model. + + Args: + model_func: The compiled TensorRT model function. + pixel_values: Input images [batch, channels, height, width]. + + Returns: + Model output (logits or embeddings depending on model). + """ + return model_func(pixel_values) + + +# Benchmark utilities + +def measure_vit_latency( + fn: Callable, + num_warmup: int = 5, + num_runs: int = 10, +) -> Tuple[float, float, float]: + """ + Measure function latency with GPU synchronization. + + Args: + fn: Function to benchmark. + num_warmup: Number of warmup runs. + num_runs: Number of timing runs. + + Returns: + Tuple of (mean_latency_ms, std_latency_ms, median_latency_ms). + """ + import statistics + + # Warmup + for _ in range(num_warmup): + fn() + + torch.cuda.synchronize() + times = [] + + for _ in range(num_runs): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + fn() + end_event.record() + torch.cuda.synchronize() + + times.append(start_event.elapsed_time(end_event)) + + mean_time = statistics.mean(times) + stdev_time = statistics.stdev(times) if len(times) > 1 else 0.0 + median_time = statistics.median(times) + + return mean_time, stdev_time, median_time + + +def measure_vit_memory( + model: nn.Module, + pixel_values: torch.Tensor, +) -> Tuple[float, float]: + """ + Measure model memory usage. + + Args: + model: The model. + pixel_values: Sample input. + + Returns: + Tuple of (peak_memory_mb, reserved_memory_mb). + """ + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + with torch.no_grad(): + _ = model(pixel_values) + + torch.cuda.synchronize() + peak_memory = torch.cuda.max_memory_allocated() / 1e6 + reserved_memory = torch.cuda.memory_reserved() / 1e6 + + return peak_memory, reserved_memory + + +# Importing this module registers the Torch-TensorRT converter for +# tensorrt_vit::attention, matching the LLM plugin_utils/plugin_converter split. +from plugin_converter_vit import convert_vit_attention # noqa: F401,E402 diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7f2ecf1273..9a78b68cb7 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -19,12 +19,18 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch import torch_tensorrt -from modelopt.torch.quantization.utils import export_torch_mode -from quantize_utils import ( - convert_linear_to_tensorrt_quantized, - load_quantization_config, - quantize_model, -) + +try: + from modelopt.torch.quantization.utils import export_torch_mode + from quantize_utils import ( + convert_linear_to_tensorrt_quantized, + load_quantization_config, + quantize_model, + ) + + QUANTIZATION_AVAILABLE = True +except ImportError: + QUANTIZATION_AVAILABLE = False from torchtrt_ext import register_sdpa from transformers import AutoModelForCausalLM, AutoTokenizer from utils import ( @@ -35,6 +41,24 @@ time_generate, ) +# Import plugin utilities (optional) +try: + from plugin_utils import ( + LLMPluginWrapper, + benchmark_plugin_generation, + compile_plugin_model, + create_kv_caches, + generate_with_plugin, + load_plugin, + register_plugin_op, + replace_attention_with_plugin, + set_plugin_config_from_model, + ) + + PLUGIN_AVAILABLE = True +except ImportError as e: + PLUGIN_AVAILABLE = False + DEVICE = torch.device("cuda:0") @@ -56,31 +80,40 @@ def get_model(args): moved to CUDA device with the specified precision """ with torch.no_grad(): + # For plugin backend, we don't set attn_implementation + attn_impl_kwargs = {} + if args.backend in ("sdpa", "iattention"): + attn_impl_kwargs["attn_implementation"] = "sdpa" + model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, - attn_implementation="sdpa", ignore_mismatched_sizes=True, + **attn_impl_kwargs, ) .eval() .cuda() ) - # register SDPA variant for the model - register_sdpa.enable_sdpa_converter(args.model, model.config) - - hf_quant_config = load_quantization_config(args.model) - if hf_quant_config: - model = convert_linear_to_tensorrt_quantized( - model, args.model_precision, hf_quant_config - ).cuda() - print( - f"Model is {hf_quant_config['quant_algo']} pre-quantized hf model. Quantized linear layers are applied" - ) - if args.quant_format: - raise RuntimeError( - f"Quantization cannot be applied for pre-quantized hf model" + # Register SDPA lowering pass only for sdpa backend. + # For iattention backend, the core TRT IAttention converters handle SDPA ops + # directly without needing the custom lowering pass. + if args.backend == "sdpa": + register_sdpa.enable_sdpa_converter(args.model, model.config) + + if QUANTIZATION_AVAILABLE: + hf_quant_config = load_quantization_config(args.model) + if hf_quant_config: + model = convert_linear_to_tensorrt_quantized( + model, args.model_precision, hf_quant_config + ).cuda() + print( + f"Model is {hf_quant_config['quant_algo']} pre-quantized hf model. Quantized linear layers are applied" ) + if args.quant_format: + raise RuntimeError( + f"Quantization cannot be applied for pre-quantized hf model" + ) if args.model_precision == "FP16": model = model.to(torch.float16) @@ -114,19 +147,32 @@ def compile_torchtrt(model, input_ids, args): for optimized inference """ max_seq_len = input_ids.shape[1] + args.num_tokens - with export_torch_mode(): + if QUANTIZATION_AVAILABLE: + with export_torch_mode(): + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + else: ep = export_llm(model, input_ids, max_seq_len=max_seq_len) position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) # Set precision specific flags use_fp32_acc = False + use_explicit_typing = False if args.model_precision == "FP16": + enabled_precisions = {torch.float32} use_fp32_acc = True + use_explicit_typing = True + elif args.model_precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_model = torch_tensorrt.dynamo.compile( ep, inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, # truncate_double=True, + use_explicit_typing=use_explicit_typing, use_fp32_acc=use_fp32_acc, device=DEVICE, disable_tf32=True, @@ -207,6 +253,15 @@ def measure_perf(trt_model, input_signature, backend_name): default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32", ) + arg_parser.add_argument( + "--backend", + type=str, + default="sdpa", + help="Backend to use. Options: sdpa, iattention, plugin. " + "'sdpa' uses custom SDPA lowering pass + converter (matmul+softmax+matmul). " + "'iattention' uses TensorRT native IAttention layer (no KV cache support yet). " + "'plugin' uses TensorRT Edge-LLM attention plugin with built-in KV cache.", + ) arg_parser.add_argument( "--iterations", type=int, default=5, help="no. of iterations to run" ) @@ -265,6 +320,26 @@ def measure_perf(trt_model, input_signature, backend_name): ) args = arg_parser.parse_args() + # Validate arguments + if args.backend not in ("sdpa", "iattention", "plugin"): + raise ValueError( + f"Unknown backend '{args.backend}'. Options: sdpa, iattention, plugin" + ) + if args.backend == "plugin" and not PLUGIN_AVAILABLE: + raise RuntimeError( + "Plugin backend requested but plugin utilities are not available." + ) + if args.cache and args.backend == "plugin": + print("Warning: --cache is only applicable with 'sdpa' backend. Ignoring.") + args.cache = "" + if args.cache and args.backend == "iattention": + print( + "Warning: --cache is not supported with 'iattention' backend " + "(static_cache passes are incompatible with native IAttention converters). " + "Ignoring --cache." + ) + args.cache = "" + with torch.inference_mode(): model = get_model(args) @@ -289,7 +364,11 @@ def measure_perf(trt_model, input_signature, backend_name): pyt_timings = None pyt_stats = None - if args.quant_format != None: + if args.quant_format is not None: + if not QUANTIZATION_AVAILABLE: + raise RuntimeError( + "Quantization requested but modelopt is not installed." + ) model = quantize_model(model, args, tokenizer) if args.enable_pytorch_run: pyt_gen_tokens = generate( @@ -312,54 +391,117 @@ def measure_perf(trt_model, input_signature, backend_name): compile_time_s=None, ) - if args.cache == "static_v1": - # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 - if args.cache == "static_v2": - # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 + # Backend selection: sdpa, iattention, or plugin + if args.backend == "plugin": + # Plugin backend + if not PLUGIN_AVAILABLE: + raise RuntimeError("Plugin backend requested but not available") + + dtype = ( + torch.float16 + if args.model_precision == "FP16" + else ( + torch.bfloat16 if args.model_precision == "BF16" else torch.float32 + ) + ) + config = model.config + max_seq_len = max(2048, MAX_OUTPUT_SEQ_LENGTH) + + # Load plugin and register op + load_plugin() + register_plugin_op() + set_plugin_config_from_model(config, max_seq_len) - # Compile the model with Torch-TensorRT - trt_model = compile_torchtrt(model, input_ids, args) + # Replace attention with plugin + model = replace_attention_with_plugin( + model, config, max_seq_len, DEVICE, dtype + ) + wrapper = LLMPluginWrapper(model) - if args.cache == "static_v1" or args.cache == "static_v2": - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) + # Compile plugin model + trt_model = compile_plugin_model( + wrapper, config, max_seq_len, DEVICE, dtype, args.debug + ) - trt_gen_tokens = generate_with_static_cache( + # Create KV caches + kv_caches = create_kv_caches( + config, max_seq_len, args.batch_size, DEVICE, dtype + ) + + # Generate + trt_gen_tokens, _ = generate_with_plugin( trt_model, input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, + kv_caches, + args.num_tokens, tokenizer.eos_token_id, + DEVICE, ) if args.benchmark: - trt_timings = time_generate( - generate_with_static_cache, + trt_timings = [] + for i in range(args.iterations): + elapsed_ms = benchmark_plugin_generation( + trt_model, + config, + input_ids.shape[1], + args.num_tokens, + max_seq_len, + DEVICE, + dtype, + ) + trt_timings.append(elapsed_ms / 1000.0) + else: + # SDPA or IAttention backend + # For iattention, args.cache is already cleared by validation above. + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - iterations=args.iterations, ) - else: - trt_gen_tokens = generate( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate, + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - iterations=args.iterations, ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) if args.benchmark: trt_stats = record_stats( @@ -387,4 +529,4 @@ def measure_perf(trt_model, input_signature, backend_name): print(pyt_stats) print("===================== \n") print("=========TensorRT PERFORMANCE============ \n") - print(trt_stats) + print(trt_stats) \ No newline at end of file diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index 88aa21eeea..caeb674cb3 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -30,15 +30,15 @@ import argparse import copy import os -import sys +from types import SimpleNamespace from contextlib import nullcontext -from typing import Tuple +from typing import Any, Dict, Tuple, TypedDict import requests import torch import torch_tensorrt from PIL import Image -from transformers import AutoModel, AutoProcessor +from transformers import AutoConfig, AutoModel, AutoProcessor, PreTrainedModel from transformers.models.qwen2 import modeling_qwen2 as mq from transformers.models.siglip import modeling_siglip as ms from utils import ( @@ -47,9 +47,34 @@ generate_mm_qwen2_5_vl, generate_mm_qwen2_5_vl_with_static_cache, generate_mm_with_static_cache, + get_qwen_image_embeds, + get_qwen_position_ids, record_stats, ) +# Import ViT plugin utilities (optional) +try: + from plugin_utils_vit import ( + ViTPluginAttention, + ViTPluginWrapper, + VIT_INPUT_CONTRACT_NATIVE, + VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO, + VIT_INPUT_CONTRACT_WINDOWED_ROPE, + compile_vit_plugin_model, + count_vit_plugin_attention_modules, + get_vit_plugin_conversion_count, + get_vit_plugin_config, + load_plugin as load_vit_plugin, + register_vit_plugin_op, + replace_vit_attention_with_plugin, + reset_vit_plugin_conversion_count, + set_vit_plugin_config, + ) + + VIT_PLUGIN_AVAILABLE = True +except ImportError: + VIT_PLUGIN_AVAILABLE = False + # --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- # Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2" # due to settings in its remote code and config.json. This prevents direct @@ -79,77 +104,494 @@ # -----------------------------------------------------------------------------# -def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): - """ - Load nvidia/Eagle2-2B model and processor, ensuring the language model uses SDPA. +def _is_qwen2_5_vl(model_name: str) -> bool: + return "qwen2.5-vl" in model_name.lower() - Returns - ------- - tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] - The model, its processor and the language-model input embedding layer. + +def _is_eagle2(model_name: str) -> bool: + return "eagle2" in model_name.lower() + + +def _patch_transformers_tied_weights_compat() -> None: """ - model_id = "nvidia/Eagle2-2B" - try: - with torch.no_grad(): - model = ( - AutoModel.from_pretrained( - model_id, - trust_remote_code=True, - torch_dtype=torch_dtype, - # attn_implementation="sdpa" is ignored due to the model's remote code. - ) - .eval() - .to(device) + Some remote VLM classes still expose the older `_tied_weights_keys` + metadata, while newer Transformers loading code expects + `all_tied_weights_keys`. Add a conservative compatibility property on the + base class before model construction. + """ + existing_attr = getattr(PreTrainedModel, "all_tied_weights_keys", None) + if isinstance(existing_attr, property) and existing_attr.fset is not None: + return + if existing_attr is not None and not isinstance(existing_attr, property): + return + + def normalize_tied_weights_keys(value): + if value is None: + return {} + if isinstance(value, dict): + return value + return {key: None for key in value} + + def all_tied_weights_keys(self): + if "_all_tied_weights_keys_compat" in self.__dict__: + return normalize_tied_weights_keys( + self.__dict__["_all_tied_weights_keys_compat"] ) - except ImportError as e: - if "flash_attn" in str(e): - raise ImportError( - "FlashAttention2 is required for Eagle2 models but not installed. " - "Please install it using: pip install flash-attn --no-build-isolation -v" - ) from e - raise + return normalize_tied_weights_keys(getattr(self, "_tied_weights_keys", None)) - processor = AutoProcessor.from_pretrained( - model_id, trust_remote_code=True, use_fast=True + def set_all_tied_weights_keys(self, value): + self.__dict__["_all_tied_weights_keys_compat"] = value + + PreTrainedModel.all_tied_weights_keys = property( + all_tied_weights_keys, set_all_tied_weights_keys ) - if hasattr(processor, "tokenizer"): - processor.tokenizer.padding_side = "left" - emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) - return model, processor, emb_layer + +def _patch_transformers_image_utils_compat() -> None: + """ + TODO: Eagle2 remote processor code imports image/video helper APIs that are + not exported by the Transformers builds used in this environment. Remove + this once Eagle pins/updates its remote code or we pin a known-compatible + Transformers version. If this remains necessary, raise an upstream bug. + """ + import transformers.image_utils as image_utils + + if not hasattr(image_utils, "VideoInput"): + image_utils.VideoInput = Any + if not hasattr(image_utils, "make_batched_videos"): + def make_batched_videos(videos): + if videos is None: + return None + if isinstance(videos, (list, tuple)): + return list(videos) + return [videos] + + image_utils.make_batched_videos = make_batched_videos -def _load_qwen2_5_vl(device, torch_dtype: torch.dtype): +def _patch_transformers_fast_image_processor_compat() -> None: """ - Load Qwen2.5-VL model and processor. + TODO: Eagle2's fast image processor remote code imports docstring constants + from Transformers that are not exported by the builds used here. Remove + this after the Eagle remote-code / Transformers API mismatch is resolved. """ - from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + import transformers.image_processing_utils_fast as fast_image_utils - model_id = "Qwen/Qwen2.5-VL-3B-Instruct" - model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - model_id, torch_dtype=torch_dtype, device_map=device - ).eval() - processor = AutoProcessor.from_pretrained(model_id) - emb_layer = model.model.get_input_embeddings().to(torch_dtype).to(device) - return model, processor, emb_layer + if not hasattr(fast_image_utils, "BASE_IMAGE_PROCESSOR_FAST_DOCSTRING"): + fast_image_utils.BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = "" + if not hasattr(fast_image_utils, "BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS"): + fast_image_utils.BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = "" + if not hasattr(fast_image_utils, "DefaultFastImageProcessorKwargs"): + class DefaultFastImageProcessorKwargs(TypedDict, total=False): + pass + + fast_image_utils.DefaultFastImageProcessorKwargs = DefaultFastImageProcessorKwargs + + +def _patch_siglip_flash_attention_compat() -> None: + """ + Eagle2's remote config can still request FlashAttention2 for the SigLIP + vision tower even after the parent config is patched. The SigLIP attention + registry above maps flash_attention_2 to SDPA, so bypass only the package + availability check that runs during model construction. + """ + + def flash_attention_check_compat(self, *args, **kwargs): + config = getattr(self, "config", None) + if config is not None: + _set_config_attn_implementation(config, "sdpa") + return True + + def get_correct_attn_implementation_compat(self, *args, **kwargs): + config = getattr(self, "config", None) + if config is not None: + _set_config_attn_implementation(config, "sdpa") + return "sdpa" + + for class_name in ("SiglipPreTrainedModel", "SiglipVisionModel"): + cls = getattr(ms, class_name, None) + if cls is None: + continue + if hasattr(cls, "get_correct_attn_implementation"): + setattr( + cls, + "get_correct_attn_implementation", + get_correct_attn_implementation_compat, + ) + for method_name in ("_flash_attn_2_can_dispatch", "_flash_attn_can_dispatch"): + if hasattr(cls, method_name): + setattr(cls, method_name, flash_attention_check_compat) + + +def _model_loader_candidates(): + candidates = [] + for class_name in ("AutoModelForImageTextToText", "AutoModelForVision2Seq"): + try: + module = __import__("transformers", fromlist=[class_name]) + candidates.append(getattr(module, class_name)) + except (ImportError, AttributeError): + pass + candidates.append(AutoModel) + return candidates + + +def _effective_attn_implementation(args: argparse.Namespace): + if args.attn_implementation: + return args.attn_implementation + if args.vision_backend == "torchtrt" or _is_eagle2(args.model): + return "sdpa" + return None + + +def _set_config_attn_implementation(config, attn_implementation: str) -> None: + visited = set() + attn_attr_names = { + "attn_implementation", + "_attn_implementation", + "_attn_implementation_internal", + } + + def is_attn_attr(name: str) -> bool: + return name in attn_attr_names or name.endswith("_attn_implementation") + + def set_attn_attr(config_obj, attr_name: str) -> None: + try: + setattr(config_obj, attr_name, attn_implementation) + except Exception: + pass + try: + config_obj.__dict__[attr_name] = attn_implementation + except Exception: + pass + + def visit(config_obj): + if config_obj is None or id(config_obj) in visited: + return + visited.add(id(config_obj)) + + if isinstance(config_obj, dict): + for key, value in list(config_obj.items()): + if is_attn_attr(str(key)): + config_obj[key] = attn_implementation + else: + visit(value) + return + + if isinstance(config_obj, (list, tuple)): + for value in config_obj: + visit(value) + return + + if isinstance(config_obj, (str, bytes, int, float, bool)): + return + + for attr_name in attn_attr_names: + set_attn_attr(config_obj, attr_name) + + config_dict = getattr(config_obj, "__dict__", None) + if not isinstance(config_dict, dict): + return + + for child_name, child_value in list(config_dict.items()): + if is_attn_attr(child_name): + set_attn_attr(config_obj, child_name) + else: + visit(child_value) + + visit(config) + + +def _load_model_config(args: argparse.Namespace): + config = AutoConfig.from_pretrained( + args.model, + trust_remote_code=args.trust_remote_code, + ) + attn_implementation = _effective_attn_implementation(args) + if attn_implementation: + _set_config_attn_implementation(config, attn_implementation) + return config + + +def _from_pretrained_kwargs(args: argparse.Namespace, torch_dtype: torch.dtype): + kwargs = { + "torch_dtype": torch_dtype, + "trust_remote_code": args.trust_remote_code, + "config": _load_model_config(args), + } + attn_implementation = _effective_attn_implementation(args) + if attn_implementation: + kwargs["attn_implementation"] = attn_implementation + return kwargs + + +_VISION_MODULE_ATTRS = ("vision_model", "visual", "vision_tower", "vision_encoder") + + +def _is_windowed_rope_vision_module(module: torch.nn.Module) -> bool: + return ( + hasattr(module, "patch_embed") + and hasattr(module, "blocks") + and hasattr(module, "rot_pos_emb") + and hasattr(module, "get_window_index") + ) + + +def _is_merged_windowed_rope_vision_module(module: torch.nn.Module) -> bool: + return _is_windowed_rope_vision_module(module) and hasattr(module, "merger") + + +def _is_tiled_aspect_ratio_vision_module(module: torch.nn.Module) -> bool: + return ( + hasattr(module, "patch_embedding") + and hasattr(module, "global_transformer") + and hasattr(module, "pre_tile_positional_embedding") + ) + + +def _is_native_vit_vision_module(module: torch.nn.Module) -> bool: + has_patch_embedding = any( + hasattr(module, attr_name) + for attr_name in ("patch_embed", "patch_embedding", "embeddings") + ) + has_transformer = any( + hasattr(module, attr_name) + for attr_name in ("blocks", "encoder", "transformer", "global_transformer") + ) + return has_patch_embedding and has_transformer + + +def _is_vision_module(module: torch.nn.Module) -> bool: + return ( + _is_windowed_rope_vision_module(module) + or _is_tiled_aspect_ratio_vision_module(module) + or _is_native_vit_vision_module(module) + ) + + +def _contains_vision_module(module: torch.nn.Module) -> bool: + for attr_name in _VISION_MODULE_ATTRS: + if isinstance(getattr(module, attr_name, None), torch.nn.Module): + return True + + for _, child in module.named_modules(): + if child is module: + continue + if _is_vision_module(child): + return True + return False + + +_LANGUAGE_MODULE_ATTRS = ( + "language_model", + "text_model", + "llm", + "decoder", + "model", +) -def load_model( - model_name: str, device: torch.device, torch_dtype: torch.dtype +def _find_language_module(model: torch.nn.Module) -> Tuple[str, torch.nn.Module]: + for attr_name in _LANGUAGE_MODULE_ATTRS: + candidate = getattr(model, attr_name, None) + if not isinstance(candidate, torch.nn.Module): + continue + if attr_name == "model" and _contains_vision_module(candidate): + continue + return attr_name, candidate + + for module_name, module in model.named_modules(): + if module is model: + continue + if module_name.rsplit(".", 1)[-1] not in _LANGUAGE_MODULE_ATTRS: + continue + if _contains_vision_module(module): + continue + return module_name, module + + raise ValueError( + "Cannot find a language-model submodule. Expected a language/text " + "model leaf module that does not also contain the vision tower." + ) + + +def get_language_model(model: torch.nn.Module) -> torch.nn.Module: + _, language_model = _find_language_module(model) + return language_model + + +def set_language_model(model: torch.nn.Module, language_model: torch.nn.Module) -> None: + module_name, _ = _find_language_module(model) + _set_module_by_name(model, module_name, language_model) + + +def _find_vision_module(model: torch.nn.Module) -> Tuple[str, torch.nn.Module]: + visual = getattr(model, "visual", None) + if isinstance(visual, torch.nn.Module): + return "visual", visual + + for parent_attr in ("model", "language_model"): + parent = getattr(model, parent_attr, None) + if not isinstance(parent, torch.nn.Module): + continue + visual = getattr(parent, "visual", None) + if isinstance(visual, torch.nn.Module): + return f"{parent_attr}.visual", visual + + for module_name, module in model.named_modules(): + if module is model: + continue + if _is_merged_windowed_rope_vision_module(module): + return module_name, module + + for attr_name in _VISION_MODULE_ATTRS: + if attr_name == "visual": + continue + candidate = getattr(model, attr_name, None) + if isinstance(candidate, torch.nn.Module): + return attr_name, candidate + + for module_name, module in model.named_modules(): + if module is model: + continue + if _is_vision_module(module): + return module_name, module + + raise ValueError( + "Cannot find a vision-model submodule. Expected a Hugging Face vision " + "tower alias or a module with ViT-like patch embedding and transformer " + "blocks/encoder." + ) + + +def _set_child_module(parent: torch.nn.Module, child_name: str, child: torch.nn.Module): + if child_name.isdigit() and isinstance( + parent, (torch.nn.ModuleList, torch.nn.Sequential) + ): + parent[int(child_name)] = child + else: + setattr(parent, child_name, child) + + +def _set_module_by_name( + model: torch.nn.Module, module_name: str, module: torch.nn.Module +) -> None: + if not module_name: + raise ValueError("Cannot replace the root model with a vision module.") + + path = module_name.split(".") + parent = model + for child_name in path[:-1]: + parent = getattr(parent, child_name) + _set_child_module(parent, path[-1], module) + + +def get_vision_model(model: torch.nn.Module) -> torch.nn.Module: + _, vision_model = _find_vision_module(model) + return vision_model + + +def set_vision_model(model: torch.nn.Module, vision_model: torch.nn.Module) -> None: + module_name, _ = _find_vision_module(model) + _set_module_by_name(model, module_name, vision_model) + + +def get_input_embedding_layer( + model: torch.nn.Module, torch_dtype: torch.dtype, device: torch.device +) -> torch.nn.Embedding: + language_model = get_language_model(model) + if hasattr(language_model, "get_input_embeddings"): + return language_model.get_input_embeddings().to(torch_dtype).to(device) + if hasattr(model, "get_input_embeddings"): + return model.get_input_embeddings().to(torch_dtype).to(device) + raise ValueError("Cannot find an input embedding layer for this VLM.") + + +def get_model( + args: argparse.Namespace, device: torch.device, torch_dtype: torch.dtype ) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: - """Dispatch helper for supported VLMs.""" - if model_name == "nvidia/Eagle2-2B": - return _load_eagle2(device, torch_dtype) - elif model_name == "Qwen/Qwen2.5-VL-3B-Instruct": - return _load_qwen2_5_vl(device, torch_dtype) - msg = f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" - raise ValueError(msg) + """Load and configure a VLM model, processor, and input embedding layer.""" + _patch_transformers_tied_weights_compat() + if _is_eagle2(args.model): + _patch_siglip_flash_attention_compat() + model_kwargs = _from_pretrained_kwargs(args, torch_dtype) + last_error = None + + with torch.no_grad(): + for loader in _model_loader_candidates(): + try: + model = ( + loader.from_pretrained(args.model, **model_kwargs).eval().to(device) + ) + break + except (KeyError, ValueError) as exc: + last_error = exc + else: + raise ValueError( + f"Could not load '{args.model}' with available AutoModel classes." + ) from last_error + + processor_name = args.processor or args.model + if _is_eagle2(args.model): + _patch_transformers_image_utils_compat() + _patch_transformers_fast_image_processor_compat() + processor_use_fast = True + processor = AutoProcessor.from_pretrained( + processor_name, + trust_remote_code=args.trust_remote_code, + use_fast=processor_use_fast, + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + emb_layer = get_input_embedding_layer(model, torch_dtype, device) + return model, processor, emb_layer # -----------------------------------------------------------------------------# # Input loading helpers # -----------------------------------------------------------------------------# +def _get_message_content_value(content_item, keys): + for key in keys: + if key in content_item and content_item[key] is not None: + return content_item[key] + return None + +def extract_vision_inputs(processor, messages): + """ + Extract image/video payloads from chat-style messages for VLM processors. + + Some processors own this logic directly. For processors that do not, fall + back to the common chat content schema used by Hugging Face VLM examples: + {"type": "image", "image": ...} and {"type": "video", "video": ...}. + """ + if hasattr(processor, "process_vision_info"): + return processor.process_vision_info(messages) + + image_inputs = [] + video_inputs = [] + for message in messages: + content = message.get("content", []) + if isinstance(content, dict): + content = [content] + for content_item in content: + if not isinstance(content_item, dict): + continue + + content_type = content_item.get("type") + if content_type == "image": + image = _get_message_content_value( + content_item, ("image", "url", "path") + ) + if image is not None: + image_inputs.append(image) + elif content_type == "video": + video = _get_message_content_value( + content_item, ("video", "url", "path") + ) + if video is not None: + video_inputs.append(video) + + return image_inputs or None, video_inputs or None def load_inputs(args: argparse.Namespace, processor, device: torch.device): """ @@ -165,9 +607,9 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): image = Image.open(requests.get(url, stream=True).raw) if args.benchmark: - model_constants = MODEL_CONSTANTS[args.model] - image_tokens = model_constants["IMAGE_TOKENS"] - wrapper_tokens = model_constants["PROMPT_WRAPPER_TOKENS"] + model_constants = MODEL_CONSTANTS.get(args.model, {}) + image_tokens = model_constants.get("IMAGE_TOKENS", 0) + wrapper_tokens = model_constants.get("PROMPT_WRAPPER_TOKENS", 0) prompt_len = args.isl - image_tokens - wrapper_tokens prompt_txt = " ".join(["token"] * max(prompt_len, 0)) @@ -190,19 +632,7 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): ) ] - # --- Model-specific vision processing --- - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": - try: - from qwen_vl_utils import process_vision_info - except ImportError: - raise ImportError( - "The 'qwen-vl-utils' package is required for Qwen VLM models. " - "Please install it using: pip install qwen-vl-utils" - ) - - image_inputs, video_inputs = process_vision_info(messages) - else: # eagle2 - image_inputs, video_inputs = processor.process_vision_info(messages) + image_inputs, video_inputs = extract_vision_inputs(processor, messages) inputs = processor( text=text, @@ -241,6 +671,7 @@ def forward(self, inputs_embeds, position_ids): def _compile_lm( language_model: torch.nn.Module, input_embeds: torch.Tensor, + position_ids: torch.Tensor, args: argparse.Namespace, device: torch.device, ) -> torch.nn.Module: @@ -251,14 +682,16 @@ def _compile_lm( max_seq_len = input_embeds.shape[1] + args.num_tokens seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) - position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(device) - use_fp32_acc = False if args.precision == "FP16": use_fp32_acc = True exported_program = export_llm( - lm_wrap, input_embeds, min_seq_len=1, max_seq_len=2560 + lm_wrap, + input_embeds, + min_seq_len=1, + max_seq_len=2560, + position_ids=position_ids, ) with torch_tensorrt.dynamo.Debugger() if args.debug else nullcontext(): @@ -275,6 +708,20 @@ def _compile_lm( return trt_mod +def _make_lm_compile_position_ids( + args: argparse.Namespace, + input_embeds: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + base_position_ids = torch.arange( + input_embeds.shape[1], dtype=torch.long, device=device + ).unsqueeze(0) + base_position_ids = base_position_ids.expand(input_embeds.shape[0], -1) + if _is_qwen2_5_vl(args.model): + return base_position_ids.unsqueeze(0).expand(3, -1, -1).contiguous() + return base_position_ids + + def compile_lm_torchtrt( model: torch.nn.Module, args: argparse.Namespace, device: torch.device ) -> torch.nn.Module: @@ -286,43 +733,35 @@ def compile_lm_torchtrt( "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - lm_model = ( - model.model - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct" - else model.language_model - ) + lm_model = get_language_model(model) model_constants = MODEL_CONSTANTS.get( - args.model, {"EXAMPLE_SEQLEN": args.num_tokens} + args.model, {"EXAMPLE_SEQLEN": args.isl} ) example_seq_len = model_constants["EXAMPLE_SEQLEN"] example_embeds = torch.randn( args.batch_size, example_seq_len, - lm_model.config.hidden_size, + _get_lm_hidden_size(lm_model.config), dtype=torch_dtype, device=device, ) - # All supported models use the same compilation helper. - if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]: - return _compile_lm(lm_model, example_embeds, args, device) - else: - msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" - raise ValueError(msg) + position_ids = _make_lm_compile_position_ids(args, example_embeds, device) + + return _compile_lm(lm_model, example_embeds, position_ids, args, device) -def _compile_eagle2_vision( +def _compile_vision_model( vision_model: torch.nn.Module, example_pixel_values: torch.Tensor, args: argparse.Namespace, device: torch.device, ) -> torch.nn.Module: """ - Compile Eagle2 vision model with Torch-TensorRT. + Compile a vision tower with Torch-TensorRT. """ - # Set precision-specific flags use_fp32_acc = False if args.precision == "FP16": use_fp32_acc = True @@ -348,28 +787,391 @@ def _compile_eagle2_vision( return trt_mod +def _get_vision_config(config): + if hasattr(config, "vision_config"): + return config.vision_config + if hasattr(config, "visual"): + return config.visual + return config + + +def _get_config_attr(config, names): + for name in names: + value = getattr(config, name, None) + if value is not None: + return value + return None + + +def _get_lm_hidden_size(config): + hidden_size = _get_config_attr(config, ("hidden_size", "n_embd", "d_model")) + if hidden_size is not None: + return int(hidden_size) + + text_config = getattr(config, "text_config", None) + if text_config is not None: + hidden_size = _get_config_attr( + text_config, ("hidden_size", "n_embd", "d_model") + ) + if hidden_size is not None: + return int(hidden_size) + + raise ValueError( + "Cannot infer language-model hidden size from config. Expected one of " + "config.hidden_size, config.n_embd, config.d_model, or the same fields " + "under config.text_config." + ) + + +def _infer_patch_count(vision_config, pixel_values: torch.Tensor) -> int: + if pixel_values.dim() in (2, 3): + return int(pixel_values.shape[-2]) + + image_size = _get_config_attr(vision_config, ("image_size",)) + patch_size = _get_config_attr(vision_config, ("patch_size",)) + if image_size is None or patch_size is None: + return 0 + + if isinstance(image_size, (tuple, list)): + image_h, image_w = image_size[:2] + else: + image_h = image_w = image_size + if isinstance(patch_size, (tuple, list)): + patch_h, patch_w = patch_size[:2] + else: + patch_h = patch_w = patch_size + return int((image_h // patch_h) * (image_w // patch_w) + 1) + + +def _set_vit_plugin_config_from_vision(vision_config, pixel_values): + num_heads = _get_config_attr( + vision_config, ("num_heads", "num_attention_heads", "attention_heads") + ) + if num_heads is None: + raise ValueError("Cannot infer ViT plugin num_attention_heads from config.") + + head_dim = _get_config_attr(vision_config, ("head_dim",)) + if head_dim is None: + hidden_size = _get_config_attr( + vision_config, ("hidden_size", "embed_dim", "dim") + ) + if hidden_size is None: + raise ValueError("Cannot infer ViT plugin hidden_size from config.") + head_dim = int(hidden_size) // int(num_heads) + + set_vit_plugin_config( + num_attention_heads=int(num_heads), + head_dim=int(head_dim), + num_patches=_infer_patch_count(vision_config, pixel_values), + ) + + +def _create_windowed_rope_vit_plugin_core_inputs( + visual_model: torch.nn.Module, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + device: torch.device, + dtype: torch.dtype, +): + with torch.no_grad(): + rotary_pos_emb = visual_model.rot_pos_emb(image_grid_thw) + window_index, cu_window_seqlens = visual_model.get_window_index(image_grid_thw) + + window_index = window_index.to(device=device, dtype=torch.long) + reverse_window_index = torch.argsort(window_index) + + seq_len = pixel_values.shape[0] + attention_mask = torch.zeros(1, seq_len, seq_len, dtype=dtype, device=device) + window_attention_mask = torch.full( + (1, seq_len, seq_len), + torch.finfo(dtype).min, + dtype=dtype, + device=device, + ) + + cu_window_seqlens = torch.as_tensor( + cu_window_seqlens, device="cpu", dtype=torch.long + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).tolist() + max_window_seq_len = max( + end - start + for start, end in zip(cu_window_seqlens[:-1], cu_window_seqlens[1:]) + ) + + for start, end in zip(cu_window_seqlens[:-1], cu_window_seqlens[1:]): + window_attention_mask[:, start:end, start:end] = 0 + + return { + "rotary_pos_emb": rotary_pos_emb.to(device=device), + "attention_mask": attention_mask, + "window_attention_mask": window_attention_mask, + "cu_window_seqlens": torch.tensor( + cu_window_seqlens, dtype=torch.int32, device=device + ), + "max_window_seq_len": max_window_seq_len, + "window_index": window_index, + "reverse_window_index": reverse_window_index, + } + + +class _VITPluginVisualAdapter(torch.nn.Module): + """ + Preserve a vision tower's native call signature while using the compiled + ViT plugin wrapper internally. + """ + + def __init__( + self, + compiled_visual, + input_contract: str, + core_inputs: Dict[str, Any], + original_visual: torch.nn.Module | None = None, + ): + super().__init__() + self.compiled_visual = compiled_visual + self.input_contract = input_contract + self.core_inputs = core_inputs + self.original_visual = original_visual + self.return_pooler_output = False + + if original_visual is not None: + for attr_name in ("dtype", "spatial_merge_size", "spatial_merge_unit"): + if hasattr(original_visual, attr_name): + setattr(self, attr_name, getattr(original_visual, attr_name)) + if hasattr(original_visual, "merger"): + self.merger = original_visual.merger + self.return_pooler_output = True + + def _call_compiled(self, positional_args, keyword_args): + try: + return self.compiled_visual(**keyword_args) + except TypeError: + return self.compiled_visual(*positional_args) + + def _compiled_kwargs(self, pixel_values, *args, **kwargs): + if self.input_contract == VIT_INPUT_CONTRACT_WINDOWED_ROPE: + return {"pixel_values": pixel_values, **self.core_inputs} + + if self.input_contract == VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO: + aspect_ratio_ids = _get_runtime_tensor( + args, kwargs, self.core_inputs, ("aspect_ratio_ids",), position=0 + ) + attention_mask = _get_runtime_tensor( + args, + kwargs, + self.core_inputs, + ("attention_mask", "aspect_ratio_mask"), + position=1, + ) + return { + "pixel_values": pixel_values, + "aspect_ratio_ids": aspect_ratio_ids, + "attention_mask": attention_mask, + } + + if self.input_contract == VIT_INPUT_CONTRACT_NATIVE: + return {"pixel_values": pixel_values} + + raise ValueError(f"Unsupported ViT plugin input contract: {self.input_contract}") + + def forward(self, pixel_values, *args, **kwargs): + keyword_args = self._compiled_kwargs(pixel_values, *args, **kwargs) + output = self._call_compiled(tuple(keyword_args.values()), keyword_args) + if self.return_pooler_output: + return SimpleNamespace(pooler_output=output, last_hidden_state=output) + return output + + +def _get_required_input(inputs, names, purpose: str): + for name in names: + value = inputs.get(name) + if isinstance(value, torch.Tensor): + return value + raise ValueError( + f"ViT plugin path requires {purpose}. Expected one of: {', '.join(names)}." + ) + + +def _get_optional_tensor(inputs, names): + for name in names: + value = inputs.get(name) + if isinstance(value, torch.Tensor): + return value + return None + + +def _get_runtime_tensor(args, kwargs, core_inputs, names, position=None): + for name in names: + value = kwargs.get(name) + if isinstance(value, torch.Tensor): + return value + if ( + position is not None + and position < len(args) + and isinstance(args[position], torch.Tensor) + ): + return args[position] + for name in names: + value = core_inputs.get(name) + if isinstance(value, torch.Tensor): + return value + return None + + +def _has_windowed_rope_contract(visual_model, inputs) -> bool: + return ( + isinstance(inputs.get("image_grid_thw"), torch.Tensor) + and hasattr(visual_model, "get_window_index") + and hasattr(visual_model, "rot_pos_emb") + and hasattr(visual_model, "patch_embed") + and hasattr(visual_model, "blocks") + ) + + +def _has_tiled_aspect_ratio_contract(inputs) -> bool: + return isinstance(inputs.get("aspect_ratio_ids"), torch.Tensor) and ( + _get_optional_tensor(inputs, ("aspect_ratio_mask", "attention_mask")) + is not None + ) + + +def _prepare_vit_plugin_inputs( + visual_model, + inputs, + pixel_values: torch.Tensor, + device: torch.device, + torch_dtype: torch.dtype, +) -> Tuple[str, Dict[str, Any], int]: + if _has_windowed_rope_contract(visual_model, inputs): + core_inputs = _create_windowed_rope_vit_plugin_core_inputs( + visual_model, + pixel_values, + inputs["image_grid_thw"], + device, + torch_dtype, + ) + max_window_seq_len = core_inputs.pop("max_window_seq_len") + return ( + VIT_INPUT_CONTRACT_WINDOWED_ROPE, + core_inputs, + max_window_seq_len, + ) + + if _has_tiled_aspect_ratio_contract(inputs): + core_inputs = { + "aspect_ratio_ids": _get_required_input( + inputs, ("aspect_ratio_ids",), "tiled vision aspect ratio ids" + ), + "attention_mask": _get_required_input( + inputs, + ("aspect_ratio_mask", "attention_mask"), + "tiled vision attention mask", + ), + } + return ( + VIT_INPUT_CONTRACT_TILED_ASPECT_RATIO, + core_inputs, + 0, + ) + + return ( + VIT_INPUT_CONTRACT_NATIVE, + {}, + 0, + ) + + +def _compile_vision_with_vit_plugin( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + device: torch.device, +) -> torch.nn.Module: + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + vision_config = _get_vision_config(model.config) + pixel_values = inputs["pixel_values"].to(dtype=torch_dtype) + + load_vit_plugin() + register_vit_plugin_op() + _set_vit_plugin_config_from_vision(vision_config, pixel_values) + print(f"ViT plugin config: {get_vit_plugin_config()}") + + visual_model = replace_vit_attention_with_plugin( + get_vision_model(model), vision_config + ) + plugin_module_count = count_vit_plugin_attention_modules(visual_model) + print(f"ViT plugin attention modules inserted: {plugin_module_count}") + + input_contract, core_inputs, max_window_seq_len = _prepare_vit_plugin_inputs( + visual_model, inputs, pixel_values, device, torch_dtype + ) + wrapper = ViTPluginWrapper( + visual_model, + input_contract=input_contract, + max_window_seq_len=max_window_seq_len, + ).eval() + + compile_kwargs = {"pixel_values": pixel_values, **core_inputs} + reset_vit_plugin_conversion_count() + compiled_visual = compile_vit_plugin_model( + wrapper, + (), + device, + example_kwargs=compile_kwargs, + dynamic_shapes={name: {} for name in compile_kwargs}, + debug=args.debug, + ) + plugin_conversion_count = get_vit_plugin_conversion_count() + print(f"ViT plugin TensorRT conversions: {plugin_conversion_count}") + if plugin_conversion_count == 0: + raise RuntimeError( + "ViT plugin backend was requested, but no ViTAttentionPlugin nodes " + "were lowered into the TensorRT network." + ) + + return _VITPluginVisualAdapter( + compiled_visual, input_contract, core_inputs, original_visual=visual_model + ).eval() + + def compile_vision_torchtrt( model: torch.nn.Module, args: argparse.Namespace, - example_pixel_values: torch.Tensor, + inputs, device: torch.device, ) -> torch.nn.Module: """ Dispatcher function for vision model compilation. """ - if args.model == "nvidia/Eagle2-2B": - return _compile_eagle2_vision( - model.vision_model, example_pixel_values, args, device - ) - elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + example_pixel_values = inputs["pixel_values"] + if getattr(args, "vision_backend", "torchtrt") == "plugin": + if not VIT_PLUGIN_AVAILABLE: + raise RuntimeError( + "ViT plugin vision backend requested but plugin utilities are not available." + ) + return _compile_vision_with_vit_plugin(model, inputs, args, device) + + if _is_qwen2_5_vl(args.model): # TODO: Vision model compilation for Qwen2.5-VL is currently skipped. # The model's `get_window_index` method uses dynamic Python list operations # (e.g., .tolist(), .extend()) to process variable-sized image grids for # windowed attention. These operations are incompatible with torch.export's # static graph tracing, preventing successful compilation. - return model.visual - else: - raise ValueError(f"Unsupported model: {args.model}") + return get_vision_model(model) + + try: + return _compile_vision_model( + get_vision_model(model), example_pixel_values, args, device + ) + except ValueError as exc: + raise ValueError( + f"Cannot compile the vision tower for '{args.model}' with the generic " + "path. Add a model adapter or use a supported architecture." + ) from exc # -----------------------------------------------------------------------------# @@ -378,7 +1180,7 @@ def compile_vision_torchtrt( def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): - """Pretty-print generated text for comparison.""" + """Print the generated tokens from the model.""" print(f"========= {backend_name} =========") print( f"{backend_name} model generated text: ", @@ -387,59 +1189,1546 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): print("===================================") -# -----------------------------------------------------------------------------# -# Main driver -# -----------------------------------------------------------------------------# -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run VLM inference (PyTorch & TensorRT back-ends)" +def _extract_hidden_states(outputs): + if isinstance(outputs, torch.Tensor): + return outputs + if hasattr(outputs, "last_hidden_state"): + return outputs.last_hidden_state + if isinstance(outputs, (tuple, list)): + return outputs[0] + return outputs + + +def _extract_vision_output(outputs): + if isinstance(outputs, torch.Tensor): + return outputs + if hasattr(outputs, "pooler_output"): + return outputs.pooler_output + if hasattr(outputs, "last_hidden_state"): + return outputs.last_hidden_state + if isinstance(outputs, (tuple, list)): + return outputs[0] + return outputs + + +def _module_device(module: torch.nn.Module, fallback: torch.device) -> torch.device: + try: + return next(module.parameters()).device + except StopIteration: + return fallback + + +def _call_language_model_for_verify(language_model, inputs_embeds, position_ids): + target_device = _module_device(language_model, inputs_embeds.device) + inputs_embeds = inputs_embeds.to(target_device) + position_ids = position_ids.to(target_device) + try: + outputs = language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + ) + except TypeError: + outputs = language_model(inputs_embeds, position_ids) + return _extract_hidden_states(outputs) + + +def _qwen_full_cu_seqlens(image_grid_thw: torch.Tensor) -> torch.Tensor: + grid = image_grid_thw.to(device="cpu", dtype=torch.int64) + seq_lens = torch.repeat_interleave(grid[:, 1] * grid[:, 2], grid[:, 0]) + cu_seqlens = torch.nn.functional.pad(seq_lens.cumsum(dim=0), (1, 0), value=0) + return cu_seqlens.to(device=image_grid_thw.device, dtype=torch.int32) + + +def _call_qwen_visual_block_reference( + block, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + position_embeddings, +) -> torch.Tensor: + call_attempts = ( + lambda: block( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ), + lambda: block( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ), + lambda: block(hidden_states, cu_seqlens=cu_seqlens), + lambda: block(hidden_states, cu_seqlens), ) - parser.add_argument( - "--model", - default="nvidia/Eagle2-2B", - choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"], - help="VLM model name", + last_error = None + for call in call_attempts: + try: + return _extract_hidden_states(call()) + except TypeError as exc: + last_error = exc + raise last_error + + +def _call_qwen_attention_reference( + attn, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + position_embeddings, +) -> torch.Tensor: + call_attempts = ( + lambda: attn( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ), + lambda: attn( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ), + lambda: attn( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ), + lambda: attn(hidden_states, cu_seqlens=cu_seqlens), + lambda: attn(hidden_states, cu_seqlens), ) - parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") - parser.add_argument( - "--precision", - default="FP16", - choices=["FP16", "FP32"], - help="Computation precision", + last_error = None + for call in call_attempts: + try: + return _extract_hidden_states(call()) + except TypeError as exc: + last_error = exc + raise last_error + + +def _qwen_windowed_rope_attention_inputs( + visual, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + layer_idx: int, + dtype: torch.dtype, +): + pixel_values = pixel_values.to(device=_module_device(visual, pixel_values.device)) + if hasattr(visual, "dtype"): + pixel_values = pixel_values.to(dtype=visual.dtype) + + hidden_states = visual.patch_embed(pixel_values) + rotary_pos_emb = visual.rot_pos_emb(image_grid_thw) + window_index, cu_window_seqlens = visual.get_window_index(image_grid_thw) + + window_index = torch.as_tensor( + window_index, device=hidden_states.device, dtype=torch.long ) - parser.add_argument("--iterations", type=int, default=5, help="# iterations") - parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") - parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") - parser.add_argument("--batch_size", type=int, default=1, help="Batch size") - parser.add_argument("--isl", type=int, default=2048, help="Input seq length") - parser.add_argument( - "--enable_pytorch_run", - action="store_true", - help="Run the PyTorch baseline as well", + cu_window_seqlens = torch.as_tensor( + cu_window_seqlens, device=hidden_states.device, dtype=torch.int32 ) - parser.add_argument( - "--cache", - default="", - choices=["", "static_v1"], - help="KV-cache variant to use", + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_seqlens = _qwen_full_cu_seqlens(image_grid_thw).to(hidden_states.device) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, ) - parser.add_argument( - "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, ) - parser.add_argument( - "--benchmark", action="store_true", help="Enable benchmarking mode" + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + attention_mask = torch.zeros(1, seq_len, seq_len, dtype=dtype, device=hidden_states.device) + window_attention_mask = torch.full( + (1, seq_len, seq_len), + torch.finfo(dtype).min, + dtype=dtype, + device=hidden_states.device, ) - parser.add_argument( - "--image_path", - type=str, - default=None, - help="Path to local image file. If not provided, uses default URL image.", + for start, end in zip(cu_window_seqlens[:-1], cu_window_seqlens[1:]): + window_attention_mask[:, start:end, start:end] = 0 + + full_attention = layer_idx in visual.fullatt_block_indexes + mask_now = attention_mask if full_attention else window_attention_mask + cu_seqlens_now = cu_seqlens if full_attention else cu_window_seqlens + hidden_states = visual.blocks[layer_idx].norm1(hidden_states) + return hidden_states, mask_now, cu_seqlens_now, rotary_pos_emb, position_embeddings + + +def _qwen_windowed_rope_block_inputs( + visual, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + layer_idx: int, + dtype: torch.dtype, +): + pixel_values = pixel_values.to(device=_module_device(visual, pixel_values.device)) + if hasattr(visual, "dtype"): + pixel_values = pixel_values.to(dtype=visual.dtype) + + hidden_states = visual.patch_embed(pixel_values) + rotary_pos_emb = visual.rot_pos_emb(image_grid_thw) + window_index, cu_window_seqlens = visual.get_window_index(image_grid_thw) + + window_index = torch.as_tensor( + window_index, device=hidden_states.device, dtype=torch.long ) - parser.add_argument( - "--device", - type=str, - default="cuda:0", - help="Device to run inference on (e.g., 'cuda:0', 'cuda:1')", + cu_window_seqlens = torch.as_tensor( + cu_window_seqlens, device=hidden_states.device, dtype=torch.int32 + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_seqlens = _qwen_full_cu_seqlens(image_grid_thw).to(hidden_states.device) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + attention_mask = torch.zeros( + 1, seq_len, seq_len, dtype=dtype, device=hidden_states.device + ) + window_attention_mask = torch.full( + (1, seq_len, seq_len), + torch.finfo(dtype).min, + dtype=dtype, + device=hidden_states.device, + ) + for start, end in zip(cu_window_seqlens[:-1], cu_window_seqlens[1:]): + window_attention_mask[:, start:end, start:end] = 0 + + for prior_layer_idx in range(layer_idx): + full_attention = prior_layer_idx in visual.fullatt_block_indexes + cu_seqlens_now = cu_seqlens if full_attention else cu_window_seqlens + hidden_states = _call_qwen_visual_block_reference( + visual.blocks[prior_layer_idx], + hidden_states, + cu_seqlens_now, + rotary_pos_emb, + position_embeddings, + ) + + full_attention = layer_idx in visual.fullatt_block_indexes + mask_now = attention_mask if full_attention else window_attention_mask + cu_seqlens_now = cu_seqlens if full_attention else cu_window_seqlens + return hidden_states, mask_now, cu_seqlens_now, rotary_pos_emb, position_embeddings + + +def _call_qwen_block_with_attention_wrapper( + block, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings, + vision_config, + layer_idx: int, +) -> torch.Tensor: + attention = ViTPluginAttention( + block.attn, + vision_config, + layer_idx, + use_plugin_op=False, + ).eval() + + residual = hidden_states + candidate = block.norm1(hidden_states) + candidate = attention( + candidate, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + candidate = residual + candidate + + residual = candidate + candidate = block.norm2(candidate) + candidate = block.mlp(candidate) + candidate = residual + candidate + return candidate + + +def _forward_qwen_windowed_rope_reference( + visual, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, +) -> torch.Tensor: + pixel_values = pixel_values.to(device=_module_device(visual, pixel_values.device)) + if hasattr(visual, "dtype"): + pixel_values = pixel_values.to(dtype=visual.dtype) + + hidden_states = visual.patch_embed(pixel_values) + rotary_pos_emb = visual.rot_pos_emb(image_grid_thw) + window_index, cu_window_seqlens = visual.get_window_index(image_grid_thw) + + window_index = torch.as_tensor( + window_index, device=hidden_states.device, dtype=torch.long + ) + reverse_window_index = torch.argsort(window_index) + cu_window_seqlens = torch.as_tensor( + cu_window_seqlens, device=hidden_states.device, dtype=torch.int32 + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_seqlens = _qwen_full_cu_seqlens(image_grid_thw).to(hidden_states.device) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // visual.spatial_merge_unit, + visual.spatial_merge_unit, + -1, + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + for layer_idx, block in enumerate(visual.blocks): + cu_seqlens_now = ( + cu_seqlens + if layer_idx in visual.fullatt_block_indexes + else cu_window_seqlens + ) + hidden_states = _call_qwen_visual_block_reference( + block, + hidden_states, + cu_seqlens_now, + rotary_pos_emb, + position_embeddings, + ) + + hidden_states = visual.merger(hidden_states) + hidden_states = hidden_states[reverse_window_index, :] + return hidden_states + + +def _call_qwen_visual_direct( + visual, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, +) -> torch.Tensor: + target_device = _module_device(visual, pixel_values.device) + pixel_values = pixel_values.to(device=target_device) + image_grid_thw = image_grid_thw.to(device=target_device) + if hasattr(visual, "dtype"): + pixel_values = pixel_values.to(dtype=visual.dtype) + + call_attempts = ( + lambda: visual(pixel_values, grid_thw=image_grid_thw), + lambda: visual(pixel_values, image_grid_thw), + lambda: visual(pixel_values), + ) + last_error = None + for call in call_attempts: + try: + return _extract_vision_output(call()) + except TypeError as exc: + last_error = exc + raise last_error + + +def _call_qwen_get_image_features_owner( + owner, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, +): + get_image_features = getattr(owner, "get_image_features", None) + if not callable(get_image_features): + return None + + call_attempts = ( + lambda: get_image_features( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ), + lambda: get_image_features(pixel_values, image_grid_thw), + ) + for call in call_attempts: + try: + return _extract_vision_output(call()) + except TypeError: + continue + return None + + +def _print_tensor_comparison( + name: str, + reference: torch.Tensor, + candidate: torch.Tensor, + atol: float, + rtol: float, +) -> bool: + reference = reference.detach() + candidate = candidate.detach() + if reference.shape != candidate.shape: + print(f"{name}: shape mismatch, ref={tuple(reference.shape)}, trt={tuple(candidate.shape)}") + return False + + ref_float = reference.float() + candidate_float = candidate.float() + diff = (ref_float - candidate_float).abs() + is_close = torch.allclose(ref_float, candidate_float, atol=atol, rtol=rtol) + ref_finite = torch.isfinite(ref_float) + candidate_finite = torch.isfinite(candidate_float) + print( + f"{name}: allclose={is_close}, " + f"max_abs={diff.max().item():.6f}, " + f"mean_abs={diff.mean().item():.6f}, " + f"ref_norm={ref_float.norm().item():.6f}, " + f"trt_norm={candidate_float.norm().item():.6f}, " + f"ref_finite={ref_finite.sum().item()}/{ref_finite.numel()}, " + f"trt_finite={candidate_finite.sum().item()}/{candidate_finite.numel()}" + ) + return is_close + + +@torch.inference_mode() +def verify_qwen_vision_wrapper( + model: torch.nn.Module, + inputs, + atol: float, + rtol: float, +) -> None: + print("========= Vision Wrapper Verification =========") + pixel_values = inputs["pixel_values"] + image_grid_thw = inputs["image_grid_thw"] + input_ids = inputs["input_ids"] + image_mask = input_ids == model.config.image_token_id + num_image_tokens = image_mask.sum().item() + visual = get_vision_model(model) + + ref_image_embeds = get_qwen_image_embeds( + model, + pixel_values, + image_grid_thw, + expected_tokens=num_image_tokens, + ) + direct_visual_embeds = _call_qwen_visual_direct( + visual, + pixel_values, + image_grid_thw, + ) + wrapper_image_embeds = _forward_qwen_windowed_rope_reference( + visual, + pixel_values, + image_grid_thw, + ) + + parent_feature_embeds = None + parent = getattr(model, "model", None) + if isinstance(parent, torch.nn.Module): + parent_feature_embeds = _call_qwen_get_image_features_owner( + parent, + pixel_values, + image_grid_thw, + ) + + helper_close = _print_tensor_comparison( + "HF get_image_features vs direct visual", + ref_image_embeds, + direct_visual_embeds, + atol, + rtol, + ) + parent_close = None + if isinstance(parent_feature_embeds, torch.Tensor): + parent_close = _print_tensor_comparison( + "HF model.get_image_features vs direct visual", + parent_feature_embeds, + direct_visual_embeds, + atol, + rtol, + ) + wrapper_close = _print_tensor_comparison( + "direct visual vs reconstructed PyTorch visual", + direct_visual_embeds, + wrapper_image_embeds, + atol, + rtol, + ) + print( + "vision wrapper verification summary: " + f"helper={helper_close}, parent_helper={parent_close}, " + f"wrapper={wrapper_close}" + ) + print("===============================================") + + +@torch.inference_mode() +def verify_qwen_attention_wrapper( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + device: torch.device, + atol: float, + rtol: float, +) -> None: + """ + Compare direct Qwen visual output against the same reconstructed visual + path with attention modules replaced by ViTPluginAttention running its + PyTorch fallback instead of the TensorRT plugin op. + """ + print("========= Vision Attention Wrapper Verification =========") + + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + pixel_values = inputs["pixel_values"] + image_grid_thw = inputs["image_grid_thw"] + visual_model = get_vision_model(model) + + direct_visual_embeds = _call_qwen_visual_direct( + visual_model, + pixel_values, + image_grid_thw, + ) + + vision_config = _get_vision_config(model.config) + replace_vit_attention_with_plugin( + visual_model, + vision_config, + use_plugin_op=False, + ) + replacement_count = count_vit_plugin_attention_modules(visual_model) + print(f"PyTorch attention replacement modules inserted: {replacement_count}") + + plugin_pixel_values = pixel_values.to(dtype=torch_dtype) + input_contract, core_inputs, max_window_seq_len = _prepare_vit_plugin_inputs( + visual_model, + inputs, + plugin_pixel_values, + device, + torch_dtype, + ) + wrapper = ViTPluginWrapper( + visual_model, + input_contract=input_contract, + max_window_seq_len=max_window_seq_len, + ).eval() + replacement_visual_embeds = _extract_vision_output( + wrapper(plugin_pixel_values, **core_inputs) + ) + + replacement_close = _print_tensor_comparison( + "direct visual vs PyTorch attention wrapper visual", + direct_visual_embeds, + replacement_visual_embeds, + atol, + rtol, + ) + print( + "vision attention wrapper verification summary: " + f"attention_wrapper={replacement_close}" + ) + print("=========================================================") + + +@torch.inference_mode() +def verify_qwen_attention_module( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + atol: float, + rtol: float, + layer_idx: int = 0, +) -> None: + """ + Compare one original Qwen visual attention module against ViTPluginAttention + running its PyTorch fallback. + """ + print("========= Vision Attention Module Verification =========") + + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + visual = get_vision_model(model) + ( + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) = _qwen_windowed_rope_attention_inputs( + visual, + inputs["pixel_values"], + inputs["image_grid_thw"], + layer_idx, + torch_dtype, + ) + + original_attn = visual.blocks[layer_idx].attn + ref_attn = _call_qwen_attention_reference( + original_attn, + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) + wrapper_attn = ViTPluginAttention( + original_attn, + _get_vision_config(model.config), + layer_idx, + use_plugin_op=False, + ).eval() + candidate_attn = wrapper_attn( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + + attn_close = _print_tensor_comparison( + f"layer {layer_idx} original attention vs PyTorch attention wrapper", + ref_attn, + candidate_attn, + atol, + rtol, + ) + print( + "vision attention module verification summary: " + f"layer={layer_idx}, attention={attn_close}" + ) + print("========================================================") + + +@torch.inference_mode() +def verify_qwen_attention_plugin_module( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + device: torch.device, + atol: float, + rtol: float, + layer_idx: int = 0, +) -> None: + """ + Compare one original Qwen visual attention module against the real TensorRT + ViTAttentionPlugin lowering. This isolates plugin math from the rest of the + vision tower. + """ + print("========= Vision Attention Plugin Module Verification =========") + + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + visual = get_vision_model(model) + ( + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) = _qwen_windowed_rope_attention_inputs( + visual, + inputs["pixel_values"], + inputs["image_grid_thw"], + layer_idx, + torch_dtype, + ) + + original_attn = visual.blocks[layer_idx].attn + ref_attn = _call_qwen_attention_reference( + original_attn, + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) + + vision_config = _get_vision_config(model.config) + fallback_attn = ViTPluginAttention( + original_attn, + vision_config, + layer_idx, + use_plugin_op=False, + ).eval() + plugin_rope_position_embeddings = tuple( + value.to(dtype=hidden_states.dtype) for value in position_embeddings + ) + fallback_with_plugin_rope = fallback_attn( + hidden_states, + attention_mask=attention_mask, + position_embeddings=plugin_rope_position_embeddings, + ) + fallback_close = _print_tensor_comparison( + f"layer {layer_idx} original attention vs PyTorch wrapper with plugin RoPE dtype", + ref_attn, + fallback_with_plugin_rope, + atol, + rtol, + ) + + class _AttentionPluginModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn = ViTPluginAttention( + original_attn, + vision_config, + layer_idx, + use_plugin_op=True, + ).eval() + + def forward(self, hidden_states, attention_mask, cos, sin): + return self.attn( + hidden_states, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + ) + + load_vit_plugin() + register_vit_plugin_op() + _set_vit_plugin_config_from_vision( + vision_config, + inputs["pixel_values"].to(dtype=torch_dtype), + ) + print(f"ViT plugin config: {get_vit_plugin_config()}") + + reset_vit_plugin_conversion_count() + compiled_attn = compile_vit_plugin_model( + _AttentionPluginModule().eval(), + ( + hidden_states, + attention_mask, + position_embeddings[0], + position_embeddings[1], + ), + device, + debug=args.debug, + ) + plugin_attn = _extract_hidden_states( + compiled_attn( + hidden_states, + attention_mask, + position_embeddings[0], + position_embeddings[1], + ) + ) + plugin_close = _print_tensor_comparison( + f"layer {layer_idx} original attention vs TensorRT plugin attention", + ref_attn, + plugin_attn, + atol, + rtol, + ) + print( + "vision attention plugin module verification summary: " + f"layer={layer_idx}, fallback_plugin_rope={fallback_close}, " + f"plugin={plugin_close}, conversions={get_vit_plugin_conversion_count()}" + ) + print("===============================================================") + + +@torch.inference_mode() +def verify_qwen_block_plugin_module( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + device: torch.device, + atol: float, + rtol: float, + layer_idx: int = 0, +) -> None: + """ + Compare one original Qwen visual block against the same block compiled with + the real TensorRT ViTAttentionPlugin inside it. + """ + print("========= Vision Block Plugin Module Verification =========") + + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + visual = get_vision_model(model) + ( + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) = _qwen_windowed_rope_block_inputs( + visual, + inputs["pixel_values"], + inputs["image_grid_thw"], + layer_idx, + torch_dtype, + ) + + block = visual.blocks[layer_idx] + ref_block = _call_qwen_visual_block_reference( + block, + hidden_states, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) + + vision_config = _get_vision_config(model.config) + + class _BlockPluginModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm1 = block.norm1 + self.attn = ViTPluginAttention( + block.attn, + vision_config, + layer_idx, + use_plugin_op=True, + ).eval() + self.norm2 = block.norm2 + self.mlp = block.mlp + + def forward(self, hidden_states, attention_mask, cos, sin): + residual = hidden_states + candidate = self.norm1(hidden_states) + candidate = self.attn( + candidate, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + ) + candidate = residual + candidate + + residual = candidate + candidate = self.norm2(candidate) + candidate = self.mlp(candidate) + return residual + candidate + + load_vit_plugin() + register_vit_plugin_op() + _set_vit_plugin_config_from_vision( + vision_config, + inputs["pixel_values"].to(dtype=torch_dtype), + ) + print(f"ViT plugin config: {get_vit_plugin_config()}") + + reset_vit_plugin_conversion_count() + compiled_block = compile_vit_plugin_model( + _BlockPluginModule().eval(), + ( + hidden_states, + attention_mask, + position_embeddings[0], + position_embeddings[1], + ), + device, + debug=args.debug, + ) + plugin_block = _extract_hidden_states( + compiled_block( + hidden_states, + attention_mask, + position_embeddings[0], + position_embeddings[1], + ) + ) + block_close = _print_tensor_comparison( + f"layer {layer_idx} original block vs TensorRT plugin block", + ref_block, + plugin_block, + atol, + rtol, + ) + print( + "vision block plugin module verification summary: " + f"layer={layer_idx}, block={block_close}, " + f"conversions={get_vit_plugin_conversion_count()}" + ) + print("===========================================================") + + +@torch.inference_mode() +def verify_qwen_block_plugin_parts( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + device: torch.device, + atol: float, + rtol: float, + layer_idx: int = 0, +) -> None: + """ + Split one Qwen visual block into attention-residual and MLP-residual halves + and compile each half. This isolates late-block drift to the surrounding + block ops instead of the attention plugin itself. + """ + print("========= Vision Block Plugin Parts Verification =========") + + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + visual = get_vision_model(model) + ( + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) = _qwen_windowed_rope_block_inputs( + visual, + inputs["pixel_values"], + inputs["image_grid_thw"], + layer_idx, + torch_dtype, + ) + + block = visual.blocks[layer_idx] + norm1_hidden_states = block.norm1(hidden_states) + ref_attn = _call_qwen_attention_reference( + block.attn, + norm1_hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) + ref_after_attn = hidden_states + ref_attn + ref_mlp = block.mlp(block.norm2(ref_after_attn)) + ref_after_mlp = ref_after_attn + ref_mlp + + vision_config = _get_vision_config(model.config) + load_vit_plugin() + register_vit_plugin_op() + _set_vit_plugin_config_from_vision( + vision_config, + inputs["pixel_values"].to(dtype=torch_dtype), + ) + print(f"ViT plugin config: {get_vit_plugin_config()}") + + class _AttentionResidualPluginModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm1 = block.norm1 + self.attn = ViTPluginAttention( + block.attn, + vision_config, + layer_idx, + use_plugin_op=True, + ).eval() + + def forward(self, hidden_states, attention_mask, cos, sin): + residual = hidden_states + candidate = self.norm1(hidden_states) + candidate = self.attn( + candidate, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + ) + return residual + candidate + + reset_vit_plugin_conversion_count() + compiled_attn_residual = compile_vit_plugin_model( + _AttentionResidualPluginModule().eval(), + ( + hidden_states, + attention_mask, + position_embeddings[0], + position_embeddings[1], + ), + device, + debug=args.debug, + ) + plugin_after_attn = _extract_hidden_states( + compiled_attn_residual( + hidden_states, + attention_mask, + position_embeddings[0], + position_embeddings[1], + ) + ) + attention_residual_close = _print_tensor_comparison( + f"layer {layer_idx} attention residual half", + ref_after_attn, + plugin_after_attn, + atol, + rtol, + ) + attention_conversions = get_vit_plugin_conversion_count() + + class _MlpResidualModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm2 = block.norm2 + self.mlp = block.mlp + + def forward(self, hidden_states): + residual = hidden_states + candidate = self.norm2(hidden_states) + candidate = self.mlp(candidate) + return residual + candidate + + reset_vit_plugin_conversion_count() + compiled_mlp_residual = compile_vit_plugin_model( + _MlpResidualModule().eval(), + (ref_after_attn,), + device, + debug=args.debug, + ) + plugin_after_mlp = _extract_hidden_states(compiled_mlp_residual(ref_after_attn)) + mlp_residual_close = _print_tensor_comparison( + f"layer {layer_idx} MLP residual half", + ref_after_mlp, + plugin_after_mlp, + atol, + rtol, + ) + mlp_conversions = get_vit_plugin_conversion_count() + + chained_after_mlp = _extract_hidden_states( + compiled_mlp_residual(plugin_after_attn) + ) + chained_close = _print_tensor_comparison( + f"layer {layer_idx} chained compiled attention half into MLP half", + ref_after_mlp, + chained_after_mlp, + atol, + rtol, + ) + + print( + "vision block plugin parts verification summary: " + f"layer={layer_idx}, attention_residual={attention_residual_close}, " + f"mlp_residual={mlp_residual_close}, " + f"chained={chained_close}, " + f"attention_conversions={attention_conversions}, " + f"mlp_conversions={mlp_conversions}" + ) + print("===========================================================") + + +@torch.inference_mode() +def verify_qwen_block_wrapper( + model: torch.nn.Module, + inputs, + args: argparse.Namespace, + atol: float, + rtol: float, + layer_idx: int = 0, +) -> None: + """ + Compare one original Qwen visual block against the same block manually + wired with ViTPluginAttention running its PyTorch fallback. + """ + print("========= Vision Block Wrapper Verification =========") + + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + visual = get_vision_model(model) + ( + hidden_states, + attention_mask, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) = _qwen_windowed_rope_block_inputs( + visual, + inputs["pixel_values"], + inputs["image_grid_thw"], + layer_idx, + torch_dtype, + ) + + block = visual.blocks[layer_idx] + ref_block = _call_qwen_visual_block_reference( + block, + hidden_states, + cu_seqlens, + rotary_pos_emb, + position_embeddings, + ) + candidate_block = _call_qwen_block_with_attention_wrapper( + block, + hidden_states, + attention_mask, + position_embeddings, + _get_vision_config(model.config), + layer_idx, + ) + + block_close = _print_tensor_comparison( + f"layer {layer_idx} original block vs PyTorch attention-wrapper block", + ref_block, + candidate_block, + atol, + rtol, + ) + print( + "vision block wrapper verification summary: " + f"layer={layer_idx}, block={block_close}" + ) + print("===================================================") + + +@torch.inference_mode() +def verify_qwen_vlm_components( + model: torch.nn.Module, + trt_model: torch.nn.Module, + inputs, + emb_layer: torch.nn.Embedding, + tokenizer, + atol: float, + rtol: float, + verify_stage: str = "all", +) -> None: + """ + Compare the compiled Qwen VLM components against the PyTorch reference. + This checks the vision tower, then isolates the LM by feeding both models + the same multimodal embeddings and Qwen multimodal RoPE position ids. + """ + print("========= Component Verification =========") + + pixel_values = inputs["pixel_values"] + image_grid_thw = inputs["image_grid_thw"] + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask") + image_mask = input_ids == model.config.image_token_id + num_image_tokens = image_mask.sum().item() + device = input_ids.device + + ref_image_embeds = get_qwen_image_embeds( + model, + pixel_values, + image_grid_thw, + expected_tokens=num_image_tokens, + ).to(device=device) + trt_image_embeds = get_qwen_image_embeds( + trt_model, + pixel_values, + image_grid_thw, + expected_tokens=num_image_tokens, + ).to(device=device) + direct_visual_embeds = _call_qwen_visual_direct( + get_vision_model(model), + pixel_values, + image_grid_thw, + ).to(device=device) + trt_direct_visual_embeds = _call_qwen_visual_direct( + get_vision_model(trt_model), + pixel_values, + image_grid_thw, + ).to(device=device) + direct_vision_close = _print_tensor_comparison( + "direct visual embeddings", direct_visual_embeds, trt_direct_visual_embeds, atol, rtol + ) + helper_vision_close = _print_tensor_comparison( + "get_image_features embeddings", ref_image_embeds, trt_image_embeds, atol, rtol + ) + if verify_stage == "vision": + print( + "component verification summary: " + f"direct_vision={direct_vision_close}, " + f"get_image_features={helper_vision_close}" + ) + print("==========================================") + return + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + mask_expanded = image_mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, + ref_image_embeds.to(device=seq_embeds.device, dtype=seq_embeds.dtype), + ) + position_ids = get_qwen_position_ids( + model, + seq_tokens, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, + ).to(device=seq_embeds.device) + + ref_hidden = _call_language_model_for_verify( + get_language_model(model), seq_embeds, position_ids + ).to(device) + trt_hidden = _call_language_model_for_verify( + get_language_model(trt_model), seq_embeds, position_ids + ).to(device) + lm_close = _print_tensor_comparison( + "LM hidden states", ref_hidden, trt_hidden, atol, rtol + ) + + ref_logits = model.lm_head( + ref_hidden[:, -1, :].to(_module_device(model.lm_head, device)) + ).to(device) + trt_logits = trt_model.lm_head( + trt_hidden[:, -1, :].to(_module_device(trt_model.lm_head, device)) + ).to(device) + logits_close = _print_tensor_comparison("next-token logits", ref_logits, trt_logits, atol, rtol) + + ref_next = ref_logits.argmax(dim=-1) + trt_next = trt_logits.argmax(dim=-1) + print( + "next-token argmax: " + f"ref={ref_next.tolist()} ({tokenizer.batch_decode(ref_next[:, None])}), " + f"trt={trt_next.tolist()} ({tokenizer.batch_decode(trt_next[:, None])}), " + f"match={torch.equal(ref_next, trt_next)}" + ) + print( + "component verification summary: " + f"direct_vision={direct_vision_close}, " + f"get_image_features={helper_vision_close}, " + f"lm={lm_close}, logits={logits_close}" + ) + print("==========================================") + + +def _qwen_inputs_embeds_with_vision( + model_for_vision: torch.nn.Module, + reference_model: torch.nn.Module, + inputs, + emb_layer: torch.nn.Embedding, +) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + image_grid_thw = inputs["image_grid_thw"] + image_mask = input_ids == reference_model.config.image_token_id + num_image_tokens = image_mask.sum().item() + + image_embeds = get_qwen_image_embeds( + model_for_vision, + pixel_values, + image_grid_thw, + expected_tokens=num_image_tokens, + ) + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + mask_expanded = image_mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, + image_embeds.to(device=seq_embeds.device, dtype=seq_embeds.dtype), + ) + return seq_tokens, seq_embeds + + +def _print_generated_token_divergence( + ref_gen_tokens: torch.Tensor, + trt_gen_tokens: torch.Tensor, + tokenizer, +) -> None: + min_len = min(ref_gen_tokens.shape[1], trt_gen_tokens.shape[1]) + mismatch_idx = None + for idx in range(min_len): + if not torch.equal(ref_gen_tokens[:, idx], trt_gen_tokens[:, idx]): + mismatch_idx = idx + break + + if mismatch_idx is None and ref_gen_tokens.shape[1] == trt_gen_tokens.shape[1]: + print("generated divergence: none") + return + + if mismatch_idx is None: + mismatch_idx = min_len + print( + "generated divergence: length mismatch after shared prefix " + f"of {min_len} tokens" + ) + else: + ref_tok = ref_gen_tokens[:, mismatch_idx] + trt_tok = trt_gen_tokens[:, mismatch_idx] + print( + "generated divergence: " + f"step={mismatch_idx}, " + f"ref={ref_tok.tolist()} ({tokenizer.batch_decode(ref_tok[:, None])}), " + f"trt={trt_tok.tolist()} ({tokenizer.batch_decode(trt_tok[:, None])})" + ) + + context_start = max(0, mismatch_idx - 4) + context_end = min( + max(ref_gen_tokens.shape[1], trt_gen_tokens.shape[1]), + mismatch_idx + 5, + ) + ref_context = ref_gen_tokens[:, context_start : min(context_end, ref_gen_tokens.shape[1])] + trt_context = trt_gen_tokens[:, context_start : min(context_end, trt_gen_tokens.shape[1])] + print( + "generated divergence context: " + f"ref={tokenizer.batch_decode(ref_context, skip_special_tokens=True)}, " + f"trt={tokenizer.batch_decode(trt_context, skip_special_tokens=True)}" + ) + + +@torch.inference_mode() +def verify_qwen_vision_semantics( + model: torch.nn.Module, + trt_model: torch.nn.Module, + inputs, + emb_layer: torch.nn.Embedding, + tokenizer, + args: argparse.Namespace, + atol: float, + rtol: float, +) -> None: + """ + Check whether compiled/plugin vision changes the language-model behavior + before compiling the LM. The first-token comparison uses the original + PyTorch LM for both paths, so the only intentional input difference is the + vision embedding source. + """ + print("========= Vision Semantic Verification =========") + + attention_mask = inputs.get("attention_mask") + ref_tokens, ref_embeds = _qwen_inputs_embeds_with_vision( + model, model, inputs, emb_layer + ) + trt_tokens, trt_embeds = _qwen_inputs_embeds_with_vision( + trt_model, model, inputs, emb_layer + ) + position_ids = get_qwen_position_ids( + model, + ref_tokens, + image_grid_thw=inputs["image_grid_thw"], + attention_mask=attention_mask, + ).to(device=ref_embeds.device) + + language_model = get_language_model(model) + ref_hidden = _call_language_model_for_verify( + language_model, + ref_embeds, + position_ids, + ) + trt_hidden = _call_language_model_for_verify( + language_model, + trt_embeds, + position_ids, + ) + ref_logits = model.lm_head( + ref_hidden[:, -1, :].to(_module_device(model.lm_head, ref_hidden.device)) + ).to(ref_embeds.device) + trt_logits = model.lm_head( + trt_hidden[:, -1, :].to(_module_device(model.lm_head, trt_hidden.device)) + ).to(ref_embeds.device) + logits_close = _print_tensor_comparison( + "first-token logits with PyTorch LM", + ref_logits, + trt_logits, + atol, + rtol, + ) + + ref_next = ref_logits.argmax(dim=-1) + trt_next = trt_logits.argmax(dim=-1) + print( + "first-token argmax: " + f"ref={ref_next.tolist()} ({tokenizer.batch_decode(ref_next[:, None])}), " + f"trt={trt_next.tolist()} ({tokenizer.batch_decode(trt_next[:, None])}), " + f"match={torch.equal(ref_next, trt_next)}" + ) + + ref_gen_tokens = generate_mm_qwen2_5_vl( + model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + attention_mask, + tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + ) + trt_gen_tokens = generate_mm_qwen2_5_vl( + trt_model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + attention_mask, + tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + ) + generated_match = torch.equal(ref_gen_tokens, trt_gen_tokens) + print( + "generated tokens: " + f"ref={ref_gen_tokens.tolist()} " + f"({tokenizer.batch_decode(ref_gen_tokens, skip_special_tokens=True)}), " + f"trt={trt_gen_tokens.tolist()} " + f"({tokenizer.batch_decode(trt_gen_tokens, skip_special_tokens=True)}), " + f"match={generated_match}" + ) + _print_generated_token_divergence(ref_gen_tokens, trt_gen_tokens, tokenizer) + print( + "vision semantic verification summary: " + f"logits={logits_close}, first_token={torch.equal(ref_next, trt_next)}, " + f"generated={generated_match}" + ) + print("==============================================") + + +# -----------------------------------------------------------------------------# +# Main driver +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run VLM inference (PyTorch & TensorRT back-ends)" + ) + parser.add_argument( + "--model", + default="nvidia/Eagle2-2B", + help="VLM model name", + ) + parser.add_argument( + "--processor", + default="", + help="Processor name/path. Defaults to --model.", + ) + parser.add_argument( + "--trust_remote_code", + action=argparse.BooleanOptionalAction, + default=True, + help="Allow Hugging Face remote model/processor code.", + ) + parser.add_argument( + "--attn_implementation", + default="", + help=( + "Attention implementation passed to from_pretrained. Defaults to " + "SDPA for the torchtrt vision backend and model default for plugin." + ), + ) + parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") + parser.add_argument( + "--precision", + default="FP16", + choices=["FP16", "BF16", "FP32"], + help="Computation precision", + ) + parser.add_argument("--iterations", type=int, default=5, help="# iterations") + parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") + parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--isl", type=int, default=2048, help="Input seq length") + parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Run the PyTorch baseline as well", + ) + parser.add_argument( + "--verify_accuracy", + action="store_true", + help="Compare PyTorch and TensorRT VLM component outputs before generation.", + ) + parser.add_argument( + "--verify_stage", + default="all", + choices=[ + "vision_wrapper", + "attention_module", + "attention_plugin_module", + "block_plugin_module", + "block_plugin_parts", + "block_wrapper", + "attention_wrapper", + "vision", + "vision_semantic", + "all", + ], + help=( + "Component verification stage. 'vision_wrapper' compares the HF " + "visual output against the reconstructed PyTorch visual path before " + "TRT/plugin compilation; 'attention_module' checks one attention " + "module with the PyTorch fallback; 'attention_plugin_module' checks " + "one attention module with the real TensorRT plugin; " + "'block_plugin_module' checks one full visual block with the real " + "TensorRT plugin; " + "'block_plugin_parts' checks one block split into attention and MLP " + "halves; " + "'block_wrapper' checks one full visual block; " + "'attention_wrapper' replaces attention with the plugin wrapper but " + "runs PyTorch attention; 'vision' stops after comparing visual " + "embeddings; 'vision_semantic' checks whether compiled vision changes " + "PyTorch LM logits or generated tokens; 'all' also checks LM hidden " + "states and logits." + ), + ) + parser.add_argument( + "--verify_atol", + type=float, + default=5e-2, + help="Absolute tolerance for --verify_accuracy tensor comparisons.", + ) + parser.add_argument( + "--verify_rtol", + type=float, + default=5e-2, + help="Relative tolerance for --verify_accuracy tensor comparisons.", + ) + parser.add_argument( + "--verify_layer", + type=int, + default=0, + help="Layer index used by --verify_stage attention_module.", + ) + parser.add_argument( + "--cache", + default="", + choices=["", "static_v1"], + help="KV-cache variant to use", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + ) + parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmarking mode" + ) + parser.add_argument( + "--vision_backend", + default="torchtrt", + choices=["torchtrt", "plugin"], + help=( + "Vision backend. 'torchtrt' keeps the existing component compiler; " + "'plugin' uses the TensorRT-Edge-LLM ViT attention plugin where supported." + ), + ) + parser.add_argument( + "--image_path", + type=str, + default=None, + help="Path to local image file. If not provided, uses default URL image.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to run inference on (e.g., 'cuda:0', 'cuda:1')", ) parser.add_argument( "--disable_tf32", @@ -462,6 +2751,11 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): args = parser.parse_args() + if args.vision_backend == "plugin" and not VIT_PLUGIN_AVAILABLE: + raise RuntimeError( + "ViT plugin vision backend requested but plugin utilities are not available." + ) + device = torch.device(args.device) if device.type == "cuda": torch.cuda.set_device(device) @@ -474,7 +2768,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - model, processor, emb_layer = load_model(args.model, device, dtype) + model, processor, emb_layer = get_model(args, device, dtype) # -------------------------------------------------------------------------# # 2. Input construction (image + text prompt) @@ -483,15 +2777,136 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): max_output_len = inputs["input_ids"].shape[1] + args.num_tokens + if args.verify_accuracy and args.verify_stage == "vision_wrapper": + if _is_qwen2_5_vl(args.model): + verify_qwen_vision_wrapper( + model, + inputs, + args.verify_atol, + args.verify_rtol, + ) + raise SystemExit(0) + print( + "--verify_stage vision_wrapper currently supports Qwen2.5-VL " + "windowed-RoPE visual towers." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "attention_wrapper": + if _is_qwen2_5_vl(args.model): + verify_qwen_attention_wrapper( + model, + inputs, + args, + device, + args.verify_atol, + args.verify_rtol, + ) + raise SystemExit(0) + print( + "--verify_stage attention_wrapper currently supports Qwen2.5-VL " + "windowed-RoPE visual towers." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "attention_module": + if _is_qwen2_5_vl(args.model): + verify_qwen_attention_module( + model, + inputs, + args, + args.verify_atol, + args.verify_rtol, + args.verify_layer, + ) + raise SystemExit(0) + print( + "--verify_stage attention_module currently supports Qwen2.5-VL " + "windowed-RoPE visual towers." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "attention_plugin_module": + if _is_qwen2_5_vl(args.model): + verify_qwen_attention_plugin_module( + model, + inputs, + args, + device, + args.verify_atol, + args.verify_rtol, + args.verify_layer, + ) + raise SystemExit(0) + print( + "--verify_stage attention_plugin_module currently supports " + "Qwen2.5-VL windowed-RoPE visual towers." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "block_plugin_module": + if _is_qwen2_5_vl(args.model): + verify_qwen_block_plugin_module( + model, + inputs, + args, + device, + args.verify_atol, + args.verify_rtol, + args.verify_layer, + ) + raise SystemExit(0) + print( + "--verify_stage block_plugin_module currently supports " + "Qwen2.5-VL windowed-RoPE visual towers." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "block_plugin_parts": + if _is_qwen2_5_vl(args.model): + verify_qwen_block_plugin_parts( + model, + inputs, + args, + device, + args.verify_atol, + args.verify_rtol, + args.verify_layer, + ) + raise SystemExit(0) + print( + "--verify_stage block_plugin_parts currently supports " + "Qwen2.5-VL windowed-RoPE visual towers." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "block_wrapper": + if _is_qwen2_5_vl(args.model): + verify_qwen_block_wrapper( + model, + inputs, + args, + args.verify_atol, + args.verify_rtol, + args.verify_layer, + ) + raise SystemExit(0) + print( + "--verify_stage block_wrapper currently supports Qwen2.5-VL " + "windowed-RoPE visual towers." + ) + raise SystemExit(0) + # -------------------------------------------------------------------------# # 3. Optional: PyTorch baseline # -------------------------------------------------------------------------# + pyt_gen_tokens = pyt_timings = pyt_stats = None if args.enable_pytorch_run: # For benchmarking, we run the generation with timing enabled. # For regular runs, we run without timing for a single output. if args.benchmark: - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + if _is_qwen2_5_vl(args.model): ( pyt_gen_tokens, _, @@ -503,6 +2918,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): inputs["pixel_values"], inputs["input_ids"], inputs["image_grid_thw"], + inputs.get("attention_mask"), processor.tokenizer.eos_token_id, emb_layer, max_new_tokens=args.num_tokens, @@ -531,12 +2947,13 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): batch_size=args.batch_size, ) else: - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + if _is_qwen2_5_vl(args.model): pyt_gen_tokens = generate_mm_qwen2_5_vl( model, inputs["pixel_values"], inputs["input_ids"], inputs["image_grid_thw"], + inputs.get("attention_mask"), processor.tokenizer.eos_token_id, emb_layer, max_new_tokens=args.num_tokens, @@ -558,12 +2975,45 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): trt_model = copy.deepcopy(model) # 4.1. Vision model compilation # --- Add vision model compilation --- # - example_pixel_values = inputs["pixel_values"] - trt_vision = compile_vision_torchtrt(model, args, example_pixel_values, device) - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": - trt_model.visual = trt_vision - else: - trt_model.vision_model = trt_vision + trt_vision = compile_vision_torchtrt(trt_model, args, inputs, device) + set_vision_model(trt_model, trt_vision) + + if args.verify_accuracy and args.verify_stage == "vision": + if _is_qwen2_5_vl(args.model): + verify_qwen_vlm_components( + model, + trt_model, + inputs, + emb_layer, + processor.tokenizer, + args.verify_atol, + args.verify_rtol, + args.verify_stage, + ) + raise SystemExit(0) + print( + "--verify_stage vision currently has detailed component checks for " + "Qwen2.5-VL." + ) + raise SystemExit(0) + + if args.verify_accuracy and args.verify_stage == "vision_semantic": + if _is_qwen2_5_vl(args.model): + verify_qwen_vision_semantics( + model, + trt_model, + inputs, + emb_layer, + processor.tokenizer, + args, + args.verify_atol, + args.verify_rtol, + ) + raise SystemExit(0) + print( + "--verify_stage vision_semantic currently supports Qwen2.5-VL." + ) + raise SystemExit(0) # -------------------------------------------------------------------------# # 4.2. Language model compilation @@ -582,23 +3032,38 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): f"Cache mode '{args.cache}' is not supported. Only 'static_v1' is supported." ) - trt_lm = compile_lm_torchtrt(model, args, device) - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": - trt_model.model = trt_lm - else: - trt_model.language_model = trt_lm + trt_lm = compile_lm_torchtrt(trt_model, args, device) + set_language_model(trt_model, trt_lm) emb_layer = emb_layer.to(device) - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + if _is_qwen2_5_vl(args.model) and hasattr(trt_model, "lm_head"): trt_model.lm_head = trt_model.lm_head.to(device) + if args.verify_accuracy: + if _is_qwen2_5_vl(args.model): + verify_qwen_vlm_components( + model, + trt_model, + inputs, + emb_layer, + processor.tokenizer, + args.verify_atol, + args.verify_rtol, + args.verify_stage, + ) + else: + print( + "--verify_accuracy currently has detailed component checks for " + "Qwen2.5-VL. Falling back to generated-token comparison." + ) + if args.cache == "static_v1": - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + if _is_qwen2_5_vl(args.model): trt_generate = generate_mm_qwen2_5_vl_with_static_cache else: # eagle2 trt_generate = generate_mm_with_static_cache else: - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + if _is_qwen2_5_vl(args.model): trt_generate = generate_mm_qwen2_5_vl else: # eagle2 trt_generate = generate_mm @@ -612,8 +3077,9 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "emb_layer": emb_layer, "max_new_tokens": args.num_tokens, } - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + if _is_qwen2_5_vl(args.model): generate_args["image_grid_thw"] = inputs["image_grid_thw"] + generate_args["attention_mask"] = inputs.get("attention_mask") if args.cache == "static_v1" or args.benchmark: generate_args["with_timing"] = True diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 5c3197356d..8937ded47a 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -10,7 +10,364 @@ ) -def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): +_VISION_MODULE_ATTRS = ("visual", "vision_model", "vision_tower", "vision_encoder") + + +def _is_windowed_rope_vision_module(module: torch.nn.Module) -> bool: + return ( + hasattr(module, "patch_embed") + and hasattr(module, "blocks") + and hasattr(module, "rot_pos_emb") + and hasattr(module, "get_window_index") + ) + + +def _is_merged_windowed_rope_vision_module(module: torch.nn.Module) -> bool: + return _is_windowed_rope_vision_module(module) and hasattr(module, "merger") + + +def _is_tiled_aspect_ratio_vision_module(module: torch.nn.Module) -> bool: + return ( + hasattr(module, "patch_embedding") + and hasattr(module, "global_transformer") + and hasattr(module, "pre_tile_positional_embedding") + ) + + +def _is_native_vit_vision_module(module: torch.nn.Module) -> bool: + has_patch_embedding = any( + hasattr(module, attr_name) + for attr_name in ("patch_embed", "patch_embedding", "embeddings") + ) + has_transformer = any( + hasattr(module, attr_name) + for attr_name in ("blocks", "encoder", "transformer", "global_transformer") + ) + return has_patch_embedding and has_transformer + + +def _is_compiled_vit_plugin_adapter(module: torch.nn.Module) -> bool: + return hasattr(module, "compiled_visual") and hasattr(module, "input_contract") + + +def _is_vision_module(module: torch.nn.Module) -> bool: + return ( + _is_windowed_rope_vision_module(module) + or _is_tiled_aspect_ratio_vision_module(module) + or _is_native_vit_vision_module(module) + or _is_compiled_vit_plugin_adapter(module) + ) + + +def _contains_vision_module(module: torch.nn.Module) -> bool: + for attr_name in _VISION_MODULE_ATTRS: + if isinstance(getattr(module, attr_name, None), torch.nn.Module): + return True + + for _, child in module.named_modules(): + if child is module: + continue + if _is_vision_module(child): + return True + return False + + +_LANGUAGE_MODULE_ATTRS = ( + "language_model", + "text_model", + "llm", + "decoder", + "model", +) + + +def get_language_model(model: torch.nn.Module) -> torch.nn.Module: + for attr_name in _LANGUAGE_MODULE_ATTRS: + candidate = getattr(model, attr_name, None) + if not isinstance(candidate, torch.nn.Module): + continue + if attr_name == "model" and _contains_vision_module(candidate): + continue + return candidate + + for module_name, module in model.named_modules(): + if module is model: + continue + if module_name.rsplit(".", 1)[-1] not in _LANGUAGE_MODULE_ATTRS: + continue + if _contains_vision_module(module): + continue + return module + + raise ValueError( + "Cannot find a language-model submodule. Expected a language/text " + "model leaf module that does not also contain the vision tower." + ) + + +def get_vision_model(model: torch.nn.Module) -> torch.nn.Module: + visual = getattr(model, "visual", None) + if isinstance(visual, torch.nn.Module): + return visual + + for parent_attr in ("model", "language_model"): + parent = getattr(model, parent_attr, None) + if not isinstance(parent, torch.nn.Module): + continue + visual = getattr(parent, "visual", None) + if isinstance(visual, torch.nn.Module): + return visual + + for _, module in model.named_modules(): + if module is model: + continue + if _is_merged_windowed_rope_vision_module(module): + return module + + for attr_name in _VISION_MODULE_ATTRS: + if attr_name == "visual": + continue + candidate = getattr(model, attr_name, None) + if isinstance(candidate, torch.nn.Module): + return candidate + + for parent_attr in ("model", "language_model"): + parent = getattr(model, parent_attr, None) + if not isinstance(parent, torch.nn.Module): + continue + for attr_name in _VISION_MODULE_ATTRS: + candidate = getattr(parent, attr_name, None) + if isinstance(candidate, torch.nn.Module): + return candidate + + for _, module in model.named_modules(): + if module is model: + continue + if _is_vision_module(module): + return module + + raise ValueError( + "Cannot find a vision-model submodule. Expected a Hugging Face vision " + "tower alias or a module with ViT-like patch embedding and transformer " + "blocks/encoder." + ) + + +def extract_vision_tensor(vision_output) -> torch.Tensor: + """Normalize Hugging Face vision outputs to a tensor of image embeddings.""" + if isinstance(vision_output, torch.Tensor): + tensor = vision_output + elif hasattr(vision_output, "last_hidden_state"): + tensor = vision_output.last_hidden_state + elif hasattr(vision_output, "pooler_output"): + tensor = vision_output.pooler_output + elif isinstance(vision_output, (tuple, list)): + tensor = vision_output[0] + else: + raise TypeError( + "Vision model returned an unsupported output type: " + f"{type(vision_output).__name__}" + ) + + if tensor.dim() == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + return tensor + + +def _find_qwen_visual_with_merger(model: torch.nn.Module): + candidates = [] + for parent in (model, getattr(model, "model", None), getattr(model, "language_model", None)): + if isinstance(parent, torch.nn.Module): + candidates.append(getattr(parent, "visual", None)) + + try: + candidates.append(get_vision_model(model)) + except ValueError: + pass + + for candidate in candidates: + if isinstance(candidate, torch.nn.Module) and hasattr(candidate, "merger"): + return candidate + + for _, module in model.named_modules(): + if hasattr(module, "merger"): + return module + return None + + +def _maybe_merge_qwen_image_embeds( + model: torch.nn.Module, + image_embeds: torch.Tensor, + expected_tokens: int | None, +) -> torch.Tensor: + if expected_tokens is None or image_embeds.shape[0] == expected_tokens: + return image_embeds + if image_embeds.dim() != 2: + return image_embeds + + visual = _find_qwen_visual_with_merger(model) + if visual is None: + return image_embeds + + spatial_merge_unit = getattr(visual, "spatial_merge_unit", None) + if spatial_merge_unit is not None: + expected_raw_tokens = expected_tokens * int(spatial_merge_unit) + if image_embeds.shape[0] != expected_raw_tokens: + return image_embeds + + try: + merged = visual.merger(image_embeds) + except Exception: + return image_embeds + + return merged if isinstance(merged, torch.Tensor) else image_embeds + + +def get_qwen_image_embeds( + model: torch.nn.Module, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + expected_tokens: int | None = None, +) -> torch.Tensor: + """ + Return Qwen image embeddings at the token level expected by the LM. + + Prefer the full Hugging Face VLM helper when available, since it owns the + exact visual merge/projector path. Fall back to the resolved visual module + for compiled adapters and older model implementations. + """ + get_image_features = getattr(model, "get_image_features", None) + if callable(get_image_features): + call_attempts = ( + lambda: get_image_features( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ), + lambda: get_image_features(pixel_values, image_grid_thw), + lambda: get_image_features(pixel_values), + ) + for call in call_attempts: + try: + image_embeds = extract_vision_tensor(call()) + return _maybe_merge_qwen_image_embeds( + model, image_embeds, expected_tokens + ) + except TypeError: + continue + + image_embeds = extract_vision_tensor( + get_vision_model(model)(pixel_values, image_grid_thw) + ) + return _maybe_merge_qwen_image_embeds(model, image_embeds, expected_tokens) + + +def _get_qwen_rope_owner(model: torch.nn.Module): + if hasattr(model, "get_rope_index"): + return model + parent = getattr(model, "model", None) + if hasattr(parent, "get_rope_index"): + return parent + return None + + +def _get_qwen_config_attr(model: torch.nn.Module, attr_name: str): + for owner in (model, getattr(model, "model", None)): + config = getattr(owner, "config", None) + value = getattr(config, attr_name, None) + if value is not None: + return value + return None + + +def get_qwen_mm_token_type_ids( + model: torch.nn.Module, + input_ids: torch.Tensor, +) -> torch.Tensor: + """ + Build Qwen multimodal token type ids from special image/video tokens. + + Qwen uses 0 for text tokens, 1 for image tokens, and 2 for video tokens. + """ + mm_token_type_ids = torch.zeros_like(input_ids, dtype=torch.int32) + + image_token_id = _get_qwen_config_attr(model, "image_token_id") + if image_token_id is not None: + mm_token_type_ids = torch.where( + input_ids == int(image_token_id), + torch.ones_like(mm_token_type_ids), + mm_token_type_ids, + ) + + video_token_id = _get_qwen_config_attr(model, "video_token_id") + if video_token_id is not None: + mm_token_type_ids = torch.where( + input_ids == int(video_token_id), + torch.full_like(mm_token_type_ids, 2), + mm_token_type_ids, + ) + + return mm_token_type_ids + + +def get_qwen_position_ids( + model: torch.nn.Module, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + mm_token_type_ids: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Return Qwen2.5-VL multimodal RoPE position ids when the model exposes + get_rope_index, otherwise fall back to plain text position ids. + """ + rope_owner = _get_qwen_rope_owner(model) + if rope_owner is None: + return torch.arange( + input_ids.shape[1], dtype=torch.long, device=input_ids.device + ).unsqueeze(0) + + kwargs = {"input_ids": input_ids} + if image_grid_thw is not None: + kwargs["image_grid_thw"] = image_grid_thw + if video_grid_thw is not None: + kwargs["video_grid_thw"] = video_grid_thw + if second_per_grid_ts is not None: + kwargs["second_per_grid_ts"] = second_per_grid_ts + if attention_mask is not None: + kwargs["attention_mask"] = attention_mask + if mm_token_type_ids is None: + mm_token_type_ids = get_qwen_mm_token_type_ids(model, input_ids) + kwargs["mm_token_type_ids"] = mm_token_type_ids + + try: + position_ids = rope_owner.get_rope_index(**kwargs) + except TypeError: + try: + position_ids = rope_owner.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + mm_token_type_ids, + ) + except TypeError: + position_ids = rope_owner.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + mm_token_type_ids, + ) + + if isinstance(position_ids, tuple): + position_ids = position_ids[0] + return position_ids.to(device=input_ids.device, dtype=torch.long) + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16, position_ids=None): """ Exports the LLM model into an ExportedProgram with dynamic shapes. In the case of guard failures due to some PyTorch kernel implements, we also @@ -19,7 +376,9 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): with torch.no_grad(): # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) - position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + if position_ids is None: + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + position_seq_dim = position_ids.dim() - 1 try: print("Trying to export the model using torch.export.export()..") # strict=False only enables aotautograd tracing and excludes dynamo. @@ -27,7 +386,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): model, args=(inputs,), kwargs={"position_ids": position_ids}, - dynamic_shapes=({1: seq_len}, {1: seq_len}), + dynamic_shapes=({1: seq_len}, {position_seq_dim: seq_len}), strict=False, ) except: @@ -39,7 +398,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): model, args=(inputs,), kwargs={"position_ids": position_ids}, - dynamic_shapes=({1: seq_len}, {1: seq_len}), + dynamic_shapes=({1: seq_len}, {position_seq_dim: seq_len}), strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) @@ -587,6 +946,10 @@ def _prepare_qwen_mm_inputs( """ vision_time = 0.0 image_embeds = None + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + image_mask = seq_tokens == model.config.image_token_id + num_image_tokens = image_mask.sum().item() if pixel_values is not None: if with_timing: @@ -594,24 +957,27 @@ def _prepare_qwen_mm_inputs( vision_end = torch.cuda.Event(enable_timing=True) vision_start.record() - image_embeds = model.visual(pixel_values, image_grid_thw) + image_embeds = get_qwen_image_embeds( + model, + pixel_values, + image_grid_thw, + expected_tokens=num_image_tokens, + ) if with_timing: vision_end.record() torch.cuda.synchronize() vision_time = vision_start.elapsed_time(vision_end) - seq_tokens = input_ids.clone() - seq_embeds = emb_layer(seq_tokens) - if image_embeds is not None: - mask = seq_tokens == model.config.image_token_id - num_image_tokens = mask.sum().item() if num_image_tokens != image_embeds.shape[0]: raise ValueError( - f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + "Number of image tokens " + f"({num_image_tokens}) does not match number of image embeddings " + f"({image_embeds.shape[0]}). Image embedding shape: " + f"{tuple(image_embeds.shape)}." ) - mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + mask_expanded = image_mask.unsqueeze(-1).expand_as(seq_embeds) seq_embeds = seq_embeds.masked_scatter( mask_expanded, image_embeds.to(seq_embeds.dtype) ) @@ -627,6 +993,7 @@ def generate_mm_qwen2_5_vl( pixel_values: torch.Tensor | None, input_ids: torch.Tensor, image_grid_thw: torch.Tensor, + attention_mask: torch.Tensor | None, eos_token_id: int, emb_layer: torch.nn.Embedding, max_new_tokens: int = 64, @@ -635,6 +1002,8 @@ def generate_mm_qwen2_5_vl( """ Custom generation function for the Qwen2_5_VLForConditionalGeneration model, with optional timing. """ + language_model = get_language_model(model) + if with_timing: overall_start = torch.cuda.Event(enable_timing=True) overall_end = torch.cuda.Event(enable_timing=True) @@ -663,20 +1032,24 @@ def generate_mm_qwen2_5_vl( step_times = [] generated = 0 + seq_attention_mask = ( + attention_mask.clone() + if attention_mask is not None + else torch.ones_like(seq_tokens, dtype=torch.long) + ) while generated < max_new_tokens: if with_timing: lm_start.record() - position_ids = ( - torch.arange( - 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device - ) - .unsqueeze(0) - .expand(seq_embeds.size(0), seq_embeds.size(1)) + position_ids = get_qwen_position_ids( + model, + seq_tokens, + image_grid_thw=image_grid_thw, + attention_mask=seq_attention_mask, ) with torch.no_grad(): - outputs = model.model( + outputs = language_model( inputs_embeds=seq_embeds, position_ids=position_ids, ) @@ -695,6 +1068,18 @@ def generate_mm_qwen2_5_vl( step_times.append(lm_start.elapsed_time(lm_end)) seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + seq_attention_mask = torch.cat( + [ + seq_attention_mask, + torch.ones( + seq_attention_mask.shape[0], + 1, + dtype=seq_attention_mask.dtype, + device=seq_attention_mask.device, + ), + ], + dim=1, + ) next_emb = emb_layer(next_tok)[:, None, :] seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) @@ -722,6 +1107,7 @@ def generate_mm_qwen2_5_vl_with_static_cache( pixel_values: torch.Tensor | None, input_ids: torch.Tensor, image_grid_thw: torch.Tensor, + attention_mask: torch.Tensor | None, eos_token_id: int, emb_layer: torch.nn.Embedding, max_new_tokens: int = 64, @@ -731,6 +1117,8 @@ def generate_mm_qwen2_5_vl_with_static_cache( """ Greedy Decoder for Qwen-2.5-VL using static KV-cache, with optional timing. """ + language_model = get_language_model(model) + if with_timing: overall_start = torch.cuda.Event(enable_timing=True) overall_end = torch.cuda.Event(enable_timing=True) @@ -758,7 +1146,7 @@ def generate_mm_qwen2_5_vl_with_static_cache( ) kv_cache = get_zeroed_static_cache_inputs( - model.model, device=device, has_position_ids=True + language_model, device=device, has_position_ids=True ) start_idx = 0 end_idx = seq_embeds.size(1) @@ -766,6 +1154,11 @@ def generate_mm_qwen2_5_vl_with_static_cache( max_total_len = end_idx + max_new_tokens output_tokens = seq_tokens.clone() step_times = [] + seq_attention_mask = ( + attention_mask.clone() + if attention_mask is not None + else torch.ones_like(output_tokens, dtype=torch.long) + ) while output_tokens.size(1) < max_total_len: if with_timing: @@ -773,14 +1166,13 @@ def generate_mm_qwen2_5_vl_with_static_cache( cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] - if generated == 0: - position_ids = ( - torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) - ) - else: - position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( - cur_embeds.device - ) + full_position_ids = get_qwen_position_ids( + model, + output_tokens, + image_grid_thw=image_grid_thw, + attention_mask=seq_attention_mask, + ) + position_ids = full_position_ids if generated == 0 else full_position_ids[..., -1:] input_signature = ( cur_embeds, @@ -790,7 +1182,7 @@ def generate_mm_qwen2_5_vl_with_static_cache( end_idx, ) - outputs_and_kv = model.model(*input_signature) + outputs_and_kv = language_model(*input_signature) hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] logits = model.lm_head(hidden_states[:, -1, :]) @@ -799,6 +1191,18 @@ def generate_mm_qwen2_5_vl_with_static_cache( next_embed = emb_layer(next_tok)[:, None, :] seq_embeds = next_embed + seq_attention_mask = torch.cat( + [ + seq_attention_mask, + torch.ones( + seq_attention_mask.shape[0], + 1, + dtype=seq_attention_mask.dtype, + device=seq_attention_mask.device, + ), + ], + dim=1, + ) generated += 1 start_idx = end_idx