diff --git a/examples/auto_deploy/cookbooks/glm_4.7_flash_trtllm_cookbook.ipynb b/examples/auto_deploy/cookbooks/glm_4.7_flash_trtllm_cookbook.ipynb new file mode 100644 index 0000000000..e4733b24be --- /dev/null +++ b/examples/auto_deploy/cookbooks/glm_4.7_flash_trtllm_cookbook.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Deploying GLM-4.7-Flash with TensorRT-LLM\n", + "\n", + "This notebook walks you through deploying the `zai-org/GLM-4.7-Flash` model using TensorRT-LLM.\n", + "\n", + "[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating and optimizing LLM inference on NVIDIA GPUs. Support for GLM-4.7-Flash is enabled through the AutoDeploy workflow. More details about AutoDeploy can be found [here](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n", + "\n", + "**Model Resources:**\n", + "- [HuggingFace Model Card](https://huggingface.co/zai-org/GLM-4.7-Flash)\n", + "- [Technical Blog](https://z.ai/blog/glm-4.7)\n", + "- [Technical Report (GLM-4.5)](https://arxiv.org/abs/2508.06471)\n", + "- [Z.ai API Platform](https://docs.z.ai/guides/llm/glm-4.7)\n", + "\n", + "**Model Highlights:**\n", + "- 30B-A3B Mixture of Experts (MoE) architecture\n", + "- 131,072 token context length\n", + "- Tool calling support\n", + "- MIT License\n", + "\n", + "**Prerequisites:**\n", + "- NVIDIA GPU with recent drivers (≥ 64 GB VRAM for BF16) and CUDA 12.x\n", + "- Python 3.10+\n", + "- TensorRT-LLM ([container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release) or pip install)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites & Environment\n", + "\n", + "Set up a containerized environment for TensorRT-LLM by running the following command in a terminal:\n", + "\n", + "```shell\n", + "docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n", + "```\n", + "\n", + "You now have TensorRT-LLM set up!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If pip not found\n", + "!python -m ensurepip --default-pip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torch openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Verify GPU\n", + "\n", + "Check that CUDA is available and the GPU is detected correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python: 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0]\n", + "CUDA available: True\n", + "Num GPUs: 8\n", + "GPU[0]: NVIDIA H100 80GB HBM3\n", + "GPU[1]: NVIDIA H100 80GB HBM3\n", + "GPU[2]: NVIDIA H100 80GB HBM3\n", + "GPU[3]: NVIDIA H100 80GB HBM3\n", + "GPU[4]: NVIDIA H100 80GB HBM3\n", + "GPU[5]: NVIDIA H100 80GB HBM3\n", + "GPU[6]: NVIDIA H100 80GB HBM3\n", + "GPU[7]: NVIDIA H100 80GB HBM3\n" + ] + } + ], + "source": [ + "# Environment check\n", + "import sys\n", + "\n", + "import torch\n", + "\n", + "print(f\"Python: {sys.version}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", + "print(f\"Num GPUs: {torch.cuda.device_count()}\")\n", + "\n", + "if torch.cuda.is_available():\n", + " for i in range(torch.cuda.device_count()):\n", + " print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI-Compatible Server\n", + "\n", + "Start a local OpenAI-compatible server with TensorRT-LLM via the terminal, within the running docker container.\n", + "\n", + "Ensure that the following commands are executed from the docker terminal.\n", + "\n", + "Start with the GLM 4.7 Flash Yaml here: `examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the Model\n", + "\n", + "Launch the TensorRT-LLM server with GLM-4.7-Flash:\n", + "\n", + "```shell\n", + "trtllm-serve \"zai-org/GLM-4.7-Flash\" \\\n", + " --host 0.0.0.0 \\\n", + " --port 8000 \\\n", + " --backend _autodeploy \\\n", + " --trust_remote_code \\\n", + " --extra_llm_api_options examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Your server is now running!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the API\n", + "\n", + "Use the OpenAI-compatible client to send requests to the TensorRT-LLM server." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "# Setup client\n", + "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", + "API_KEY = \"null\"\n", + "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)\n", + "\n", + "MODEL_ID = \"zai-org/GLM-4.7-Flash\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Chat Completion Example\n", + "==================================================\n", + "Response:\n", + "1. **Analyze the Request:** The user wants to know 15% of 85 and wants to see the reasoning behind the calculation.\n", + "\n", + "2. **Identify the Core Task:** Calculate $15\\% \\times 85$.\n", + "\n", + "3. **Determine the Mathematical Approach:** There are several ways to solve this:\n", + " * *Method 1: Fraction multiplication.* Convert 15% to a fraction ($\\frac{15}{100}$), then multiply by 85.\n", + " * *Method 2: Decimal multiplication.* Convert 15% to a decimal ($0.15$), then multiply by 85.\n", + " * *Method 3: Decomposition (Breaking it down).* $15\\% = 10\\% + 5\\%$.\n", + " * $10\\%$ of $85 = 8.5$\n", + " * $5\\%$ of $85 = \\frac{8.5}{2} = 4.25$\n", + " * Sum: $8.5 + 4.25 = 12.75$\n", + "\n", + "4. **Select the Best Approach for Explanation:** Method 3 is often easiest for a general audience to follow step-by-step because it avoids dealing with decimals until the end or simplifies large multiplications. Method 2 is the most direct standard school method. I will use Method 3 (Splitting 15% into 10% and 5%) as the primary reasoning because it is intuitive, but I might briefly mention the standard formula ($\\frac{\\text{percent}}{100} \\times \\text{number}$).\n", + "\n", + "5. **Execute the Calculation (Method 3):**\n", + " * Step 1: Find 10% of 85.\n", + " * Moving the decimal point one place to the left: 8.5.\n", + " * Step 2: Find 5% of 85.\n", + " * Since 5% is half of 10%, take half of 8.5.\n", + " * $8.5 / 2 = 4.25$.\n", + " * Step 3: Add them together.\n", + " * $8.5 + 4.25$.\n", + " * $8.5 + 4 = 12.5$.\n", + " * $12.5 + 0\n" + ] + } + ], + "source": [ + "# Basic chat completion\n", + "print(\"Chat Completion Example\")\n", + "print(\"=\" * 50)\n", + "\n", + "response = client.chat.completions.create(\n", + " model=MODEL_ID,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n", + " ],\n", + " temperature=1,\n", + " top_p=0.95,\n", + " max_tokens=512,\n", + ")\n", + "\n", + "print(\"Response:\")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Streaming response:\n", + "==================================================\n", + "1. **Analyze the Request:** The user is asking for the \"first 5 prime numbers\".\n", + "\n", + "2. **Define \"Prime Number\":** A prime number is a natural number greater than 1 that is not a product of two smaller natural numbers. In other words, it has exactly two distinct positive divisors: 1 and itself.\n", + "\n", + "3. **Identify the First Numbers:**\n", + " * Start checking from 1 (exclusive).\n", + " * Check 2: Divisors are 1 and 2. Prime. (1st)\n", + " * Check 3: Divisors are 1 and 3. Prime. (2nd)\n", + " * Check 4: Divisors are 1, 2, 4. Not prime (2 * 2).\n", + " * Check 5: Divisors are 1 and 5. Prime. (3rd)\n", + " * Check 6: Divisors are 1, 2, 3, 6. Not prime.\n", + " * Check 7: Divisors are 1 and 7. Prime. (4th)\n", + " * Check 8: Divisors are 1, 2, 4, 8. Not prime.\n", + " * Check 9: Divisors are 1, 3, 9. Not prime.\n", + " * Check 10: Divisors are 1, 2, 5, 10. Not prime.\n", + " * Check 11: Divisors are 1 and 11. Prime. (5th)\n", + "\n", + "4. **Compile the List:** 2, 3, 5, 7, 11.\n", + "\n", + "5. **Formulate the Output:** Present the list clearly.\n", + "\n", + "6. **Final Review:** Does this answer the user's prompt accurately? Yes.The first 5 prime numbers are **2, 3, 5, 7, and 11**." + ] + } + ], + "source": [ + "# Streaming chat completion\n", + "print(\"Streaming response:\")\n", + "print(\"=\" * 50)\n", + "\n", + "stream = client.chat.completions.create(\n", + " model=MODEL_ID,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n", + " ],\n", + " temperature=0.7,\n", + " max_tokens=1024,\n", + " stream=True,\n", + ")\n", + "\n", + "for chunk in stream:\n", + " if chunk.choices[0].delta.content:\n", + " print(chunk.choices[0].delta.content, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation Parameters\n", + "\n", + "For optimal results, use the following parameters based on your task:\n", + "\n", + "**Default Settings (Most Tasks)**\n", + "- `temperature`: 1.0\n", + "- `top_p`: 0.95\n", + "- `max_tokens`: 131072\n", + "\n", + "**Agentic Tasks (SWE-bench, Terminal Bench)**\n", + "- `temperature`: 0.7\n", + "- `top_p`: 1.0\n", + "- `max_tokens`: 16384\n", + "\n", + "**Deterministic Tasks**\n", + "- `temperature`: 0\n", + "- `max_tokens`: 16384" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional Resources\n", + "\n", + "- [TensorRT-LLM Documentation](https://nvidia.github.io/TensorRT-LLM/)\n", + "- [AutoDeploy Guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n", + "- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)\n", + "- [Z.ai Discord Community](https://discord.gg/QR7SARHRxK)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml b/examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml new file mode 100644 index 0000000000..9d8e16234d --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml @@ -0,0 +1,8 @@ +compile_backend: torch-cudagraph +max_batch_size: 64 +max_seq_len: 4096 +enable_chunked_prefill: true +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64] +transforms: + fuse_nvfp4_moe: + allow_different_input_scales: true diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index fb268d7b39..e1dd4be29f 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -223,3 +223,5 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'llama4_maverick_lite.yaml'] - name: nvidia/NVIDIA-Nemotron-3-Super-120B-BF16-BF16KV-010726 yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml','super_v3.yaml'] +- name: zai-org/GLM-4.7-Flash + yaml_extra: ['glm-4.7-flash.yaml'] diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 5957bd4409..84aca711e3 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -162,15 +162,16 @@ transforms: visualize_namespace: stage: visualize enabled: false - ############################################################################################ + ########################################################################################### # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES - ############################################################################################ + ########################################################################################### insert_cached_attention: stage: cache_init backend: flashinfer insert_cached_mla_attention: stage: cache_init - backend: MultiHeadLatentAttention + requires_shape_prop: true + backend: flashinfer_mla insert_cached_ssm_attention: stage: cache_init backend: triton_ssm diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index 27f18bed1b..addc6cc222 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -5,38 +5,3 @@ All AutoDeploy custom operators follow the following naming convention: `torch.ops.auto_deploy.__` The table below lists the operators ordered by their backend. - -### Available Custom Operators - -| Operator Name | Description | -|--------------|-------------| -| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support | -| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation | -| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) | -| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation | -| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported | -| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention | -| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation | -| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) | -| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) | -| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation | -| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation | -| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation | -| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values | -| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer | -| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer | -| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies | -| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine | -| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving | -| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache | -| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion | -| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache | -| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache | -| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache | -| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support | -| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs | -| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions | -| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation | -| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) | -| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) | -| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py index da76b1e52e..8d0d819300 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py @@ -19,7 +19,6 @@ import math from typing import Optional import torch -import torch.nn as nn import torch.nn.functional as F @@ -242,299 +241,3 @@ def torch_attention_fake( layout: str = "bnsd", ): return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() - - -def update_kv_cache( - key_states: torch.Tensor, - value_states: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - seq_len: torch.Tensor, # metadata - input_pos: torch.Tensor, # metadata - slot_idx: torch.Tensor, - seq_start: torch.Tensor, -) -> torch.Tensor: - """ - Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D]. - This function can be used to build reference attention implementations that use KV cache. - """ - - for idx in range(seq_len.shape[0]): - k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[ - seq_start[idx] : seq_start[idx] + seq_len[idx], ... - ] - v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[ - seq_start[idx] : seq_start[idx] + seq_len[idx], ... - ] - - -@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=()) -def fused_mla_ref( - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv: torch.Tensor, - k_pe: torch.Tensor, - seq_len: torch.Tensor, # metadata - input_pos: torch.Tensor, # metadata - cache_loc: torch.Tensor, - seq_start: torch.Tensor, - k_cache: torch.Tensor, # caches - v_cache: torch.Tensor, # caches - freqs_cis: torch.Tensor, - softmax_scale: Optional[float] = None, -) -> torch.Tensor: - """ - Reference implementation for Fused MLA with KV cache support. - This implementation flattens the inputs and can be used as a reference to debug the triton kernels. - """ - # Compute parameters - bs, num_heads, q_len, qk_nope_head_dim = q_nope.shape - qk_rope_head_dim = q_pe.shape[-1] - v_head_dim = kv.shape[-1] - qk_nope_head_dim - q_head_dim = qk_nope_head_dim + qk_rope_head_dim - - # Flatten inputs - bs_view = (bs * q_len,) - - k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1) - - q_nope = q_nope.transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous() - q_pe = q_pe.transpose(1, 2).clone().view(*bs_view, num_heads, qk_rope_head_dim).contiguous() - k_nope = k_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous() - k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous() - value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous() - - if freqs_cis is not None: - cos_base = freqs_cis[0, ...] - sin_base = freqs_cis[1, ...] - for i in range(seq_len.shape[0]): - start = seq_start[i] - length = seq_len[i] - if q_len == 1: - idx = (input_pos[i] + length - 1).item() - pos_ids = torch.tensor(idx, device=cos_base.device) - else: - pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device) - - cos = cos_base[pos_ids] # [..., 1, head_dim] - sin = sin_base[pos_ids] - q_slice = q_pe[start : start + length] - k_slice = k_pe[start : start + length] - - q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( - q_slice, - k_slice, - cos, - sin, - -2, - ) - - q_pe[start : start + length] = q_rot - k_pe[start : start + length] = k_rot - - query_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) # [b*s,n,d] - query_states[..., :qk_nope_head_dim] = q_nope - query_states[..., qk_nope_head_dim:] = q_pe - - key_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) - key_states[..., :qk_nope_head_dim] = k_nope - key_states[..., qk_nope_head_dim:] = k_pe - - # Update KV cache - update_kv_cache( - key_states, value_states, k_cache, v_cache, seq_len, input_pos, cache_loc, seq_start - ) - - # Compute attention - attn_outputs = [] - for idx in range(seq_len.shape[0]): - # Get inputs from KV cache - k = k_cache[cache_loc[idx], : input_pos[idx] + seq_len[idx], :, :] # [kv_seq_len, n, d] - v = v_cache[cache_loc[idx], : input_pos[idx] + seq_len[idx], :, :] # [kv_seq_len, n, d] - # Generate attention mask - if q_len == 1: - # Generate phase - single token attention mask - attn_mask = torch.zeros( - 1, input_pos[idx] + 1, device=query_states.device, dtype=query_states.dtype - ) - else: - # Context phase - causal attention mask - temp_mask = torch.ones( - seq_len[idx], - input_pos[idx] + seq_len[idx], - dtype=torch.bool, - device=query_states.device, - ).tril(diagonal=0) - attn_bias = torch.zeros( - seq_len[idx], - input_pos[idx] + seq_len[idx], - device=query_states.device, - dtype=query_states.dtype, - ) - attn_mask = attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")).to( - query_states.device - ) - - # Compute attention weights - attn_weights = ( - torch.matmul( - query_states[seq_start[idx] : seq_start[idx] + seq_len[idx], :, :].transpose(0, 1), - k.transpose(0, 1).transpose(1, 2), - ) - * 1 - / math.sqrt(query_states.size(-1)) - ) - attn_weights = attn_weights + attn_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = torch.nn.functional.dropout(attn_weights, p=0.0, training=False) - attn_output = torch.matmul(attn_weights, v.transpose(0, 1)) - attn_outputs.append(attn_output) - - if q_len == 1: - attn_output = torch.stack(attn_outputs) - else: - attn_output = torch.cat(attn_outputs, dim=-2).unsqueeze(0) - - return attn_output - - -@fused_mla_ref.register_fake -def fused_mla_ref_fake( - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv: torch.Tensor, - k_pe: torch.Tensor, - seq_len: torch.Tensor, # metadata - input_pos: torch.Tensor, # metadata - cache_loc: torch.Tensor, - seq_start: torch.Tensor, - k_cache: torch.Tensor, # caches - v_cache: torch.Tensor, # caches - freqs_cis: torch.Tensor, - softmax_scale: Optional[float] = None, -): - """Fake Fused MLA+Rope with KV cache support.""" - v_head_dim = kv.shape[-1] - q_nope.shape[-1] - return torch.empty_like(kv[..., -v_head_dim:]) - - -@torch.library.custom_op("auto_deploy::torch_attention_deepseek_fused_mla", mutates_args=()) -def fused_mla( - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv: torch.Tensor, - k_pe: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor, - softmax_scale: float, -) -> torch.Tensor: - """MultiHeadLatentAttention as implemented in DeepSeekV3Attention. This does not capture KV cache use/update.""" - # Did not implement KV cache logic, since we would be writing our own custom op and inserting KV caches later - # Compute parameters - bs, num_heads, q_len, qk_nope_head_dim = q_nope.shape - qk_rope_head_dim = q_pe.shape[-1] - v_head_dim = kv.shape[-1] - qk_nope_head_dim - q_head_dim = qk_nope_head_dim + qk_rope_head_dim - - k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - - cos = cos[position_ids] - sin = sin[position_ids] - q_pe, k_pe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(q_pe, k_pe, cos, sin) - - query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim) - query_states[:, :, :, :qk_nope_head_dim] = q_nope - query_states[:, :, :, qk_nope_head_dim:] = q_pe - - key_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim) - key_states[:, :, :, :qk_nope_head_dim] = k_nope - key_states[:, :, :, qk_nope_head_dim:] = k_pe - - # Use old logic - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale - - if attn_weights.size() != (bs, num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bs, num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bs, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bs, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = nn.functional.dropout(attn_weights, p=0.0, training=False) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bs, num_heads, q_len, v_head_dim): - raise ValueError( - f"`attn_output` should be of size {(bs, num_heads, q_len, v_head_dim)}, but is" - f" {attn_output.size()}" - ) - - # We do not return attn_weights along with attn_output - return attn_output - - -@fused_mla.register_fake -def fused_mla( - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv: torch.Tensor, - k_pe: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor, - softmax_scale: float, -) -> torch.Tensor: - v_head_dim = kv.shape[-1] - q_nope.shape[-1] - return torch.empty_like(kv[..., -v_head_dim:]) - - -@torch.library.custom_op("auto_deploy::torch_attention_deepseek_mla", mutates_args=()) -def mla( - q_nope: torch.Tensor, # Down projected q_nope - q_pe: torch.Tensor, # q_pe after applying rope - kv: torch.Tensor, # compressed kv after passing through layernorm - pe: torch.Tensor, # k_pe after applying rope - attention_mask: torch.Tensor, # attention mask - softmax_scale: float, # softmax scale -) -> torch.Tensor: - """ - Reference implementation for MLA style attention that handles compressed kv. - """ - scores = ( - torch.einsum("bhsc,btc->bsht", q_nope, kv) + torch.einsum("bhsr,btr->bsht", q_pe, pe) - ) * softmax_scale - if attention_mask is not None: - scores += attention_mask.unsqueeze(1) - scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(q_nope) - attn_output = torch.einsum("bsht,btc->bshc", scores, kv) - return attn_output - - -@mla.register_fake -def mla( - q_nope: torch.Tensor, # Down projected q_nope - q_pe: torch.Tensor, # q_pe after applying rope - kv: torch.Tensor, # compressed kv after passing through layernorm - k_pe: torch.Tensor, # k_pe after applying rope - attention_mask: torch.Tensor, # attention mask - softmax_scale: float, # softmax scale -) -> torch.Tensor: - """MLA style attention that handles compressed kv.""" - return torch.empty_like(q_nope) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py index a8f68574c5..36c8d54d4e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py @@ -35,7 +35,31 @@ from ..attention_interface import ( ResourceHandlerDict, UnpagedResourceHandler, ) -from .torch_attention import repeat_kv, update_kv_cache +from .torch_attention import repeat_kv + + +def _update_kv_cache( + key_states: torch.Tensor, + value_states: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + seq_len: torch.Tensor, # metadata + input_pos: torch.Tensor, # metadata + slot_idx: torch.Tensor, + seq_start: torch.Tensor, +) -> torch.Tensor: + """ + Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D]. + This function can be used to build reference attention implementations that use KV cache. + """ + + for idx in range(seq_len.shape[0]): + k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[ + seq_start[idx] : seq_start[idx] + seq_len[idx], ... + ] + v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[ + seq_start[idx] : seq_start[idx] + seq_len[idx], ... + ] def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor: @@ -149,7 +173,7 @@ def _torch_context_mha( ) -> None: """Context attention (multiple tokens, potentially multiple sequences) using existing torch functions.""" # Update KV cache first using existing function - update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start) + _update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start) # Compute attention for each sequence attn_outputs = [] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py index b2c4737b67..261617de1b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py @@ -1,24 +1,21 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""MLA (Multi-head Latent Attention) custom ops. -"""Multi-head Latent Attention operations. - -This module provides Multi-head Latent Attention (MLA) implementations: -- mla: MLA operations and attention descriptor +Exports: +- TorchBackendMLAAttention: Attention descriptor for MLA (registered as "torch_mla") +- FlashInferMLAAttention: Attention descriptor for FlashInfer MLA (registered as "flashinfer_mla") +- torch_mla: Source op for MLA attention +- torch_backend_mla_with_cache: Cached backend op with FlashInfer-compatible cache +- flashinfer_mla_with_cache: Cached backend op using FlashInfer MLA kernels """ +from .flashinfer_mla import FlashInferMLAAttention, flashinfer_mla_with_cache +from .torch_backend_mla import TorchBackendMLAAttention, torch_backend_mla_with_cache +from .torch_mla import torch_mla + __all__ = [ - "mla", + "TorchBackendMLAAttention", + "FlashInferMLAAttention", + "torch_mla", + "torch_backend_mla_with_cache", + "flashinfer_mla_with_cache", ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py new file mode 100644 index 0000000000..0171713aae --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py @@ -0,0 +1,957 @@ +"""FlashInfer-based MLA (Multi-head Latent Attention) backend with paged caching. + +This module provides: +- FlashInferMLAAttention: attention descriptor using FlashInfer MLA kernels +- flashinfer_mla_with_cache: cached backend op with paged KV cache + +FlashInfer MLA uses: +- Regular prefill (input_pos == 0): BatchPrefillWithRaggedKVCacheWrapper with expanded K, V +- Chunked prefill (input_pos > 0): BatchMLAPagedAttentionWrapper with matrix absorption +- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache + +FlashInfer MLA Cache Layout (two separate caches): + ckv_cache: [num_pages, page_size, kv_lora_rank] + kpe_cache: [num_pages, page_size, qk_rope_head_dim] + - No num_heads dimension (MLA-specific optimization) + +Reference: https://docs.flashinfer.ai/api/mla.html +""" + +import math +from dataclasses import dataclass, fields +from math import prod +from typing import Dict, List, Literal, Optional, Tuple + +import flashinfer +import torch +from torch._ops import OpOverloadPacket +from torch._subclasses import FakeTensor +from torch.fx import Node + +from .....llmapi.llm_args import KvCacheConfig +from ...utils.cuda_graph import cuda_graph_state +from ..attention_interface import ( + AttentionDescriptor, + AttentionLayout, + AttentionRegistry, + Constant, + MHACallable, + PrepareMetadataCallable, + PrepareMetadataHostCallable, + ResourceHandler, + ResourceHandlerDict, + SequenceInfo, +) + + +@dataclass +class MLADecodePlanParams: + """Parameters that affect the FlashInfer MLA decode execution plan.""" + + num_heads: int + kv_lora_rank: int # head_dim_ckv + qk_rope_head_dim: int # head_dim_kpe + qk_nope_head_dim: int + v_head_dim: int + num_seq: int + page_size: int + q_dtype: torch.dtype + kv_dtype: torch.dtype + sm_scale: Optional[float] = None + + def __hash__(self): + """Convert all fields to a string representation and concatenate them.""" + return hash("_".join([str(getattr(self, f.name)) for f in fields(self)])) + + +@dataclass +class MLAPrefillPlanParams: + """Parameters that affect the FlashInfer MLA prefill execution plan.""" + + num_heads: int + num_kv_heads: int # For MLA with expanded KV, same as num_heads + head_dim_qk: int # qk_nope_head_dim + qk_rope_head_dim + head_dim_vo: int # v_head_dim (value/output head dimension) + num_seq: int + q_dtype: torch.dtype + kv_dtype: torch.dtype + sm_scale: Optional[float] = None + + def __hash__(self): + """Convert all fields to a string representation and concatenate them.""" + return hash("_".join([str(getattr(self, f.name)) for f in fields(self)])) + + +class _FlashInferMLAPlanner: + """A class interface to handle FlashInfer MLA-related planning/wrapping operations. + + For MLA attention: + - Regular prefill uses BatchPrefillWithRaggedKVCacheWrapper with expanded K, V tensors + - Chunked prefill uses BatchMLAPagedAttentionWrapper with matrix absorption (same as decode) + - Decode uses BatchMLAPagedAttentionWrapper with paged compressed KV cache + """ + + workspace_buffer: Optional[torch.Tensor] + prefill_wrapper: Optional[flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper] + decode_wrapper: Optional["flashinfer.mla.BatchMLAPagedAttentionWrapper"] + # Separate wrapper for chunked/incremental prefill (uses same kernel as decode but different planning) + chunked_prefill_wrapper: Optional["flashinfer.mla.BatchMLAPagedAttentionWrapper"] + cached_cuda_graph_decode_wrappers: Dict[ + MLADecodePlanParams, "flashinfer.mla.BatchMLAPagedAttentionWrapper" + ] + plan_params_prefill: Optional[MLAPrefillPlanParams] + plan_params_decode: Optional[MLADecodePlanParams] + plan_params_chunked_prefill: Optional[MLADecodePlanParams] + kv_layout: Literal["NHD", "HND"] = "NHD" + + def __init__(self): + self.workspace_buffer = None + self.prefill_wrapper = None + self.decode_wrapper = None + self.chunked_prefill_wrapper = None + self.cached_cuda_graph_decode_wrappers = {} + self.plan_params_prefill = None + self.plan_params_decode = None + self.plan_params_chunked_prefill = None + + def _init_decode_wrapper( + self, + use_cuda_graph: bool = False, + qo_indptr: Optional[torch.Tensor] = None, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_len_arr: Optional[torch.Tensor] = None, + ): + assert self.workspace_buffer is not None + if use_cuda_graph: + return flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + use_cuda_graph=True, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + kv_len_arr=kv_len_arr, + ) + else: + return flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + use_cuda_graph=False, + ) + + def reset(self, device: torch.device) -> None: + self.plan_params_prefill = None + self.plan_params_decode = None + self.plan_params_chunked_prefill = None + + if isinstance(self.workspace_buffer, torch.Tensor): + return + + self.__init__() # reset all state + + # NOTE: avoid OOM for many cudagraphs + self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8) + + # Prefill uses BatchPrefillWithRaggedKVCacheWrapper with expanded K, V + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, + self.kv_layout, + ) + # Decode uses BatchMLAPagedAttentionWrapper with paged compressed KV cache + self.decode_wrapper = self._init_decode_wrapper() + # Chunked prefill uses same kernel as decode but with variable-length queries + self.chunked_prefill_wrapper = self._init_decode_wrapper() + + def plan_prefill( + self, + qo_indptr_host: torch.Tensor, + kv_indptr_host: torch.Tensor, + plan_params: MLAPrefillPlanParams, + ) -> flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper: + """Plan prefill using BatchPrefillWithRaggedKVCacheWrapper. + + For MLA prefill, we expand compressed_kv to get full K, V tensors + and use standard ragged KV cache attention with causal masking. + + Args: + qo_indptr_host: Cumulative query/output lengths on host. + kv_indptr_host: Cumulative key/value lengths on host. + plan_params: Parameters for planning (hashable, no tensors). + """ + if plan_params != self.plan_params_prefill: + self.prefill_wrapper.plan( + qo_indptr_host, + kv_indptr_host, + plan_params.num_heads, + plan_params.num_kv_heads, + plan_params.head_dim_qk, + head_dim_vo=plan_params.head_dim_vo, + use_fp16_qk_reduction=False, + causal=True, + q_data_type=plan_params.q_dtype, + kv_data_type=plan_params.kv_dtype, + sm_scale=plan_params.sm_scale, + ) + self.plan_params_prefill = plan_params + + return self.prefill_wrapper + + def _plan_mla_wrapper( + self, + wrapper: "flashinfer.mla.BatchMLAPagedAttentionWrapper", + qo_indptr: torch.Tensor, + kv_page_indptr: torch.Tensor, + kv_page_indices: torch.Tensor, + kv_last_page_len: torch.Tensor, + plan_params: MLADecodePlanParams, + ): + """Helper to plan a BatchMLAPagedAttentionWrapper.""" + # Compute actual KV lengths from paging metadata: + # kv_len = (num_pages - 1) * page_size + last_page_len + num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1] + kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len + wrapper.plan( + qo_indptr, + kv_page_indptr, + kv_page_indices, + kv_len_arr, + plan_params.num_heads, + plan_params.kv_lora_rank, # head_dim_ckv + plan_params.qk_rope_head_dim, # head_dim_kpe + plan_params.page_size, + causal=True, + q_data_type=plan_params.q_dtype, + kv_data_type=plan_params.kv_dtype, + sm_scale=plan_params.sm_scale, + ) + + def plan_decode( + self, + kv_page_indptr: torch.Tensor, + kv_page_indices: torch.Tensor, + kv_last_page_len: torch.Tensor, + plan_params: MLADecodePlanParams, + ) -> "flashinfer.mla.BatchMLAPagedAttentionWrapper": + """Plan decode using BatchMLAPagedAttentionWrapper. + + For MLA decode, we use the paged compressed KV cache with + FlashInfer's optimized MLA kernels. Each sequence generates 1 token. + + Args: + kv_page_indptr: Cumulative page counts [batch_size + 1]. + kv_page_indices: Page indices for the KV cache. + kv_last_page_len: Length of the last page per sequence. + plan_params: Parameters for planning. + """ + # Decode qo_indptr: [0, 1, 2, ..., batch_size] (1 token per sequence) + batch_size = kv_page_indptr.shape[0] - 1 + qo_indptr = torch.arange(batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32) + + # Compute kv_len_arr for CUDA graph wrapper initialization + num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1] + kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len + + # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached + if ( + cuda_graph_state.in_warm_up() + and plan_params not in self.cached_cuda_graph_decode_wrappers + ): + # During CUDA graph capture, the metadata tensors provided by auto-deploy are stable. + # Pass the buffer tensors to the wrapper for use_cuda_graph=True + wrapper = self._init_decode_wrapper( + use_cuda_graph=True, + qo_indptr=qo_indptr, + kv_indptr=kv_page_indptr, + kv_indices=kv_page_indices, + kv_len_arr=kv_len_arr, + ) + self.cached_cuda_graph_decode_wrappers[plan_params] = wrapper + self._plan_mla_wrapper( + wrapper, qo_indptr, kv_page_indptr, kv_page_indices, kv_last_page_len, plan_params + ) + + # check if we are in cuda graph capture and just return the pre-cached decode wrapper + if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up(): + wrapper = self.cached_cuda_graph_decode_wrappers[plan_params] + return wrapper + + # Re-plan if plan_params changed + if plan_params != self.plan_params_decode: + self._plan_mla_wrapper( + self.decode_wrapper, + qo_indptr, + kv_page_indptr, + kv_page_indices, + kv_last_page_len, + plan_params, + ) + self.plan_params_decode = plan_params + + return self.decode_wrapper + + def plan_chunked_prefill( + self, + qo_indptr: torch.Tensor, + kv_page_indptr: torch.Tensor, + kv_page_indices: torch.Tensor, + kv_last_page_len: torch.Tensor, + plan_params: MLADecodePlanParams, + ) -> "flashinfer.mla.BatchMLAPagedAttentionWrapper": + """Plan chunked/incremental prefill using BatchMLAPagedAttentionWrapper. + + For chunked prefill (input_pos > 0), we use the same kernel as decode but with + variable-length queries. Each sequence can have multiple tokens. + + Args: + qo_indptr: Cumulative query lengths [batch_size + 1]. + kv_page_indptr: Cumulative page counts [batch_size + 1]. + kv_page_indices: Page indices for the KV cache. + kv_last_page_len: Length of the last page per sequence. + plan_params: Parameters for planning. + """ + # Re-plan if plan_params changed + if plan_params != self.plan_params_chunked_prefill: + self._plan_mla_wrapper( + self.chunked_prefill_wrapper, + qo_indptr, + kv_page_indptr, + kv_page_indices, + kv_last_page_len, + plan_params, + ) + self.plan_params_chunked_prefill = plan_params + + return self.chunked_prefill_wrapper + + def plan_generate_only( + self, + num_seq: int, + cu_num_pages: torch.Tensor, + cache_loc: torch.Tensor, + last_page_len: torch.Tensor, + ): + """Plan decode-only batches for cached CUDA graph wrappers. + + This is called from the host-side preparation function to plan + the decode wrappers for decode-only batches before the actual + attention op is invoked. + + Args: + num_seq: Number of sequences in the decode batch. + cu_num_pages: Cumulative page counts, already sliced to [: num_seq + 1]. + cache_loc: Page indices for the KV cache. + last_page_len: Length of the last page per sequence, already sliced to [:num_seq]. + """ + for plan_params in self.cached_cuda_graph_decode_wrappers: + if plan_params.num_seq == num_seq: + wrapper = self.cached_cuda_graph_decode_wrappers[plan_params] + + # For a pure decode batch, qo_indptr is just [0, 1, 2, ..., batch_size] + qo_indptr = torch.arange(num_seq + 1, device=cu_num_pages.device, dtype=torch.int32) + + # Compute actual KV lengths from paging metadata: + # kv_len = (num_pages - 1) * page_size + last_page_len + num_pages_per_seq = cu_num_pages[1:] - cu_num_pages[:-1] + kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + last_page_len + + wrapper.plan( + qo_indptr, + cu_num_pages, # kv_page_indptr + cache_loc, # kv_page_indices + kv_len_arr, + plan_params.num_heads, + plan_params.kv_lora_rank, # head_dim_ckv + plan_params.qk_rope_head_dim, # head_dim_kpe + plan_params.page_size, + causal=True, + q_data_type=plan_params.q_dtype, + kv_data_type=plan_params.kv_dtype, + sm_scale=plan_params.sm_scale, + ) + + +_GlobalFlashInferMLAPlanner = _FlashInferMLAPlanner() + + +@torch.library.custom_op("auto_deploy::flashinfer_mla_prepare_metadata", mutates_args=()) +def prepare_flashinfer_mla_metadata( + position_ids: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + seq_len_with_cache: torch.Tensor, +) -> List[torch.Tensor]: + """Prepare metadata for FlashInfer MLA attention. + + This prepares batch_indices and positions for cache appends, similar to + the standard FlashInfer attention preparation. + """ + # retrieve host-side metadata + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + num_tokens = num_prefill_tokens + num_decode + + _GlobalFlashInferMLAPlanner.reset(position_ids.device) + + qo_indptr = cu_seqlen[: num_seq + 1] + + # Compute batch_indices and positions for cache appends + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, seq_len_with_cache[:num_seq], num_tokens + ) + + return batch_indices, positions + + +@prepare_flashinfer_mla_metadata.register_fake +def prepare_flashinfer_mla_metadata_fake( + position_ids: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + seq_len_with_cache: torch.Tensor, +): + num_tokens = position_ids.shape[0] * position_ids.shape[1] + return ( + torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), # batch_indices + torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), # positions + ) + + +def prepare_flashinfer_mla_metadata_host( + batch_info_host: torch.Tensor, + cu_num_pages_host: torch.Tensor, + cache_loc_host: torch.Tensor, + last_page_len_host: torch.Tensor, +) -> None: + """Host-side preparation for FlashInfer MLA attention. + + For decode-only batches, this function pre-plans the cached CUDA graph + wrappers to avoid planning during graph capture/replay. + """ + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + + if num_prefill == 0: + _GlobalFlashInferMLAPlanner.plan_generate_only( + num_decode, + cu_num_pages_host[: num_decode + 1], + cache_loc_host, + last_page_len_host[:num_decode], + ) + + +@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=()) +def flashinfer_mla_with_cache( + # 5 tensor args (matching torch_mla source op) + q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim] + q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim] + compressed_kv: torch.Tensor, # [B, S, kv_lora_rank] + kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim] + kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # Standard paged metadata + batch_info_host: torch.Tensor, + cu_seqlen_host: torch.Tensor, + cu_num_pages: torch.Tensor, + cu_num_pages_host: torch.Tensor, + cache_loc: torch.Tensor, + last_page_len: torch.Tensor, + last_page_len_host: torch.Tensor, + seq_len_with_cache_host: torch.Tensor, + # Extra FlashInfer metadata + flashinfer_batch_indices: torch.Tensor, + flashinfer_positions: torch.Tensor, + # Paged caches (two separate caches) + ckv_cache: torch.Tensor, # [num_pages, page_size, kv_lora_rank] + kpe_cache: torch.Tensor, # [num_pages, page_size, qk_rope_head_dim] + # Constants + scale: Optional[float], + kv_lora_rank: int, +) -> torch.Tensor: + """FlashInfer MLA attention with paged cache. + + Uses FlashInfer's optimized kernels: + - Prefill: BatchPrefillWithRaggedKVCacheWrapper with expanded K, V tensors + - Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache + + FlashInfer MLA Cache Layout (two separate caches): + ckv_cache: [num_pages, page_size, kv_lora_rank] + kpe_cache: [num_pages, page_size, qk_rope_head_dim] + + Args: + q_nope: Query non-positional component [B, S, N, qk_nope_head_dim] + q_pe: Query positional component [B, S, N, qk_rope_head_dim] + compressed_kv: Compressed KV latent [B, S, kv_lora_rank] + kpe: Key positional encoding [B, S, 1, qk_rope_head_dim] + kv_b_proj_weight: Projection weight [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + (metadata args): Standard paged attention metadata + ckv_cache: Paged cache for compressed KV + kpe_cache: Paged cache for key positional encoding + scale: Softmax scale factor + kv_lora_rank: Rank of compressed KV + + Returns: + Attention output [B, S, N, v_head_dim] + """ + # Get dimensions + b, s = q_nope.shape[:2] + num_heads = q_nope.shape[2] + qk_nope_head_dim = q_nope.shape[3] + qk_rope_head_dim = q_pe.shape[3] + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + + # Infer v_head_dim from kv_b_proj_weight + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads + v_head_dim = kv_head_dim - qk_nope_head_dim + + # Get batch info + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + num_total_tokens = num_prefill_tokens + num_decode + + # Set scale + if scale is None: + scale = 1.0 / math.sqrt(qk_head_dim) + + page_size = ckv_cache.shape[1] + + # Flatten inputs to [total_tokens, ...] format + bs = b * s + q_nope_flat = q_nope.contiguous().view(bs, num_heads, qk_nope_head_dim) + q_pe_flat = q_pe.contiguous().view(bs, num_heads, qk_rope_head_dim) + compressed_kv_flat = compressed_kv.contiguous().view(bs, kv_lora_rank) + kpe_flat = kpe.contiguous().view(bs, qk_rope_head_dim) + + # Convert cache dtype if needed + if ckv_cache.dtype == torch.float8_e4m3fn: + compressed_kv_flat = compressed_kv_flat.to(torch.float8_e4m3fn) + kpe_flat = kpe_flat.to(torch.float8_e4m3fn) + + # Append to paged cache using FlashInfer's append function + # Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager + flashinfer.page.append_paged_mla_kv_cache( + compressed_kv_flat, + kpe_flat, + flashinfer_batch_indices, + flashinfer_positions, + ckv_cache, + kpe_cache, + cache_loc, + cu_num_pages[: num_seq + 1], + last_page_len[:num_seq], + ) + + # Pre-allocate output + if num_prefill > 0 and num_decode > 0: + y = torch.empty(bs, num_heads, v_head_dim, dtype=q_nope.dtype, device=q_nope.device) + else: + y = None + + # ========================================================================= + # PREFILL phase: Use BatchPrefillWithRaggedKVCacheWrapper for regular prefill + # or BatchMLAPagedAttentionWrapper for chunked prefill + # ========================================================================= + if num_prefill > 0: + q_nope_prefill = q_nope_flat[:num_prefill_tokens] + q_pe_prefill = q_pe_flat[:num_prefill_tokens] + compressed_kv_prefill = compressed_kv_flat[:num_prefill_tokens] + kpe_prefill = kpe_flat[:num_prefill_tokens] + + # Check if any prefill sequence has cached tokens (chunked prefill) + # seq_len_with_cache > current_seq_len means there are cached tokens + q_lens = cu_seqlen_host[1 : num_prefill + 1] - cu_seqlen_host[:num_prefill] + kv_lens = seq_len_with_cache_host[:num_prefill] + is_chunked_prefill = (kv_lens > q_lens).any().item() + + if is_chunked_prefill: + # ================================================================= + # CHUNKED PREFILL: Use BatchMLAPagedAttentionWrapper with absorption + # Same approach as decode, but with variable-length Q sequences + # ================================================================= + + # Extract W_kn and W_v from kv_b_proj_weight + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank] + kv_b_proj_reshaped = kv_b_proj_weight.view( + num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank + ) + # W_kn: [N, qk_nope_head_dim, kv_lora_rank] + w_kn = kv_b_proj_reshaped[:, :qk_nope_head_dim, :] + # W_v: [N, v_head_dim, kv_lora_rank] + w_v = kv_b_proj_reshaped[:, qk_nope_head_dim:, :] + + # Absorb W_kn into q_nope: + # q_nope_prefill: [num_prefill_tokens, N, qk_nope_head_dim] + # w_kn: [N, qk_nope_head_dim, kv_lora_rank] + # q_nope_absorbed: [num_prefill_tokens, N, kv_lora_rank] + q_nope_absorbed = torch.einsum("bnd,ndk->bnk", q_nope_prefill, w_kn).contiguous() + + # Build qo_indptr for variable-length prefill sequences + qo_indptr = cu_seqlen_host[: num_prefill + 1].to( + device=cu_num_pages.device, dtype=torch.int32 + ) + + pp_chunked = MLADecodePlanParams( + num_heads=num_heads, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, + num_seq=num_prefill, + page_size=page_size, + q_dtype=q_nope.dtype, + kv_dtype=ckv_cache.dtype, + sm_scale=scale, + ) + + wrapper_chunked = _GlobalFlashInferMLAPlanner.plan_chunked_prefill( + qo_indptr=qo_indptr, + kv_page_indptr=cu_num_pages[: num_prefill + 1], + kv_page_indices=cache_loc, + kv_last_page_len=last_page_len[:num_prefill], + plan_params=pp_chunked, + ) + + # Run paged MLA attention in compressed space + y_prefill_compressed = wrapper_chunked.run( + q_nope_absorbed, + q_pe_prefill, + ckv_cache, + kpe_cache, + ) + + # Project output back from latent space to v_head_dim + # y_prefill_compressed: [num_prefill_tokens, N, kv_lora_rank] + # w_v: [N, v_head_dim, kv_lora_rank] + # y_prefill: [num_prefill_tokens, N, v_head_dim] + y_prefill = torch.einsum("bnk,nvk->bnv", y_prefill_compressed, w_v) + + else: + # ================================================================= + # REGULAR PREFILL: Use BatchPrefillWithRaggedKVCacheWrapper + # Expand compressed_kv to K, V and use ragged attention + # ================================================================= + + # Expand compressed_kv using kv_b_proj_weight to get k_nope and v + # compressed_kv: [num_prefill_tokens, kv_lora_rank] + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # kv_expanded: [num_prefill_tokens, N * (qk_nope_head_dim + v_head_dim)] + kv_expanded = torch.matmul(compressed_kv_prefill, kv_b_proj_weight.t()) + kv_expanded = kv_expanded.view( + num_prefill_tokens, num_heads, qk_nope_head_dim + v_head_dim + ) + + # Split into k_nope and v + k_nope_prefill = kv_expanded[:, :, :qk_nope_head_dim] # [tokens, N, qk_nope_head_dim] + v_prefill = kv_expanded[:, :, qk_nope_head_dim:].contiguous() # [tokens, N, v_head_dim] + + # Expand kpe to all heads: [tokens, qk_rope_head_dim] -> [tokens, N, qk_rope_head_dim] + kpe_expanded = kpe_prefill.unsqueeze(1).expand(-1, num_heads, -1).contiguous() + + # Concatenate to form full Q and K + # Q: [tokens, N, qk_head_dim] + q_prefill = torch.cat([q_nope_prefill, q_pe_prefill], dim=-1).contiguous() + # K: [tokens, N, qk_head_dim] + k_prefill = torch.cat([k_nope_prefill, kpe_expanded], dim=-1).contiguous() + + pp_prefill = MLAPrefillPlanParams( + num_heads=num_heads, + num_kv_heads=num_heads, # For MLA with expanded KV, same as num_heads + head_dim_qk=qk_head_dim, + head_dim_vo=v_head_dim, + num_seq=num_prefill, + q_dtype=q_nope.dtype, + kv_dtype=k_prefill.dtype, + sm_scale=scale, + ) + + wrapper_prefill = _GlobalFlashInferMLAPlanner.plan_prefill( + qo_indptr_host=cu_seqlen_host[: num_prefill + 1], + kv_indptr_host=cu_seqlen_host[: num_prefill + 1], # Same as qo for self-attention + plan_params=pp_prefill, + ) + + y_prefill = wrapper_prefill.run( + q_prefill, + k_prefill, + v_prefill, + ) + + if y is not None: + y[:num_prefill_tokens] = y_prefill + else: + y = y_prefill + + # ========================================================================= + # DECODE phase: Use BatchMLAPagedAttentionWrapper with paged compressed KV + # ========================================================================= + if num_decode > 0: + q_nope_decode = q_nope_flat[num_prefill_tokens:num_total_tokens].contiguous() + q_pe_decode = q_pe_flat[num_prefill_tokens:num_total_tokens].contiguous() + + # FlashInfer MLA operates in the compressed latent space. + # We need to: + # 1. Absorb W_kn (K-nope projection) into q_nope + # 2. Run attention in compressed space + # 3. Project output back using W_v + + # Extract W_kn and W_v from kv_b_proj_weight + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank] + kv_b_proj_reshaped = kv_b_proj_weight.view( + num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank + ) + # W_kn: [N, qk_nope_head_dim, kv_lora_rank] + w_kn = kv_b_proj_reshaped[:, :qk_nope_head_dim, :] + # W_v: [N, v_head_dim, kv_lora_rank] + w_v = kv_b_proj_reshaped[:, qk_nope_head_dim:, :] + + # Absorb W_kn into q_nope: + # q_nope_decode: [num_decode, N, qk_nope_head_dim] + # w_kn: [N, qk_nope_head_dim, kv_lora_rank] + # q_nope_absorbed: [num_decode, N, kv_lora_rank] + q_nope_absorbed = torch.einsum("bnd,ndk->bnk", q_nope_decode, w_kn).contiguous() + + pp_decode = MLADecodePlanParams( + num_heads=num_heads, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, + num_seq=num_decode, + page_size=page_size, + q_dtype=q_nope.dtype, + kv_dtype=ckv_cache.dtype, + sm_scale=scale, + ) + + wrapper_decode = _GlobalFlashInferMLAPlanner.plan_decode( + kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1], + kv_page_indices=cache_loc, + kv_last_page_len=last_page_len[num_prefill:num_seq], + plan_params=pp_decode, + ) + + # Run attention in compressed space + # y_decode_compressed: [num_decode, N, kv_lora_rank] + # Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager + y_decode_compressed = wrapper_decode.run( + q_nope_absorbed, + q_pe_decode, + ckv_cache, + kpe_cache, + ) + + # Project output back from latent space to v_head_dim + # y_decode_compressed: [num_decode, N, kv_lora_rank] + # w_v: [N, v_head_dim, kv_lora_rank] + # y_decode: [num_decode, N, v_head_dim] + y_decode = torch.einsum("bnk,nvk->bnv", y_decode_compressed, w_v) + + if y is not None: + y[num_prefill_tokens:num_total_tokens] = y_decode + else: + y = y_decode + + return y.view(b, s, num_heads, v_head_dim) + + +@flashinfer_mla_with_cache.register_fake +def flashinfer_mla_with_cache_fake( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + compressed_kv: torch.Tensor, + kpe: torch.Tensor, + kv_b_proj_weight: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen_host: torch.Tensor, + cu_num_pages: torch.Tensor, + cu_num_pages_host: torch.Tensor, + cache_loc: torch.Tensor, + last_page_len: torch.Tensor, + last_page_len_host: torch.Tensor, + seq_len_with_cache_host: torch.Tensor, + flashinfer_batch_indices: torch.Tensor, + flashinfer_positions: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + scale: Optional[float], + kv_lora_rank: int, +) -> torch.Tensor: + """Fake implementation for flashinfer_mla_with_cache.""" + num_heads = q_nope.shape[2] + qk_nope_head_dim = q_nope.shape[-1] + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads + v_head_dim = kv_head_dim - qk_nope_head_dim + + return q_nope.new_empty( + q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim + ).contiguous() + + +class MLAPagedResourceHandler(ResourceHandler): + """Handler for paged resources in MLA that require per-layer contiguous memory. + + While MLA uses paged caching, the underlying flashinfer MLA kernel uses a uint32_t to track the + strides for the cache. The KVCacheManager will allocate a contiguous tensor for the cache + across all layers with dim 0 representing the layer index. Hence, the per-layer cache has very + large strides to jump between pages which causes overflow in the MLA kernel that uses uint32_t + for strides. + + We use a separate handler for this purpose to avoid registering the cache with the + KVCacheManager and instead rely on local allocation. + """ + + @property + def is_paged(self) -> bool: + """Whether the resource is paged.""" + return True + + def __init__(self, *token_shape: int, dtype: torch.dtype) -> None: + """Initialize the ContiguousPagedResourceHandler. + + Args: + token_shape: The shape of the resource per token. + dtype: The dtype of the resource. + """ + self.token_shape = token_shape + self.dtype = dtype + + def _get_bytes_per_token(self) -> int: + """The size of the resource per token in bytes.""" + return prod(self.token_shape) * self.dtype.itemsize + + def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor: + """Allocate contiguous paged resource. + + Args: + sequence_info: SequenceInfo with device and page information. + + Returns: + Contiguous tensor of shape [num_blocks, tokens_per_block, *token_shape]. + """ + return torch.empty( + sequence_info.num_blocks, + sequence_info.tokens_per_block, + *self.token_shape, + device=sequence_info.device, + dtype=self.dtype, + ) + + +@AttentionRegistry.register("flashinfer_mla") +class FlashInferMLAAttention(AttentionDescriptor): + """Attention descriptor for FlashInfer-based MLA with paged cache. + + This descriptor uses FlashInfer's optimized MLA kernels: + - Source op: torch_mla (same as torch_mla backend) + - Cached op: flashinfer_mla_with_cache with paged cache + + FlashInfer MLA Cache Layout (two separate caches): + ckv_cache: [num_pages, page_size, kv_lora_rank] + kpe_cache: [num_pages, page_size, qk_rope_head_dim] + - No num_heads dimension (MLA-specific optimization) + + Reference: https://docs.flashinfer.ai/api/mla.html + """ + + @classmethod + def _get_planner(cls) -> _FlashInferMLAPlanner: + return _GlobalFlashInferMLAPlanner + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + """Get the attention layout expected by the backend.""" + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + """Get the number of tensor arguments expected by the source op.""" + return 5 # q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + """Get the source attention op that we target for replacement.""" + return torch.ops.auto_deploy.torch_mla + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + """Get the cached attention op.""" + return torch.ops.auto_deploy.flashinfer_mla_with_cache.default + + @classmethod + def get_standard_metadata_args(cls) -> List[str]: + """Get the list of standard metadata arguments for paged attention.""" + return [ + "batch_info_host", + "cu_seqlen_host", + "cu_num_pages", + "cu_num_pages_host", + "cache_loc", + "last_page_len", + "last_page_len_host", + "seq_len_with_cache_host", + ] + + @classmethod + def get_prepare_extra_metadata_info( + cls, any_source_attn_node: Node + ) -> Tuple[Optional[PrepareMetadataCallable], int, List[Constant]]: + """Get the prepare_metadata op for FlashInfer MLA.""" + return (torch.ops.auto_deploy.flashinfer_mla_prepare_metadata.default, 2, []) + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + """Get cache initializers using FlashInfer MLA paged cache layout. + + Creates two separate paged caches: + - ckv_cache: [num_pages, page_size, kv_lora_rank] + - kpe_cache: [num_pages, page_size, qk_rope_head_dim] + """ + # Extract dimensions from source node args + # torch_mla signature: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, ... + compressed_kv_fake: FakeTensor = source_attn_node.args[2].meta["val"] + kpe_fake: FakeTensor = source_attn_node.args[3].meta["val"] + + # Get dimensions + # compressed_kv: [B, S, kv_lora_rank] + # kpe: [B, S, 1, qk_rope_head_dim] + kv_lora_rank = compressed_kv_fake.shape[-1] + qk_rope_head_dim = kpe_fake.shape[-1] + + # flashinfer mla requires kv_lora_rank to be 512 and qk_rope_head_dim to be 64 + if kv_lora_rank != 512: + raise ValueError("kv_lora_rank must be 512 for flashinfer_mla") + if qk_rope_head_dim != 64: + raise ValueError("qk_rope_head_dim must be 64 for flashinfer_mla") + + cache_dtype = cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype) + + # FlashInfer MLA uses two separate paged caches with no num_heads dimension + return { + "ckv_cache": MLAPagedResourceHandler( + kv_lora_rank, + dtype=cache_dtype, + ), + "kpe_cache": MLAPagedResourceHandler( + qk_rope_head_dim, + dtype=cache_dtype, + ), + } + + @classmethod + def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]: + """Get function for host-side preparation.""" + return prepare_flashinfer_mla_metadata_host + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + """Get constants to pass to the cached attention op.""" + # Extract kv_lora_rank for cache operations + compressed_kv_fake = source_attn_node.args[2].meta["val"] + kv_lora_rank = compressed_kv_fake.shape[-1] + + # Get scale from kwargs + scale = source_attn_node.kwargs.get("scale", None) + + return [scale, kv_lora_rank] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/mla.py deleted file mode 100644 index f435fc5818..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/mla.py +++ /dev/null @@ -1,280 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Custom ops for MultiHead Latent attention.""" - -import math -from typing import List, Optional, Union - -import torch -from torch._ops import OpOverloadPacket -from torch.fx import Node - -from .....llmapi.llm_args import KvCacheConfig -from ..attention.triton_attention import _decode_attention, _prefill_attention -from ..attention_interface import ( - AttentionDescriptor, - AttentionLayout, - AttentionRegistry, - MHACallable, - ResourceHandlerDict, - UnpagedResourceHandler, -) - -Constant = Union[int, float, str, None] - - -def _precompute_inv_freq( - max_seq_len: int, head_dim: int, rope_theta: float, device: torch.device -) -> torch.Tensor: - inv_freq = 1.0 / ( - rope_theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim) - ) - t = torch.arange(max_seq_len, device=inv_freq.device, dtype=inv_freq.dtype) - - freqs = torch.outer(t, inv_freq.to(t.device)) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)]) - return cos_sin_stacked - - -@torch.library.custom_op( - "auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=() -) -def fused_flattened_mla_with_cache( - # Q, K, V - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv: torch.Tensor, - k_pe: torch.Tensor, - # STANDARD METADATA - batch_info_host: torch.Tensor, - seq_len: torch.Tensor, - input_pos: torch.Tensor, - cache_loc: torch.Tensor, - cu_seqlen: torch.Tensor, - # EXTRA METADATA - # - # CACHES - k_cache: torch.Tensor, - v_cache: torch.Tensor, - # CONSTANTS - softmax_scale: Optional[float] = None, - rope_theta: Optional[float] = None, -) -> torch.Tensor: - """Flattened & fused MLA with cache with triton kernels.""" - # b, s info - # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences. - # Generally speaking, we expect one of two cases here: - # 1. b > 0, s==1: this indicates a generate-only batch of tokens. - # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences - # and number of tokens per sequence are encoded in seq_len and seq_start. - - # check for sequence info and truncate metadata - num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() - num_seq = num_prefill + num_decode - - seq_len = seq_len[:num_seq] - input_pos = input_pos[:num_seq] - cache_loc = cache_loc[:num_seq] - seq_start = cu_seqlen[:num_seq] - - # Get parameters - b, num_heads, s, qk_nope_head_dim = q_nope.shape - qk_rope_head_dim = q_pe.shape[-1] - v_head_dim = kv.shape[-1] - qk_nope_head_dim - - # Get k_nope and value_states - k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1) - - # Flatten inputs - if s == 1: - bs_view = (b, s) - else: - bs_view = (b * s,) - - # TODO(suyogg): do something about all these clones, transposes, and contiguous-es - q_nope = q_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous() - q_pe = q_pe.clone().transpose(1, 2).view(*bs_view, num_heads, qk_rope_head_dim).contiguous() - k_nope = k_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous() - k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous() - value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous() - # Apply RoPE - if rope_theta is not None: - max_seq_len = (input_pos + seq_len).max().item() - cos_sin_stacked = _precompute_inv_freq( - max_seq_len, qk_rope_head_dim, rope_theta, q_pe.device - ) - - # Extract cos and sin from freqs_cis - cos_base = cos_sin_stacked[0, ...] - sin_base = cos_sin_stacked[1, ...] - - # TODO: Use triton kernels for RoPE - # TODO: Add yarn support - for i in range(seq_len.shape[0]): - start = seq_start[i] - length = seq_len[i] - - # build position_ids - if s == 1: - idx = (input_pos[i] + length - 1).item() - pos_ids = torch.tensor(idx, device=cos_base.device) - else: - pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device) - - cos = cos_base[pos_ids] # [..., 1, head_dim] - sin = sin_base[pos_ids] - q_slice = q_pe[start : start + length] - k_slice = k_pe[start : start + length] - - q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( - q_slice, - k_slice, - cos, - sin, - -2, - ) - - q_pe[start : start + length] = q_rot - k_pe[start : start + length] = k_rot - - # Create query_states, key_states - query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d] - key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d] - - # Compute scale if not provided - qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(qk_head_dim) - - # Compute attention - y = torch.empty_like(value_states) - if s == 1: - # generate-only phase (decode) - _decode_attention( - query_states.contiguous(), - key_states.contiguous(), - value_states.contiguous(), - k_cache, - v_cache, - cache_loc, - input_pos, - scale, - y, - ) - - else: - # mixed context + generate phase (prefill) - _prefill_attention( - query_states.contiguous(), - key_states.contiguous(), - value_states.contiguous(), - k_cache, - v_cache, - input_pos, - cache_loc, - seq_len, - seq_start, - scale, - y, - ) - - y = ( - y.view(b, s, -1, v_head_dim).transpose(1, 2).contiguous() - ) # BNSD format as expected by the callsite. - return y - - -@fused_flattened_mla_with_cache.register_fake -def fused_flattened_mla_with_cache_fake( - # Q, K, V - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv: torch.Tensor, - k_pe: torch.Tensor, - # STANDARD METADATA - batch_info_host: torch.Tensor, - seq_len: torch.Tensor, - input_pos: torch.Tensor, - cache_loc: torch.Tensor, - cu_seqlen: torch.Tensor, - # EXTRA METADATA - # - # CACHES - k_cache: torch.Tensor, - v_cache: torch.Tensor, - # CONSTANTS - softmax_scale: Optional[float] = None, - rope_theta: Optional[float] = None, -): - v_head_dim = kv.shape[-1] - q_nope.shape[-1] - return torch.empty_like(kv[..., -v_head_dim:]) - - -@AttentionRegistry.register("MultiHeadLatentAttention") -class MultiHeadLatentAttention(AttentionDescriptor): - @classmethod - def get_attention_layout(cls) -> AttentionLayout: - """Get the attention layout expected by the backend.""" - return "bnsd" - - @classmethod - def get_num_qkv_args(cls) -> int: - """Get the number of qkv arguments expected by the source op.""" - return 4 - - @classmethod - def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla - - @classmethod - def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache.default - - @classmethod - def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] - - @classmethod - def get_cache_initializers( - cls, source_attn_node: Node, cache_config: KvCacheConfig - ) -> ResourceHandlerDict: - q_nope_fake = source_attn_node.args[0].meta["val"] - q_pe_fake = source_attn_node.args[1].meta["val"] - kv_fake = source_attn_node.args[2].meta["val"] - - num_kv_heads = kv_fake.shape[1] - head_dim = q_nope_fake.shape[-1] - rope_dim = q_pe_fake.shape[-1] - - return { - "k_cache": UnpagedResourceHandler( - num_kv_heads, - head_dim + rope_dim, - dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype), - ), - "v_cache": UnpagedResourceHandler( - num_kv_heads, - head_dim, - dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype), - ), - } - - @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: - softmax_scale = None - rope_theta = 10000.0 # TODO: remove once MLA is unfused - return [softmax_scale, rope_theta] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py new file mode 100644 index 0000000000..28cda4cb0e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py @@ -0,0 +1,518 @@ +"""Custom ops for MultiHead Latent Attention (MLA) with FlashInfer-compatible cache. + +This module provides: +- torch_cached_mla_with_cache: cached backend op +- TorchBackendMLAAttention: attention descriptor + +FlashInfer MLA Cache Layout: + mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + - No num_heads dimension (MLA-specific optimization) + - compressed_kv_cached = mla_cache[:, :, :kv_lora_rank] (zero-copy slice) + - kpe_cached = mla_cache[:, :, kv_lora_rank:] (zero-copy slice) + +The implementation uses: +- Prefill: Expand compressed_kv -> full K, V, compute normal attention +- Generate: Weight absorption for efficiency (Q @ W^T instead of expanding cached KV) + +Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout +""" + +import math +from typing import List, Optional + +import torch +from torch._ops import OpOverloadPacket +from torch.fx import Node + +from .....llmapi.llm_args import KvCacheConfig +from ..attention_interface import ( + AttentionDescriptor, + AttentionLayout, + AttentionRegistry, + Constant, + MHACallable, + ResourceHandlerDict, + UnpagedResourceHandler, +) + + +def _update_mla_cache( + compressed_kv: torch.Tensor, # [total_tokens, kv_lora_rank] + kpe: torch.Tensor, # [total_tokens, qk_rope_head_dim] + mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + seq_len: torch.Tensor, + input_pos: torch.Tensor, + slot_idx: torch.Tensor, + seq_start: torch.Tensor, + kv_lora_rank: int, +) -> None: + """Update FlashInfer MLA cache with compressed_kv and kpe values. + + FlashInfer MLA cache layout: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + - First kv_lora_rank dims: compressed KV latent (before kv_b_proj) + - Last qk_rope_head_dim dims: key positional encoding + """ + for idx in range(seq_len.shape[0]): + start = seq_start[idx].item() + length = seq_len[idx].item() + cache_idx = slot_idx[idx].item() + pos = input_pos[idx].item() + + # Update compressed_kv portion + mla_cache[cache_idx, pos : pos + length, :kv_lora_rank] = compressed_kv[ + start : start + length + ] + # Update kpe portion + mla_cache[cache_idx, pos : pos + length, kv_lora_rank:] = kpe[start : start + length] + + +def _torch_mla_generate_with_absorption( + q_nope: torch.Tensor, # [B, 1, N, qk_nope_head_dim] + q_pe: torch.Tensor, # [B, 1, N, qk_rope_head_dim] + compressed_kv: torch.Tensor, # [B, 1, kv_lora_rank] + kpe: torch.Tensor, # [B, 1, 1, qk_rope_head_dim] + kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + slot_idx: torch.Tensor, + input_pos: torch.Tensor, + scale: float, + kv_lora_rank: int, + num_heads: int, + qk_nope_head_dim: int, + v_head_dim: int, + out: torch.Tensor, +) -> None: + """Generate-only MLA attention with weight absorption. + + Weight absorption: Instead of expanding all cached KV, we absorb kv_b_proj into Q. + Q_absorbed = Q_nope @ W_k^T where W_k is the k_nope portion of kv_b_proj_weight + + This avoids expanding potentially thousands of cached tokens. + """ + b = q_nope.shape[0] + + # Extract k_nope and v portions from kv_b_proj_weight + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank] + weight_reshaped = kv_b_proj_weight.view(num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank) + w_k_nope = weight_reshaped[:, :qk_nope_head_dim, :] # [N, qk_nope_head_dim, kv_lora_rank] + w_v = weight_reshaped[:, qk_nope_head_dim:, :] # [N, v_head_dim, kv_lora_rank] + + # Update cache with new tokens + compressed_kv_flat = compressed_kv.squeeze(1) # [B, kv_lora_rank] + kpe_flat = kpe.squeeze(1).squeeze(1) # [B, qk_rope_head_dim] + + for i in range(b): + cache_idx = slot_idx[i].item() + pos = input_pos[i].item() + mla_cache[cache_idx, pos, :kv_lora_rank] = compressed_kv_flat[i] + mla_cache[cache_idx, pos, kv_lora_rank:] = kpe_flat[i] + + # Compute attention for each sequence using weight absorption + for i in range(b): + cache_idx = slot_idx[i].item() + pos = input_pos[i].item() + + # Get query for this sequence: [N, qk_nope_head_dim], [N, qk_rope_head_dim] + q_nope_i = q_nope[i, 0] # [N, qk_nope_head_dim] + q_pe_i = q_pe[i, 0] # [N, qk_rope_head_dim] + + # Retrieve cached data up to current position + cached_data = mla_cache[cache_idx, : pos + 1] # [seq_len, kv_lora_rank + qk_rope_head_dim] + compressed_kv_cached = cached_data[:, :kv_lora_rank] # [seq_len, kv_lora_rank] + kpe_cached = cached_data[:, kv_lora_rank:] # [seq_len, qk_rope_head_dim] + + # ===================================================================== + # Weight absorption for Q_nope part + # ===================================================================== + # q_absorbed = q_nope @ w_k_nope^T (absorb k_nope projection into query) + # q_nope_i: [N, qk_nope_head_dim] + # w_k_nope: [N, qk_nope_head_dim, kv_lora_rank] + # q_absorbed: [N, kv_lora_rank] + q_absorbed = torch.einsum("nd,ndk->nk", q_nope_i, w_k_nope) + + # Attention scores from absorbed Q and compressed KV + # Compute in fp32 to match FlashInfer's use_fp16_qk_reduction=False + # q_absorbed: [N, kv_lora_rank], compressed_kv_cached: [seq_len, kv_lora_rank] + # scores_nope: [N, seq_len] + scores_nope = torch.matmul(q_absorbed.float(), compressed_kv_cached.float().t()) + + # ===================================================================== + # Q_pe part - standard attention with kpe + # ===================================================================== + # q_pe_i: [N, qk_rope_head_dim], kpe_cached: [seq_len, qk_rope_head_dim] + # scores_pe: [N, seq_len] + scores_pe = torch.matmul(q_pe_i.float(), kpe_cached.float().t()) + + # Combined attention scores (already in fp32) + attn_scores = (scores_nope + scores_pe) * scale # [N, seq_len] + + # Softmax (already in fp32, convert back to input dtype) + attn_weights = torch.softmax(attn_scores, dim=-1).to(q_nope.dtype) # [N, seq_len] + + # ===================================================================== + # Compute output with absorbed value projection + # ===================================================================== + # v_out = attn_weights @ compressed_kv @ w_v^T + # First: weighted_kv = attn_weights @ compressed_kv_cached -> [N, kv_lora_rank] + weighted_kv = torch.matmul(attn_weights, compressed_kv_cached) # [N, kv_lora_rank] + + # Then: attn_out = weighted_kv @ w_v^T -> [N, v_head_dim] + # w_v: [N, v_head_dim, kv_lora_rank] + # weighted_kv: [N, kv_lora_rank] + attn_out = torch.einsum("nk,nvk->nv", weighted_kv, w_v) # [N, v_head_dim] + + out[i] = attn_out + + +def _torch_mla_context_with_expansion( + q_nope: torch.Tensor, # [total_tokens, N, qk_nope_head_dim] + q_pe: torch.Tensor, # [total_tokens, N, qk_rope_head_dim] + compressed_kv: torch.Tensor, # [total_tokens, kv_lora_rank] + kpe: torch.Tensor, # [total_tokens, 1, qk_rope_head_dim] + kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + input_pos: torch.Tensor, + slot_idx: torch.Tensor, + seq_len: torch.Tensor, + seq_start: torch.Tensor, + scale: float, + kv_lora_rank: int, + num_heads: int, + qk_nope_head_dim: int, + v_head_dim: int, + out: torch.Tensor, +) -> None: + """Context MLA attention with kv_b_proj expansion. + + For prefill, we expand compressed_kv using kv_b_proj_weight and compute + standard attention. This is more efficient than absorption for prefill + since we only expand the current tokens, not the full cache. + """ + + # Flatten kpe: [total_tokens, 1, qk_rope_head_dim] -> [total_tokens, qk_rope_head_dim] + kpe_flat = kpe.squeeze(1) + + # Update cache first with compressed representation + _update_mla_cache( + compressed_kv, + kpe_flat, + mla_cache, + seq_len, + input_pos, + slot_idx, + seq_start, + kv_lora_rank, + ) + + # Compute attention for each sequence + attn_outputs = [] + for idx in range(seq_len.shape[0]): + seq_len_i = seq_len[idx].item() + input_pos_i = input_pos[idx].item() + slot_idx_i = slot_idx[idx].item() + seq_start_i = seq_start[idx].item() + + if seq_len_i == 0: + continue + + # Get query for this sequence + q_nope_seq = q_nope[ + seq_start_i : seq_start_i + seq_len_i + ] # [seq_len_i, N, qk_nope_head_dim] + q_pe_seq = q_pe[seq_start_i : seq_start_i + seq_len_i] # [seq_len_i, N, qk_rope_head_dim] + + # Get cached data for attention (includes just-added tokens) + kv_seq_len = input_pos_i + seq_len_i + cached_data = mla_cache[ + slot_idx_i, :kv_seq_len + ] # [kv_seq_len, kv_lora_rank + qk_rope_head_dim] + compressed_kv_cached = cached_data[:, :kv_lora_rank] # [kv_seq_len, kv_lora_rank] + kpe_cached = cached_data[:, kv_lora_rank:] # [kv_seq_len, qk_rope_head_dim] + + # ===================================================================== + # Expand compressed_kv using kv_b_proj_weight for this sequence + # ===================================================================== + # compressed_kv_cached: [kv_seq_len, kv_lora_rank] + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # kv_expanded: [kv_seq_len, N * (qk_nope_head_dim + v_head_dim)] + kv_expanded = torch.matmul(compressed_kv_cached, kv_b_proj_weight.t()) + + # Reshape to [kv_seq_len, N, qk_nope_head_dim + v_head_dim] + kv_expanded = kv_expanded.view(kv_seq_len, num_heads, qk_nope_head_dim + v_head_dim) + + # Split into k_nope and v + k_nope_expanded = kv_expanded[:, :, :qk_nope_head_dim] # [kv_seq_len, N, qk_nope_head_dim] + v_expanded = kv_expanded[:, :, qk_nope_head_dim:] # [kv_seq_len, N, v_head_dim] + + # Expand kpe to all heads + kpe_expanded = kpe_cached.unsqueeze(1).expand( + -1, num_heads, -1 + ) # [kv_seq_len, N, qk_rope_head_dim] + + # Construct full query and key + query_full = torch.cat([q_nope_seq, q_pe_seq], dim=-1) # [seq_len_i, N, qk_head_dim] + key_full = torch.cat( + [k_nope_expanded, kpe_expanded], dim=-1 + ) # [kv_seq_len, N, qk_head_dim] + + # Transpose for attention: [1, N, seq_len, head_dim] + query_t = query_full.transpose(0, 1).unsqueeze(0) # [1, N, seq_len_i, qk_head_dim] + key_t = key_full.transpose(0, 1).unsqueeze(0) # [1, N, kv_seq_len, qk_head_dim] + + # Compute attention scores in fp32 to match FlashInfer's use_fp16_qk_reduction=False + # FlashInfer uses fp32 accumulation for QK^T, so we do the same for numerical consistency + attn_scores = ( + torch.matmul(query_t.float(), key_t.float().transpose(-2, -1)) * scale + ) # [1, N, seq_len_i, kv_seq_len] in fp32 + + # Apply causal mask + causal_mask = torch.triu( + torch.ones(seq_len_i, kv_seq_len, device=q_nope.device, dtype=torch.bool), + diagonal=kv_seq_len - seq_len_i + 1, + ) + attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + # Softmax (already in fp32, convert back to input dtype) + attn_weights = torch.softmax(attn_scores, dim=-1).to(q_nope.dtype) + + # Value: [1, N, kv_seq_len, v_head_dim] + v_t = v_expanded.transpose(0, 1).unsqueeze(0) + + # Compute output + attn_out = torch.matmul(attn_weights, v_t) # [1, N, seq_len_i, v_head_dim] + attn_out = attn_out[0].transpose(0, 1) # [seq_len_i, N, v_head_dim] + + attn_outputs.append(attn_out) + + # Concatenate all outputs + if len(attn_outputs) == 0: + out.zero_() + elif len(attn_outputs) == 1: + out.copy_(attn_outputs[0]) + else: + out.copy_(torch.cat(attn_outputs, dim=0)) + + +@torch.library.custom_op("auto_deploy::torch_cached_mla_with_cache", mutates_args=()) +def torch_backend_mla_with_cache( + # 5 tensor args (get_num_qkv_args = 5) + q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim] + q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim] + compressed_kv: torch.Tensor, # [B, S, kv_lora_rank] + kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim] + kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # Standard metadata + batch_info_host: torch.Tensor, + seq_len: torch.Tensor, + input_pos: torch.Tensor, + slot_idx: torch.Tensor, + cu_seqlen: torch.Tensor, + # Cache (FlashInfer layout) + mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + # Constants + scale: Optional[float] = None, + kv_lora_rank: int = 512, +) -> torch.Tensor: + """Torch backend MLA with FlashInfer-compatible compressed cache. + + FlashInfer MLA Cache Layout: + mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + - compressed_kv = mla_cache[:, :, :kv_lora_rank] (zero-copy slice) + - kpe = mla_cache[:, :, kv_lora_rank:] (zero-copy slice) + + Prefill (context): Expand compressed_kv, compute normal attention + Generate (decode): Use weight absorption for efficiency + """ + # Get dimensions + b, s = q_nope.shape[:2] + num_heads = q_nope.shape[2] + qk_nope_head_dim = q_nope.shape[3] + qk_rope_head_dim = q_pe.shape[3] + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + + # Infer v_head_dim from kv_b_proj_weight + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads + v_head_dim = kv_head_dim - qk_nope_head_dim + + # Get cleaned up metadata + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + seq_len = seq_len[:num_seq] + input_pos = input_pos[:num_seq] + slot_idx = slot_idx[:num_seq] + seq_start = cu_seqlen[:num_seq] + + # Set scale + if scale is None: + scale = 1.0 / math.sqrt(qk_head_dim) + + # Define output shape: [B, S, N, v_head_dim] + output_shape = (b, s, num_heads, v_head_dim) + + if s == 1: + # ===================================================================== + # Generate phase: Use weight absorption + # ===================================================================== + y = q_nope.new_empty(b, num_heads, v_head_dim).contiguous() + + _torch_mla_generate_with_absorption( + q_nope, + q_pe, + compressed_kv, + kpe, + kv_b_proj_weight, + mla_cache, + slot_idx, + input_pos, + scale, + kv_lora_rank, + num_heads, + qk_nope_head_dim, + v_head_dim, + y, + ) + + return y.unsqueeze(1) # [B, 1, N, v_head_dim] + else: + # ===================================================================== + # Context phase: Expand and compute normal attention + # ===================================================================== + bs_view = (b * s,) + + q_nope_flat = q_nope.contiguous().view(*bs_view, num_heads, qk_nope_head_dim) + q_pe_flat = q_pe.contiguous().view(*bs_view, num_heads, qk_rope_head_dim) + compressed_kv_flat = compressed_kv.contiguous().view(*bs_view, kv_lora_rank) + kpe_flat = kpe.contiguous().view(*bs_view, 1, qk_rope_head_dim) + + y = q_nope.new_empty(*bs_view, num_heads, v_head_dim).contiguous() + + _torch_mla_context_with_expansion( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + kv_b_proj_weight, + mla_cache, + input_pos, + slot_idx, + seq_len, + seq_start, + scale, + kv_lora_rank, + num_heads, + qk_nope_head_dim, + v_head_dim, + y, + ) + + return y.view(*output_shape) + + +@torch_backend_mla_with_cache.register_fake +def torch_backend_mla_with_cache_fake( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + compressed_kv: torch.Tensor, + kpe: torch.Tensor, + kv_b_proj_weight: torch.Tensor, + batch_info_host: torch.Tensor, + seq_len: torch.Tensor, + input_pos: torch.Tensor, + slot_idx: torch.Tensor, + cu_seqlen: torch.Tensor, + mla_cache: torch.Tensor, + scale: Optional[float] = None, + kv_lora_rank: int = 512, +) -> torch.Tensor: + """Fake implementation for torch_backend_mla_with_cache.""" + num_heads = q_nope.shape[2] + qk_nope_head_dim = q_nope.shape[-1] + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads + v_head_dim = kv_head_dim - qk_nope_head_dim + + return q_nope.new_empty( + q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim + ).contiguous() + + +@AttentionRegistry.register("torch_mla") +class TorchBackendMLAAttention(AttentionDescriptor): + """Attention descriptor for Multi-head Latent Attention (MLA). + + This descriptor uses FlashInfer-compatible compressed cache: + - torch_mla: source op that expands compressed_kv for attention + - torch_cached_mla_with_cache: cached op with absorption for generate + + FlashInfer MLA Cache Layout: + mla_cache: [max_batch, max_seq, head_dim_ckv + head_dim_kpe] + - No num_heads dimension (MLA-specific optimization) + - ckv_cached = mla_cache[:, :, :head_dim_ckv] (zero-copy slice) + - kpe_cached = mla_cache[:, :, head_dim_ckv:] (zero-copy slice) + + Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout + """ + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + """Get the attention layout expected by the backend.""" + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + """Get the number of tensor arguments expected by the source op.""" + return 5 # q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + """Get the source attention op that we target for replacement.""" + return torch.ops.auto_deploy.torch_mla + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + """Get the cached attention op.""" + return torch.ops.auto_deploy.torch_cached_mla_with_cache.default + + @classmethod + def get_standard_metadata_args(cls) -> List[str]: + """Get the list of standard metadata arguments.""" + return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"] + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + """Get cache initializers using FlashInfer MLA cache layout.""" + # Extract dimensions from source node args + # torch_mla signature: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, ... + compressed_kv_fake = source_attn_node.args[2].meta["val"] + kpe_fake = source_attn_node.args[3].meta["val"] + + # Get dimensions + # compressed_kv: [B, S, kv_lora_rank] + # kpe: [B, S, 1, qk_rope_head_dim] + kv_lora_rank = compressed_kv_fake.shape[-1] + qk_rope_head_dim = kpe_fake.shape[-1] + + # FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + # No num_heads dimension - this is the key MLA optimization + return { + "mla_cache": UnpagedResourceHandler( + kv_lora_rank + qk_rope_head_dim, + dtype=cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype), + ), + } + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + """Get constants to pass to the cached attention op.""" + # Extract kv_lora_rank for cache slicing + compressed_kv_fake = source_attn_node.args[2].meta["val"] + kv_lora_rank = compressed_kv_fake.shape[-1] + + # Get scale from kwargs + scale = source_attn_node.kwargs.get("scale", None) + + return [scale, kv_lora_rank] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py new file mode 100644 index 0000000000..5be00a764e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py @@ -0,0 +1,157 @@ +"""Torch reference implementation for Multi-head Latent Attention (MLA). + +This module provides the source op for MLA that: +- Accepts compressed_kv (before kv_b_proj) for FlashInfer-compatible caching +- Expands compressed_kv using kv_b_proj_weight for attention computation +- Computes standard attention with the expanded K, V +""" + +import math +from typing import Optional + +import torch + + +@torch.library.custom_op("auto_deploy::torch_mla", mutates_args=()) +def torch_mla( + q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim] + q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim] (RoPE applied) + compressed_kv: torch.Tensor, # [B, S, kv_lora_rank] - BEFORE kv_b_proj + kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim] (RoPE applied) + kv_b_proj_weight: torch.Tensor, # [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + is_causal: bool = True, + scale: Optional[float] = None, + layout: str = "bsnd", +) -> torch.Tensor: + """Multi-head Latent Attention (MLA) with FlashInfer-compatible compressed KV. + + This op expands compressed_kv using kv_b_proj_weight and computes attention. + For prefill, this is the standard formulation. For the cached version, + weight absorption is used for efficiency. + + Args: + q_nope: Query non-positional component [B, S, N, qk_nope_head_dim] (bsnd) + q_pe: Query positional component with RoPE applied [B, S, N, qk_rope_head_dim] (bsnd) + compressed_kv: Compressed KV latent [B, S, kv_lora_rank] (before kv_b_proj) + kpe: Key positional encoding with RoPE applied [B, S, 1, qk_rope_head_dim] (bsnd) + kv_b_proj_weight: Projection weights [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + is_causal: Whether to apply causal masking (default: True) + scale: Softmax scale factor (default: 1/sqrt(qk_head_dim)) + layout: Input/output layout, either "bsnd" or "bnsd" (default: "bsnd") + + Returns: + Attention output with shape [B, S, N, v_head_dim] (bsnd) + """ + if layout not in ("bnsd", "bsnd"): + raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}") + + # Get dimensions + if layout == "bsnd": + bs, s_q, num_heads, qk_nope_head_dim = q_nope.shape + qk_rope_head_dim = q_pe.shape[-1] + else: + bs, num_heads, s_q, qk_nope_head_dim = q_nope.shape + qk_rope_head_dim = q_pe.shape[-1] + + s_k = compressed_kv.shape[1] + + # Infer dimensions from kv_b_proj_weight + # kv_b_proj_weight: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads # qk_nope_head_dim + v_head_dim + v_head_dim = kv_head_dim - qk_nope_head_dim + + # Set scale + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + if scale is None: + scale = 1.0 / math.sqrt(qk_head_dim) + + # ========================================================================= + # Expand compressed_kv using kv_b_proj_weight (this is the prefill path) + # ========================================================================= + # compressed_kv: [B, S, kv_lora_rank] + # kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank] + # kv = compressed_kv @ kv_b_proj_weight.T -> [B, S, num_heads * kv_head_dim] + kv = torch.matmul(compressed_kv, kv_b_proj_weight.t()) + + # Reshape to [B, S, N, kv_head_dim] + kv = kv.view(bs, s_k, num_heads, kv_head_dim) + + # Split into k_nope and value_states + k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1) + + # k_nope and value_states are always [B, S, N, D] from the kv reshape above. + # We need them in [B, N, S, D] for attention computation. + k_nope = k_nope.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + # Convert inputs to computation layout [B, N, S, D] if they come in bsnd format + if layout == "bsnd": + # [B, S, N, D] -> [B, N, S, D] + q_nope = q_nope.transpose(1, 2).contiguous() + q_pe = q_pe.transpose(1, 2).contiguous() + kpe = kpe.transpose(1, 2).contiguous() + + # kpe is [B, 1, S, qk_rope_head_dim], expand to num_heads + kpe_expanded = kpe.expand(bs, num_heads, s_k, qk_rope_head_dim) + + # Construct full query and key states + # query_states: [B, N, S, qk_head_dim] + query_states = torch.cat([q_nope, q_pe], dim=-1) + # key_states: [B, N, S, qk_head_dim] + key_states = torch.cat([k_nope, kpe_expanded], dim=-1) + + # Compute attention scores: Q @ K^T + attn_scores = ( + torch.matmul(query_states, key_states.transpose(-2, -1)) * scale + ) # [B, N, s_q, s_k] + + # Apply causal mask if specified + if is_causal and s_q == s_k: + causal_mask = torch.triu( + torch.ones(s_q, s_k, device=q_nope.device, dtype=torch.bool), + diagonal=1, + ) + attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + # Compute attention weights and output + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q_nope.dtype) + attn_out = torch.matmul(attn_weights, value_states) # [B, N, s_q, v_head_dim] + + # Convert back to requested layout + if layout == "bsnd": + return attn_out.transpose(1, 2).contiguous() # [B, S, N, v_head_dim] + else: + return attn_out.contiguous() # [B, N, S, v_head_dim] + + +@torch_mla.register_fake +def torch_mla_fake( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + compressed_kv: torch.Tensor, + kpe: torch.Tensor, + kv_b_proj_weight: torch.Tensor, + is_causal: bool = True, + scale: Optional[float] = None, + layout: str = "bsnd", +) -> torch.Tensor: + """Fake implementation for torch_mla.""" + # Infer v_head_dim from kv_b_proj_weight + qk_nope_head_dim = q_nope.shape[-1] + num_heads = q_nope.shape[2] if layout == "bsnd" else q_nope.shape[1] + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads + v_head_dim = kv_head_dim - qk_nope_head_dim + + # Output shape depends on layout + if layout == "bsnd": + # Input: [B, S, N, D], Output: [B, S, N, v_head_dim] + return q_nope.new_empty( + q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim + ).contiguous() + else: + # Input: [B, N, S, D], Output: [B, N, S, v_head_dim] + return q_nope.new_empty( + q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim + ).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index 4ad9e96cb8..b9fcb1e0f0 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,9 +1,11 @@ -from .modeling_eagle import Eagle3DrafterForCausalLM +from .modeling_deepseek import DeepSeekV3ForCausalLM +from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast from .modeling_nemotron_h import NemotronHForCausalLM __all__ = ( - "Eagle3DrafterForCausalLM", + "DeepSeekV3ForCausalLM", + "Glm4MoeLiteForCausalLM", "NemotronFlashForCausalLM", "NemotronFlashPreTrainedTokenizerFast", "NemotronHForCausalLM", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek.py new file mode 100644 index 0000000000..add0de21fc --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek.py @@ -0,0 +1,655 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Slimmed down PyTorch DeepSeekV3 model implementation for auto_deploy export. + +Source: +https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py + +This implementation differs from the original in the following ways: +* Simplified for prefill-only inference (no KV caching) +* Uses auto_deploy custom ops for export compatibility +* Removed flash attention variants (uses torch_mla custom op) +* Removed gradient checkpointing and training code paths +* Removed attention dropout (inference only) + +This allows us to have a clean export-ready implementation with auto_deploy custom ops. +""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType + + +class DeepSeekV3RMSNorm(nn.Module): + """RMS Normalization for DeepSeekV3.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.triton_rms_norm( + hidden_states, self.weight, self.variance_epsilon + ).to(hidden_states.dtype) + + +class DeepSeekV3RotaryEmbedding(nn.Module): + """Rotary Position Embedding for DeepSeekV3. + + Simplified version that precomputes and caches cos/sin values. + Returns full cached values (not sliced by seq_len) to enable export. + """ + + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build cos/sin cache + self._set_cos_sin_cache(max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def forward( + self, x: torch.Tensor, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Return full cached cos/sin (not sliced) for export compatibility + return ( + self.cos_cached.to(dtype=x.dtype, device=x.device), + self.sin_cached.to(dtype=x.dtype, device=x.device), + ) + + +class DeepSeekV3YarnRotaryEmbedding(DeepSeekV3RotaryEmbedding): + """YaRN-extended rotary embedding for DeepSeekV3.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000.0, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + _mscale = float( + self._yarn_get_mscale(self.scaling_factor, self.mscale) + / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale), persistent=False) + + @staticmethod + def _yarn_find_correction_dim( + num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048 + ) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def _yarn_find_correction_range( + self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int + ) -> Tuple[int, int]: + low = math.floor( + self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + @staticmethod + def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + @staticmethod + def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor: + if min_val == max_val: + max_val += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + return torch.clamp(linear_func, 0, 1) + + +class DeepSeekV3MLP(nn.Module): + """MLP layer for DeepSeekV3 (SwiGLU activation).""" + + def __init__( + self, config, hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size or config.hidden_size + self.intermediate_size = intermediate_size or config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class DeepSeekV3MoEGate(nn.Module): + """MoE Gating for DeepSeekV3 with noaux_tc top-k selection.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32) + ) + self.register_buffer( + "e_score_correction_bias", + torch.zeros(self.n_routed_experts, dtype=torch.float32), + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize gate weights using kaiming uniform (matches original DeepSeek implementation).""" + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass returning (selected_experts, routing_weights).""" + bsz, seq_len, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Compute router logits + if self.weight.dtype == torch.float32: + router_logits = F.linear(hidden_states_flat.float(), self.weight) + else: + router_logits = torch.ops.trtllm.dsv3_router_gemm_op( + hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32 + ) + + # Use fused noaux_tc_op kernel for top-k selection + topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op( + router_logits, + self.e_score_correction_bias, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + ) + + return topk_indices, topk_weights + + +class DeepSeekV3MoE(nn.Module): + """Mixture of Experts layer for DeepSeekV3.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + # Routed experts + self.experts = nn.ModuleList( + [ + DeepSeekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + + # Gate + self.gate = DeepSeekV3MoEGate(config) + + # Shared experts (if configured) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepSeekV3MLP(config, intermediate_size=intermediate_size) + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + identity = hidden_states + orig_shape = hidden_states.shape + + selected_experts, routing_weights = self.gate(hidden_states) + + # Use torch_moe custom op for routed experts + final_hidden_states = torch.ops.auto_deploy.torch_moe( + hidden_states.view(-1, hidden_states.shape[-1]), + selected_experts, + routing_weights, + w1_weight=[expert.gate_proj.weight for expert in self.experts], + w2_weight=[expert.down_proj.weight for expert in self.experts], + w3_weight=[expert.up_proj.weight for expert in self.experts], + is_gated_mlp=True, + act_fn=int(ActivationType.Silu), + ) + + final_hidden_states = final_hidden_states.view(*orig_shape) + + # Add shared experts output if present + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) + + return final_hidden_states.to(hidden_states.dtype) + + +class DeepSeekV3Attention(nn.Module): + """Multi-head Latent Attention (MLA) for DeepSeekV3. + + Uses compressed KV representation with latent projections. + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Q projection (with optional LoRA) + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepSeekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + # KV projection with MQA + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepSeekV3RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias + ) + + # Initialize rotary embedding + self._init_rope() + + # Softmax scale + self.softmax_scale = self.q_head_dim ** (-0.5) + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = DeepSeekV3YarnRotaryEmbedding._yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepSeekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + + if scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepSeekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + # Default to base rotary embedding for unsupported types + self.rotary_emb = DeepSeekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + # Q projection + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + # Shape: [B, S, N, q_head_dim] (BSND layout) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # KV projection - keep compressed form + kv_a_output = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # Apply layernorm to compressed_kv + compressed_kv = self.kv_a_layernorm(compressed_kv) + + # k_pe: [B, S, 1, qk_rope_head_dim] (BSND layout, shared across heads) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) + + kv_seq_len = q_len + + # Get cos/sin for RoPE + cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len) + cos = cos[position_ids] # [B, S, head_dim] + sin = sin[position_ids] # [B, S, head_dim] + + # Apply RoPE using custom op + q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( + q_pe, + k_pe, + cos, + sin, + 2, # unsqueeze_dim=2 for BSND layout + ) + + # Call MLA with compressed KV + attn_output = torch.ops.auto_deploy.torch_mla( + q_nope, # [B, S, N, qk_nope_head_dim] + q_pe_rotated, # [B, S, N, qk_rope_head_dim] + compressed_kv, # [B, S, kv_lora_rank] + kpe, # [B, S, 1, qk_rope_head_dim] + self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank] + True, # is_causal + self.softmax_scale, + "bsnd", # layout + ) + + # Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim] + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class DeepSeekV3DecoderLayer(nn.Module): + """Transformer decoder layer for DeepSeekV3.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Attention + self.self_attn = DeepSeekV3Attention(config, layer_idx=layer_idx) + + # MLP or MoE + # MoE layers are used after first_k_dense_replace and at moe_layer_freq intervals + use_moe = ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + if use_moe: + self.mlp = DeepSeekV3MoE(config) + else: + self.mlp = DeepSeekV3MLP(config) + + # Layer norms + self.input_layernorm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepSeekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + # Self attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_ids) + hidden_states = residual + hidden_states + + # MLP/MoE + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@dataclass +class DeepSeekV3Output(ModelOutput): + """Output for DeepSeekV3Model.""" + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class DeepSeekV3CausalLMOutput(ModelOutput): + """Output for DeepSeekV3ForCausalLM.""" + + logits: Optional[torch.FloatTensor] = None + + +class DeepSeekV3PreTrainedModel(PreTrainedModel): + """Base class for DeepSeekV3 models.""" + + base_model_prefix = "model" + _no_split_modules = ["DeepSeekV3DecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class DeepSeekV3Model(DeepSeekV3PreTrainedModel): + """DeepSeekV3 transformer decoder model.""" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + DeepSeekV3DecoderLayer(config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> DeepSeekV3Output: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Cannot specify both input_ids and inputs_embeds") + elif input_ids is None and inputs_embeds is None: + raise ValueError("Must specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = inputs_embeds.shape[:2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_ids) + + hidden_states = self.norm(hidden_states) + + return DeepSeekV3Output(last_hidden_state=hidden_states) + + +class DeepSeekV3ForCausalLM(DeepSeekV3PreTrainedModel, GenerationMixin): + """DeepSeekV3 model with language modeling head.""" + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepSeekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> DeepSeekV3CausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return DeepSeekV3CausalLMOutput(logits=logits) + + +# Register with AutoModelForCausalLMFactory +AutoModelForCausalLMFactory.register_custom_model_cls("DeepseekV3Config", DeepSeekV3ForCausalLM) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_glm4_moe_lite.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_glm4_moe_lite.py new file mode 100644 index 0000000000..17aec04940 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_glm4_moe_lite.py @@ -0,0 +1,830 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Slimmed down PyTorch GLM4 MoE Lite model implementation for auto_deploy export. + +Source: +https://huggingface.co/zai-org/GLM-4.7-Flash + +This implementation differs from the original HuggingFace version in the following ways: +* Bundled config class to work with transformers v4.57 (model requires v5.0) +* Simplified for prefill-only inference (no KV caching) +* Uses auto_deploy custom ops for export compatibility +* Removed flash attention variants (uses torch_mla custom op) +* Removed gradient checkpointing and training code paths +* Removed attention dropout (inference only) + +The GLM4 MoE Lite model uses Multi-head Latent Attention (MLA), similar to DeepSeek V3. +""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoConfig, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType + + +class Glm4MoeLiteConfig(PretrainedConfig): + """Configuration class for GLM4 MoE Lite model. + + This config class is bundled with the custom model implementation to enable + loading on transformers v4.57 (the model requires v5.0 where the config is + natively registered). + """ + + model_type = "glm4_moe_lite" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 154880, + hidden_size: int = 2048, + intermediate_size: int = 10240, + num_hidden_layers: int = 47, + num_attention_heads: int = 20, + num_key_value_heads: int = 20, + hidden_act: str = "silu", + max_position_embeddings: int = 202752, + rms_norm_eps: float = 1e-5, + # MLA parameters + q_lora_rank: int = 768, + kv_lora_rank: int = 512, + qk_nope_head_dim: int = 192, + qk_rope_head_dim: int = 64, + v_head_dim: int = 256, + # MoE parameters + n_routed_experts: int = 64, + n_shared_experts: int = 1, + num_experts_per_tok: int = 4, + moe_intermediate_size: int = 1536, + n_group: int = 1, + topk_group: int = 1, + routed_scaling_factor: float = 1.8, + norm_topk_prob: bool = True, + first_k_dense_replace: int = 1, + # RoPE parameters + rope_theta: float = 1000000.0, + rope_scaling: Optional[dict] = None, + # Other parameters + attention_bias: bool = False, + attention_dropout: float = 0.0, + tie_word_embeddings: bool = False, + pad_token_id: int = 154820, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + # MLA + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + # MoE + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size + self.n_group = n_group + self.topk_group = topk_group + self.routed_scaling_factor = routed_scaling_factor + self.norm_topk_prob = norm_topk_prob + self.first_k_dense_replace = first_k_dense_replace + # RoPE + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Other + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + + super().__init__( + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +# Register config with transformers' AutoConfig so it can be loaded from HF hub +# Use exist_ok=True to handle cases where transformers already has this model type registered +# (e.g., transformers v5.0+). In those cases, AutoConfig will use the built-in config, +# but AutoModelForCausalLMFactory will still use our custom model implementation. +try: + AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig, exist_ok=True) +except TypeError: + # Older transformers versions don't support exist_ok parameter + try: + AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig) + except ValueError: + # Already registered by transformers, that's fine + pass + + +class Glm4MoeLiteRMSNorm(nn.Module): + """RMS Normalization for GLM4 MoE Lite. + + Uses standard torch operations so AD fusion passes can replace with + the appropriate backend (flashinfer/triton) based on config. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Glm4MoeLiteRotaryEmbedding(nn.Module): + """Rotary Position Embedding for GLM4 MoE Lite. + + Simplified version that precomputes and caches cos/sin values. + Returns full cached values (not sliced by seq_len) to enable export. + + Uses _ad_ prefix for buffer names to work with AutoDeploy's lift_to_meta. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000.0, + attention_scaling: float = 1.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.attention_scaling = attention_scaling + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build cos/sin cache with AD-specific naming + self._set_cos_sin_cache(max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + # Use _ad_ prefix for AutoDeploy compatibility with lift_to_meta + self.register_buffer("_ad_cos_cached", emb.cos() * self.attention_scaling, persistent=False) + self.register_buffer("_ad_sin_cached", emb.sin() * self.attention_scaling, persistent=False) + + def forward( + self, x: torch.Tensor, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Return full cached cos/sin (not sliced) for export compatibility + return ( + self._ad_cos_cached.to(dtype=x.dtype, device=x.device), + self._ad_sin_cached.to(dtype=x.dtype, device=x.device), + ) + + +class Glm4MoeLiteYarnRotaryEmbedding(Glm4MoeLiteRotaryEmbedding): + """YaRN-extended rotary embedding for GLM4 MoE Lite.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000.0, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + attention_scaling: float = 1.0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, attention_scaling) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + _mscale = float( + self._yarn_get_mscale(self.scaling_factor, self.mscale) + / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + # Use _ad_ prefix for AutoDeploy compatibility with lift_to_meta + # Note: attention_scaling is already incorporated in _mscale for YaRN + self.register_buffer("_ad_cos_cached", (emb.cos() * _mscale), persistent=False) + self.register_buffer("_ad_sin_cached", (emb.sin() * _mscale), persistent=False) + + @staticmethod + def _yarn_find_correction_dim( + num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048 + ) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def _yarn_find_correction_range( + self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int + ) -> Tuple[int, int]: + low = math.floor( + self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + @staticmethod + def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + @staticmethod + def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor: + if min_val == max_val: + max_val += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + return torch.clamp(linear_func, 0, 1) + + +class Glm4MoeLiteMLP(nn.Module): + """MLP layer for GLM4 MoE Lite (SwiGLU activation).""" + + def __init__( + self, config, hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size or config.hidden_size + self.intermediate_size = intermediate_size or config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Glm4MoeLiteMoEGate(nn.Module): + """MoE Gating for GLM4 MoE Lite with top-k selection. + + Uses fused TensorRT-LLM custom ops for efficient routing: + - dsv3_router_gemm_op: Fused router GEMM for non-float32 weights + - noaux_tc_op: Fused sigmoid + bias + group top-k + normalize + scale + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = getattr(config, "norm_topk_prob", True) + + # noaux_tc_op always normalizes, so norm_topk_prob must be True + if not self.norm_topk_prob: + raise ValueError( + "Glm4MoeLiteMoEGate requires norm_topk_prob=True when using fused ops. " + "The noaux_tc_op kernel always normalizes routing weights." + ) + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32) + ) + self.register_buffer( + "e_score_correction_bias", + torch.zeros(self.n_routed_experts, dtype=torch.float32), + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize gate weights using kaiming uniform.""" + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass returning (selected_experts, routing_weights). + + Uses fused TensorRT-LLM ops for efficient routing: + 1. dsv3_router_gemm_op: Router GEMM (when weights are not float32) + 2. noaux_tc_op: Fused sigmoid + bias + group top-k + normalize + scale + """ + bsz, seq_len, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Router GEMM - use fused op when weights are not float32 + if self.weight.dtype == torch.float32: + router_logits = F.linear(hidden_states_flat.float(), self.weight) + else: + router_logits = torch.ops.trtllm.dsv3_router_gemm_op( + hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32 + ) + + # Fused routing: sigmoid + bias + group top-k + normalize + scale + # noaux_tc_op internally applies: + # 1. Sigmoid to router_logits + # 2. Adds e_score_correction_bias + # 3. Group-wise top-2 scoring and top group selection + # 4. Top-k expert selection from selected groups + # 5. Gathers weights from sigmoid scores + # 6. Normalizes and scales by routed_scaling_factor + topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op( + router_logits, + self.e_score_correction_bias, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + ) + + return topk_indices, topk_weights + + +class Glm4MoeLiteMoE(nn.Module): + """Mixture of Experts layer for GLM4 MoE Lite.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + # Routed experts - use ModuleList with individual expert modules + # This creates state_dict keys like: experts.0.gate_proj.weight + # which matches the checkpoint structure + self.experts = nn.ModuleList( + [ + Glm4MoeLiteMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + + # Gate + self.gate = Glm4MoeLiteMoEGate(config) + + # Shared experts + if config.n_shared_experts is not None and config.n_shared_experts > 0: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = Glm4MoeLiteMLP(config, intermediate_size=intermediate_size) + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + identity = hidden_states + orig_shape = hidden_states.shape + + selected_experts, routing_weights = self.gate(hidden_states) + + # Use torch_moe custom op for routed experts + final_hidden_states = torch.ops.auto_deploy.torch_moe( + hidden_states.view(-1, hidden_states.shape[-1]), + selected_experts, + routing_weights, + w1_weight=[expert.gate_proj.weight for expert in self.experts], + w2_weight=[expert.down_proj.weight for expert in self.experts], + w3_weight=[expert.up_proj.weight for expert in self.experts], + is_gated_mlp=True, + act_fn=int(ActivationType.Silu), + ) + + final_hidden_states = final_hidden_states.view(*orig_shape) + + # Add shared experts output if present + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) + + return final_hidden_states.to(hidden_states.dtype) + + +class Glm4MoeLiteAttention(nn.Module): + """Multi-head Latent Attention (MLA) for GLM4 MoE Lite. + + Uses compressed KV representation with latent projections. + Receives position embeddings from the model level (shared rotary embedding). + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + + # Q projection (with optional LoRA) + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = Glm4MoeLiteRMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) + + # KV projection with MQA + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = Glm4MoeLiteRMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias + ) + + # Softmax scale + self.softmax_scale = self.qk_head_dim ** (-0.5) + # Apply mscale adjustment if using YaRN scaling with factor + if ( + config.rope_scaling is not None + and isinstance(config.rope_scaling, dict) + and "factor" in config.rope_scaling + ): + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = Glm4MoeLiteYarnRotaryEmbedding._yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + # Q projection + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + # Shape: [B, S, N, qk_head_dim] (BSND layout) + q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # KV projection - keep compressed form + kv_a_output = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # Apply layernorm to compressed_kv + compressed_kv = self.kv_a_layernorm(compressed_kv) + + # k_pe: [B, S, 1, qk_rope_head_dim] (BSND layout, shared across heads) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) + + # Get cos/sin from position_embeddings (full cached from shared rotary embedding) + cos, sin = position_embeddings # Full table: [max_seq_len, head_dim] + cos = cos[position_ids] # [B, S, head_dim] + sin = sin[position_ids] # [B, S, head_dim] + + # Apply RoPE using custom op + q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( + q_pe, + k_pe, + cos, + sin, + 2, # unsqueeze_dim=2 for BSND layout + ) + + # Call MLA with compressed KV + attn_output = torch.ops.auto_deploy.torch_mla( + q_nope, # [B, S, N, qk_nope_head_dim] + q_pe_rotated, # [B, S, N, qk_rope_head_dim] + compressed_kv, # [B, S, kv_lora_rank] + kpe, # [B, S, 1, qk_rope_head_dim] + self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank] + True, # is_causal + self.softmax_scale, + "bsnd", # layout + ) + + # Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim] + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Glm4MoeLiteDecoderLayer(nn.Module): + """Transformer decoder layer for GLM4 MoE Lite.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Attention + self.self_attn = Glm4MoeLiteAttention(config, layer_idx=layer_idx) + + # MLP or MoE + # Layer 0 to first_k_dense_replace-1 use dense MLP, rest use MoE + use_moe = config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace + if use_moe: + self.mlp = Glm4MoeLiteMoE(config) + else: + self.mlp = Glm4MoeLiteMLP(config) + + # Layer norms + self.input_layernorm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Glm4MoeLiteRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Self attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_ids, position_embeddings) + hidden_states = residual + hidden_states + + # MLP/MoE + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@dataclass +class Glm4MoeLiteOutput(ModelOutput): + """Output for Glm4MoeLiteModel.""" + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Glm4MoeLiteCausalLMOutput(ModelOutput): + """Output for Glm4MoeLiteForCausalLM.""" + + logits: Optional[torch.FloatTensor] = None + + +class Glm4MoeLitePreTrainedModel(PreTrainedModel): + """Base class for GLM4 MoE Lite models.""" + + config_class = Glm4MoeLiteConfig + base_model_prefix = "model" + _no_split_modules = ["Glm4MoeLiteDecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Glm4MoeLiteModel(Glm4MoeLitePreTrainedModel): + """GLM4 MoE Lite transformer decoder model.""" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + Glm4MoeLiteDecoderLayer(config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Shared rotary embedding at model level (not per-layer) + # This creates a single set of cos/sin buffers for all layers + self.rotary_emb = self._init_rope(config) + + self.post_init() + + def _init_rope(self, config): + """Initialize shared rotary embedding for all layers.""" + qk_rope_head_dim = config.qk_rope_head_dim + + # Compute attention_scaling for RoPE (same logic as in attention) + attention_scaling = 1.0 + if ( + config.rope_scaling is not None + and isinstance(config.rope_scaling, dict) + and "factor" in config.rope_scaling + ): + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = Glm4MoeLiteYarnRotaryEmbedding._yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + attention_scaling = mscale + + # Check if rope_scaling is None, empty, or missing required "factor" key + use_yarn = ( + config.rope_scaling is not None + and isinstance(config.rope_scaling, dict) + and "factor" in config.rope_scaling + ) + + if not use_yarn: + return Glm4MoeLiteRotaryEmbedding( + qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + attention_scaling=attention_scaling, + ) + else: + scaling_factor = config.rope_scaling["factor"] + kwargs = { + key: config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in config.rope_scaling + } + return Glm4MoeLiteYarnRotaryEmbedding( + qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + scaling_factor=scaling_factor, + base=config.rope_theta, + attention_scaling=attention_scaling, + **kwargs, + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Glm4MoeLiteOutput: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Cannot specify both input_ids and inputs_embeds") + elif input_ids is None and inputs_embeds is None: + raise ValueError("Must specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = inputs_embeds.shape[:2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Compute position embeddings once from shared rotary embedding + # This returns full cached cos/sin tables + position_embeddings = self.rotary_emb(inputs_embeds) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_ids, position_embeddings) + + hidden_states = self.norm(hidden_states) + + return Glm4MoeLiteOutput(last_hidden_state=hidden_states) + + +class Glm4MoeLiteForCausalLM(Glm4MoeLitePreTrainedModel, GenerationMixin): + """GLM4 MoE Lite model with language modeling head.""" + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = Glm4MoeLiteModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Glm4MoeLiteCausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return Glm4MoeLiteCausalLMOutput(logits=logits) + + +# Register with AutoModelForCausalLMFactory +AutoModelForCausalLMFactory.register_custom_model_cls("Glm4MoeLiteConfig", Glm4MoeLiteForCausalLM) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py b/tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py deleted file mode 100644 index f30bc0c6fa..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py +++ /dev/null @@ -1,185 +0,0 @@ -import types -import warnings -from typing import Dict, Optional - -import torch -import torch.utils.checkpoint -from transformers import AutoModelForCausalLM -from transformers.cache_utils import Cache - - -def deepseek_v3_attention( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.IntTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -): - """DeepSeekV3Attention forward function rewritten to wrap MultiheadLatentAttention as a custom op.""" - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - # If else paths are determined by config.json - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - _, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - raise ValueError("past_key_value is not supported") - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - # Use custom op to capture mla. This does not handle KV cache - # as passing transformers Cache into a custom op is throwing an error. - # Would not be an issue, cause we intend to replace mla op with our implementation further along the pipeline - attn_output = torch.ops.auto_deploy.torch_attention_deepseek_fused_mla( - q_nope, - q_pe, - kv, - k_pe, - cos, - sin, - position_ids, - attention_mask, - self.softmax_scale, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# This patched module matches exactly with HF generate -@torch.inference_mode() -def deepseek_v3_moe_exact(self, hidden_states): - """DeepSeekV3MoE forward function rewritten to enable torch export. - - This custom implementation matches exactly with the deepseek implementation. There are - some errors in the output tensors when the index_add based implementation is used, leading - to some mismatch in the outputs for some prompts. This ensures exact match between HF output - without custom patch and with custom patch. - """ - identity = hidden_states - batch_size, sequence_length, hidden_dim = hidden_states.shape - - selected_experts, routing_weights, *_ = self.gate(hidden_states) - - hidden_states = hidden_states.view(-1, hidden_dim) - idxs = torch.argsort(selected_experts.view(-1), stable=True) - - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.experts_per_rank - ).permute(2, 1, 0) - outputs = [] - for expert_idx in range(len(self.experts)): - expert_layer = self.experts[expert_idx] - _, top_x = torch.where(expert_mask[expert_idx]) - # Sort the top_xs and idx - sorted, _ = torch.sort(top_x) - tokens_for_this_expert = hidden_states[None, sorted].reshape(-1, hidden_dim) - expert_out = expert_layer(tokens_for_this_expert) - outputs.append(expert_out) - - outs = torch.cat(outputs, dim=0) - # Wrap torch.zeros() in a custom op to fix meta device issue during inference. - new_x = torch.zeros( - (*selected_experts.view(-1).shape, hidden_dim), - device=selected_experts.device, - dtype=outs.dtype, - ) - new_x[idxs] = outs - final_hidden_states = ( - new_x.view(*selected_experts.shape, -1) - .type(routing_weights.dtype) - .mul_(routing_weights.unsqueeze(-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - - if self.config.n_shared_experts is not None: - final_hidden_states = final_hidden_states + self.shared_experts(identity) - - return final_hidden_states.to(hidden_states.dtype) - - -@torch.inference_mode() -def deepseek_v3_moe(self, hidden_states): - """DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export.""" - - selected_experts, routing_weights, *_ = self.gate(hidden_states) - final_hidden_states = torch.ops.auto_deploy.torch_moe( - hidden_states, - selected_experts, - routing_weights, - w1_weight=[expert.gate_proj.weight for expert in self.experts], - w2_weight=[expert.down_proj.weight for expert in self.experts], - w3_weight=[expert.up_proj.weight for expert in self.experts], - ) - - if self.config.n_shared_experts is not None: - final_hidden_states = final_hidden_states + self.shared_experts(hidden_states) - - return final_hidden_states.to(hidden_states.dtype) - - -def deepseek_v3_rope(self, x, seq_len=None): - """DeepSeekV3 Rotary Embedding forward function rewritten to enable torch export. - We return the full cached cos and sin values, instead of slicing them based on seq_len as this - would cause an issue during the generate phase (when seq_len=1 from input_ids). We also move the cos - sin buffers to appropriate device to enable export. - """ - - return ( - self.cos_cached.to(dtype=x.dtype).to(device=x.device), - self.sin_cached.to(dtype=x.dtype).to(device=x.device), - ) - - -_from_config_original = AutoModelForCausalLM.from_config - -CUSTOM_MODULE_PATCHES: Dict[str, callable] = { - "DeepseekV3MoE": deepseek_v3_moe, - "DeepseekV2MoE": deepseek_v3_moe, - "DeepseekV3RotaryEmbedding": deepseek_v3_rope, - "DeepseekV3YarnRotaryEmbedding": deepseek_v3_rope, - "DeepseekV2RotaryEmbedding": deepseek_v3_rope, - "DeepseekV2YarnRotaryEmbedding": deepseek_v3_rope, -} - - -def get_model_from_config_patched(config, **kwargs): - model = _from_config_original(config, **kwargs) - # Patch modules - for _, module in model.named_modules(): - if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys(): - module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module) - - return model - - -# TODO: figure out how this can be incorporated into the export patch system -AutoModelForCausalLM.from_config = get_model_from_config_patched diff --git a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py index 5ec9d69627..898b46c769 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py +++ b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py @@ -83,7 +83,7 @@ class QuantConfigReaderRegistry: @QuantConfigReaderRegistry.register("modelopt") class ModelOPTQuantConfigReader(QuantConfigReader): - _ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*") + _ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate") def read_config(self, config: Dict) -> Dict: producer = config.get("producer", {}).get("name") diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 3ee8eb2d62..905880b06c 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -411,3 +411,75 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness): task.evaluate(llm, sampling_params=sampling_params) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + + +class TestGLM4Flash(LlmapiAccuracyTestHarness): + """Accuracy regression tests for GLM-4.7-Flash. + + TODO: enable in CI, see https://github.com/NVIDIA/TensorRT-LLM/issues/11117 + + In the meantime, you should run this test locally: + + ``` + cd tests/integration/defs + TRTLLM_ACCURACY_NO_REFERENCE=1 pytest -svv "accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]" + ``` + """ + + MODEL_NAME = "zai-org/GLM-4.7-Flash" + MODEL_PATH = MODEL_NAME # Model is in HF_CACHE + # Set minimum possible seq len + small buffer, for test speed & memory usage + MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN, + GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN) + MAX_NUM_TOKENS = MAX_SEQ_LEN + + def get_default_kwargs(self, enable_chunked_prefill=False): + config = { + "skip_tokenizer_init": False, + "trust_remote_code": True, + "compile_backend": "torch-cudagraph", + "max_batch_size": 128, + "max_seq_len": self.MAX_SEQ_LEN, + "max_num_tokens": self.MAX_NUM_TOKENS, + "skip_loading_weights": False, + "disable_overlap_scheduler": False, + "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], + "kv_cache_config": { + "enable_block_reuse": False, + "free_gpu_memory_fraction": 0.88 + }, + "model_kwargs": { + "torch_dtype": "bfloat16" + }, + "transforms": { + "fuse_nvfp4_moe": { + "allow_different_input_scales": True, + }, + }, + } + if enable_chunked_prefill: + config["enable_chunked_prefill"] = True + config[ + "max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size) + return config + + def get_default_sampling_params(self): + eos_id = -1 + beam_width = 1 + return SamplingParams(end_id=eos_id, + pad_id=eos_id, + n=beam_width, + use_beam_search=beam_width > 1) + + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("enable_chunked_prefill", [True, False]) + def test_auto_dtype(self, enable_chunked_prefill): + kwargs = self.get_default_kwargs(enable_chunked_prefill) + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_PATH, + tokenizer=self.MODEL_PATH, + **kwargs) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index d585e4e088..e1922ae751 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -400,7 +400,8 @@ _SMALL_MODEL_CONFIGS = { "num_hidden_layers": 2, "hidden_size": 32, "intermediate_size": 64, - "kv_lora_rank": 128, + "kv_lora_rank": 512, # NOTE: must be 512 (default) for flashinfer_mla to work + "qk_rope_head_dim": 64, # NOTE: must be 64 (default) for flashinfer_mla to work "moe_intermediate_size": 128, "n_group": 2, "topk_group": 2, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py new file mode 100644 index 0000000000..e7c1be66b2 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py @@ -0,0 +1,2179 @@ +"""Test FlashInfer MLA backend operations. + +Tests the flashinfer_mla_with_cache cached op and compares it with the +torch_backend_mla_with_cache reference implementation. + +Key features tested: +- 5 tensor arguments: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight +- Paged caches: ckv_cache [num_pages, page_size, kv_lora_rank] and kpe_cache [num_pages, page_size, qk_rope_head_dim] +- Prefill: Expand compressed_kv, compute normal attention via BatchPrefillWithRaggedKVCacheWrapper +- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache + +Reference: https://docs.flashinfer.ai/api/mla.html +""" + +import flashinfer +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy # noqa: F401 +from tensorrt_llm._torch.auto_deploy.custom_ops.mla.flashinfer_mla import ( + _GlobalFlashInferMLAPlanner, +) +from tensorrt_llm._torch.auto_deploy.utils.cuda_graph import CudaGraphWarmUpPhase + +# Skip all tests in this module on GPUs with compute capability < 9.0 (older than Hopper) +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability(0) < (9, 0), + reason="FlashInfer MLA tests require GPU with compute capability >= 9.0 (at least Hopper architecture)", +) + + +def _create_mla_inputs( + batch_size: int, + seq_len: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + kv_lora_rank: int, + v_head_dim: int, + dtype: torch.dtype, + device: str, +): + """Create MLA input tensors. + + Args: + batch_size: Batch size + seq_len: Sequence length + num_heads: Number of attention heads + qk_nope_head_dim: Dimension of query/key non-positional part + qk_rope_head_dim: Dimension of query/key positional (RoPE) part + kv_lora_rank: Rank of compressed KV (LoRA rank) + v_head_dim: Dimension of value head + dtype: Data type + device: Device + + Returns: + Dictionary with q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight + """ + kv_head_dim = qk_nope_head_dim + v_head_dim + + # Scale factor for Xavier-like initialization to keep values bounded + # This helps reduce numerical differences by keeping output magnitudes smaller + q_scale = 1.0 / (qk_nope_head_dim**0.5) + kv_scale = 1.0 / (kv_lora_rank**0.5) + + # q_nope: [B, S, N, qk_nope_head_dim] + q_nope = ( + torch.randn(batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device) + * q_scale + ) + + # q_pe: [B, S, N, qk_rope_head_dim] + q_pe = ( + torch.randn(batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device) + * q_scale + ) + + # compressed_kv: [B, S, kv_lora_rank] + compressed_kv = ( + torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device) * kv_scale + ) + + # kpe: [B, S, 1, qk_rope_head_dim] + kpe = ( + torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device) * q_scale + ) + + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # Xavier initialization for the projection weight + weight_scale = 1.0 / (kv_lora_rank**0.5) + kv_b_proj_weight = ( + torch.randn(num_heads * kv_head_dim, kv_lora_rank, dtype=dtype, device=device) + * weight_scale + ) + + return { + "q_nope": q_nope, + "q_pe": q_pe, + "compressed_kv": compressed_kv, + "kpe": kpe, + "kv_b_proj_weight": kv_b_proj_weight, + } + + +def _create_unpaged_cache_and_metadata( + batch_size: int, + max_seq_len: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + dtype: torch.dtype, + device: str, + seq_lengths: list, + input_positions: list, +): + """Create unpaged (torch backend) cache and metadata. + + Args: + batch_size: Batch size + max_seq_len: Maximum sequence length + kv_lora_rank: Rank of compressed KV + qk_rope_head_dim: Dimension of RoPE + dtype: Data type + device: Device + seq_lengths: List of sequence lengths per batch + input_positions: List of input positions (cache offsets) per batch + + Returns: + Dictionary with cache and metadata tensors + """ + # FlashInfer MLA cache (unpaged): [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + mla_cache = torch.zeros( + batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device + ) + + # Metadata + seq_len_tensor = torch.tensor(seq_lengths, dtype=torch.int32, device=device) + input_pos = torch.tensor(input_positions, dtype=torch.int32, device=device) + slot_idx = torch.arange(batch_size, dtype=torch.int32, device=device) + + # Compute cu_seqlen (cumulative sequence lengths) + cu_seqlen = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlen[1:] = torch.cumsum(seq_len_tensor, dim=0) + + # Determine if this is context (prefill) or generate (decode) + total_tokens = sum(seq_lengths) + is_decode = all(s == 1 for s in seq_lengths) + + if is_decode: + # Decode phase + batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device) + else: + # Context/prefill phase + batch_info_host = torch.tensor( + [batch_size, total_tokens, 0], dtype=torch.int32, device=device + ) + + return { + "mla_cache": mla_cache, + "batch_info_host": batch_info_host, + "seq_len": seq_len_tensor, + "input_pos": input_pos, + "slot_idx": slot_idx, + "cu_seqlen": cu_seqlen[:-1], # Exclude last element for seq_start + } + + +def _create_paged_cache_and_metadata( + batch_size: int, + max_num_pages: int, + page_size: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + dtype: torch.dtype, + device: str, + seq_lengths: list, + input_positions: list, +): + """Create paged (flashinfer backend) cache and metadata. + + Args: + batch_size: Batch size + max_num_pages: Maximum number of pages + page_size: Size of each page + kv_lora_rank: Rank of compressed KV + qk_rope_head_dim: Dimension of RoPE + dtype: Data type + device: Device + seq_lengths: List of sequence lengths per batch + input_positions: List of input positions (cache offsets) per batch + + Returns: + Dictionary with paged cache and metadata tensors + """ + # Paged MLA caches (two separate caches) + ckv_cache = torch.zeros(max_num_pages, page_size, kv_lora_rank, dtype=dtype, device=device) + kpe_cache = torch.zeros(max_num_pages, page_size, qk_rope_head_dim, dtype=dtype, device=device) + + # Compute total KV lengths (input_pos + seq_len for each sequence) + kv_lengths = [pos + seq_len for pos, seq_len in zip(input_positions, seq_lengths)] + + # Compute number of pages per sequence + pages_per_seq = [(kv_len - 1) // page_size + 1 if kv_len > 0 else 1 for kv_len in kv_lengths] + + # Assign pages (simple sequential assignment) + page_assignments = [] + next_page = 0 + for num_pages in pages_per_seq: + page_assignments.append(list(range(next_page, next_page + num_pages))) + next_page += num_pages + + # Create FlashInfer paging metadata + seq_len_tensor = torch.tensor(seq_lengths, dtype=torch.int32, device=device) + + # qo_indptr: cumulative query/output lengths + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(seq_len_tensor, dim=0) + + # cu_num_pages: cumulative number of pages per sequence + num_pages_per_seq = torch.tensor(pages_per_seq, dtype=torch.int32, device=device) + cu_num_pages = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_num_pages[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + # cache_loc (paged_kv_indices): flattened list of page indices + cache_loc = torch.tensor( + [p for pages in page_assignments for p in pages], dtype=torch.int32, device=device + ) + + # last_page_len: number of valid tokens in the last page of each sequence + last_page_len = torch.tensor( + [((kv_len - 1) % page_size) + 1 if kv_len > 0 else 0 for kv_len in kv_lengths], + dtype=torch.int32, + device=device, + ) + + # seq_len_with_cache: total KV lengths + seq_len_with_cache = torch.tensor(kv_lengths, dtype=torch.int32, device=device) + + # Host copies + qo_indptr_host = qo_indptr.cpu() + cu_num_pages_host = cu_num_pages.cpu() + last_page_len_host = last_page_len.cpu() + seq_len_with_cache_host = seq_len_with_cache.cpu() + + # Determine if this is context (prefill) or generate (decode) + total_tokens = sum(seq_lengths) + is_decode = all(s == 1 for s in seq_lengths) + + if is_decode: + # Decode phase + batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device) + else: + # Context/prefill phase + batch_info_host = torch.tensor( + [batch_size, total_tokens, 0], dtype=torch.int32, device=device + ) + + return { + "ckv_cache": ckv_cache, + "kpe_cache": kpe_cache, + "batch_info_host": batch_info_host, + "cu_seqlen_host": qo_indptr_host, + "cu_num_pages": cu_num_pages, + "cu_num_pages_host": cu_num_pages_host, + "cache_loc": cache_loc, + "last_page_len": last_page_len, + "last_page_len_host": last_page_len_host, + "seq_len_with_cache_host": seq_len_with_cache_host, + "page_size": page_size, + } + + +def _copy_unpaged_to_paged_cache( + unpaged_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + batch_size: int, + tokens_per_seq: list, + page_size: int, + cu_num_pages: torch.Tensor, + cache_loc: torch.Tensor, + kv_lora_rank: int, +): + """Copy unpaged cache data to paged cache format. + + This is used to initialize paged cache with the same data as unpaged cache + for comparison tests. + + Args: + unpaged_cache: Source cache [batch, max_seq, dim] + ckv_cache: Destination paged ckv cache [num_pages, page_size, kv_lora_rank] + kpe_cache: Destination paged kpe cache [num_pages, page_size, qk_rope_head_dim] + batch_size: Number of sequences + tokens_per_seq: Number of tokens to copy per sequence + page_size: Number of tokens per page + cu_num_pages: Cumulative page counts from flashinfer metadata [batch_size + 1] + cache_loc: Page indices from flashinfer metadata + kv_lora_rank: Rank of compressed KV (split dimension) + """ + for batch_idx in range(batch_size): + num_tokens = tokens_per_seq[batch_idx] + if num_tokens == 0: + continue + + # Get page assignments for this sequence from flashinfer metadata + page_start_idx = cu_num_pages[batch_idx].item() + page_end_idx = cu_num_pages[batch_idx + 1].item() + + token_offset = 0 + for i in range(page_start_idx, page_end_idx): + page_num = cache_loc[i].item() + tokens_to_copy = min(page_size, num_tokens - token_offset) + if tokens_to_copy <= 0: + break + + # Split unpaged cache into ckv and kpe portions + unpaged_data = unpaged_cache[batch_idx, token_offset : token_offset + tokens_to_copy] + ckv_cache[page_num, :tokens_to_copy] = unpaged_data[:, :kv_lora_rank] + kpe_cache[page_num, :tokens_to_copy] = unpaged_data[:, kv_lora_rank:] + token_offset += tokens_to_copy + + +@pytest.mark.parametrize("seq_length", [32, 128]) +@pytest.mark.parametrize("num_heads", [1]) +@pytest.mark.parametrize("batch_size", [8, 64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_op_context(seq_length, num_heads, batch_size, dtype, device): + """Test FlashInfer MLA context (prefill) phase. + + Compares flashinfer_mla_with_cache against torch_backend_mla_with_cache + for context (prefill) operations where seq_length > 1. + """ + # MLA dimensions (similar to DeepSeek) + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 128 + page_size = seq_length # Use seq_length as page size for simpler comparison + + max_seq_len = 256 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # Create input tensors + inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + # Flatten inputs for context phase + total_tokens = batch_size * seq_length + q_nope_flat = inputs["q_nope"].view(1, total_tokens, num_heads, qk_nope_head_dim) + q_pe_flat = inputs["q_pe"].view(1, total_tokens, num_heads, qk_rope_head_dim) + compressed_kv_flat = inputs["compressed_kv"].view(1, total_tokens, kv_lora_rank) + kpe_flat = inputs["kpe"].view(1, total_tokens, 1, qk_rope_head_dim) + + # Sequence lengths and positions + seq_lengths = [seq_length] * batch_size + input_positions = [0] * batch_size # Context starts at position 0 + + # ========================================================================= + # Run torch backend (reference) + # ========================================================================= + torch_meta = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + torch_output = torch.ops.auto_deploy.torch_cached_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + inputs["kv_b_proj_weight"], + torch_meta["batch_info_host"], + torch_meta["seq_len"], + torch_meta["input_pos"], + torch_meta["slot_idx"], + torch_meta["cu_seqlen"], + torch_meta["mla_cache"], + None, # scale + kv_lora_rank, + ) + + # ========================================================================= + # Run FlashInfer backend + # ========================================================================= + # Reset planner + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + # Create paged metadata + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + # Compute FlashInfer batch indices and positions + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + total_tokens, + ) + + flashinfer_output = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, # scale + kv_lora_rank, + ) + + # ========================================================================= + # Compare outputs + # ========================================================================= + # Reshape for comparison + torch_output_reshaped = torch_output.view(batch_size, seq_length, num_heads, v_head_dim) + flashinfer_output_reshaped = flashinfer_output.view( + batch_size, seq_length, num_heads, v_head_dim + ) + + # FlashInfer uses fused kernels with different computation order/precision than the + # torch reference. With bfloat16 and scaled inputs, tighter tolerances are achievable. + assert torch.allclose( + flashinfer_output_reshaped.cpu().to(torch.float32), + torch_output_reshaped.cpu().to(torch.float32), + atol=0.05, + rtol=0.02, + ), ( + f"FlashInfer MLA context output doesn't match torch backend. " + f"Max diff: {(flashinfer_output_reshaped - torch_output_reshaped).abs().max():.6f}" + ) + + +@pytest.mark.parametrize("prefill_seq_length", [64, 128]) +@pytest.mark.parametrize("num_heads", [1]) +@pytest.mark.parametrize("batch_size", [4, 64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_op_decode(prefill_seq_length, num_heads, batch_size, dtype, device): + """Test FlashInfer MLA decode (generate) phase. + + Compares flashinfer_mla_with_cache against torch_backend_mla_with_cache + for decode operations where seq_length = 1. + """ + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 192 + page_size = 64 + + seq_length = 1 # Decode phase + + max_seq_len = 256 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # Create input tensors + inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + # Sequence lengths and positions + seq_lengths = [seq_length] * batch_size + input_positions = [prefill_seq_length] * batch_size + + # ========================================================================= + # Setup caches with pre-filled data + # ========================================================================= + # Create unpaged cache with prefilled data + torch_meta = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + # Pre-fill cache with random data if prefill_seq_length > 0 + if prefill_seq_length > 0: + torch_meta["mla_cache"][:, :prefill_seq_length, :] = torch.randn( + batch_size, + prefill_seq_length, + kv_lora_rank + qk_rope_head_dim, + dtype=dtype, + device=device, + ) + + # Create paged cache + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + # Copy unpaged cache to paged format + if prefill_seq_length > 0: + _copy_unpaged_to_paged_cache( + torch_meta["mla_cache"], + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + batch_size, + [prefill_seq_length] * batch_size, # Number of tokens to copy + page_size, + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cache_loc"], + kv_lora_rank, + ) + + # ========================================================================= + # Run torch backend (reference) + # ========================================================================= + torch_output = torch.ops.auto_deploy.torch_cached_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + torch_meta["batch_info_host"], + torch_meta["seq_len"], + torch_meta["input_pos"], + torch_meta["slot_idx"], + torch_meta["cu_seqlen"], + torch_meta["mla_cache"], + None, # scale + kv_lora_rank, + ) + + # ========================================================================= + # Run FlashInfer backend + # ========================================================================= + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + # Compute FlashInfer batch indices and positions + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + batch_size * seq_length, + ) + + flashinfer_output = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, # scale + kv_lora_rank, + ) + + # ========================================================================= + # Compare outputs + # ========================================================================= + # FlashInfer uses fused kernels with different computation order/precision than the + # torch reference. With bfloat16 and scaled inputs, tighter tolerances are achievable. + assert torch.allclose( + flashinfer_output.cpu().to(torch.float32), + torch_output.cpu().to(torch.float32), + atol=0.05, + rtol=0.02, + ), ( + f"FlashInfer MLA decode output doesn't match torch backend. " + f"Max diff: {(flashinfer_output - torch_output).abs().max():.6f}" + ) + + +@pytest.mark.parametrize("prefill_seq_length", [16, 128, 1024]) +@pytest.mark.parametrize("num_heads", [1, 8]) +@pytest.mark.parametrize("batch_size", [4, 64, 256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_context_and_generate( + prefill_seq_length, num_heads, batch_size, dtype, device +): + """Test FlashInfer MLA context (prefill) followed by generate (decode). + + This test verifies the full workflow: + 1. Context phase: Process initial sequence + 2. Generate phase: Generate additional tokens one at a time + """ + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 128 + # Use a fixed page_size of 64 for FlashInfer MLA. + page_size = 64 + + max_seq_len = 2048 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # ========================================================================= + # Context phase + # ========================================================================= + inputs_context = _create_mla_inputs( + batch_size, + prefill_seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + seq_lengths_context = [prefill_seq_length] * batch_size + input_positions_context = [0] * batch_size + + # Flatten context inputs + total_tokens = batch_size * prefill_seq_length + q_nope_flat = inputs_context["q_nope"].view(1, total_tokens, num_heads, qk_nope_head_dim) + q_pe_flat = inputs_context["q_pe"].view(1, total_tokens, num_heads, qk_rope_head_dim) + compressed_kv_flat = inputs_context["compressed_kv"].view(1, total_tokens, kv_lora_rank) + kpe_flat = inputs_context["kpe"].view(1, total_tokens, 1, qk_rope_head_dim) + + # Create caches + torch_meta = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths_context, + input_positions_context, + ) + + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths_context, + input_positions_context, + ) + + # Run torch backend context + torch_output_context = torch.ops.auto_deploy.torch_cached_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + inputs_context["kv_b_proj_weight"], + torch_meta["batch_info_host"], + torch_meta["seq_len"], + torch_meta["input_pos"], + torch_meta["slot_idx"], + torch_meta["cu_seqlen"], + torch_meta["mla_cache"], + None, + kv_lora_rank, + ) + + # Run FlashInfer context + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths_context, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + total_tokens, + ) + + flashinfer_output_context = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs_context["q_nope"], + inputs_context["q_pe"], + inputs_context["compressed_kv"], + inputs_context["kpe"], + inputs_context["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify context outputs match + torch_output_context_reshaped = torch_output_context.view( + batch_size, prefill_seq_length, num_heads, v_head_dim + ) + flashinfer_output_context_reshaped = flashinfer_output_context.view( + batch_size, prefill_seq_length, num_heads, v_head_dim + ) + + assert torch.allclose( + flashinfer_output_context_reshaped.cpu().to(torch.float32), + torch_output_context_reshaped.cpu().to(torch.float32), + atol=0.01, + rtol=0.01, + ), "Context phase outputs don't match" + + # ========================================================================= + # Generate phase (single token) + # ========================================================================= + inputs_gen = _create_mla_inputs( + batch_size, + 1, # seq_length = 1 for generate + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + seq_lengths_gen = [1] * batch_size + input_positions_gen = [prefill_seq_length] * batch_size + + # Update torch metadata for generate + torch_meta_gen = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths_gen, + input_positions_gen, + ) + # Use the same cache (already filled from context) + torch_meta_gen["mla_cache"] = torch_meta["mla_cache"] + + # Update flashinfer metadata for generate + flashinfer_meta_gen = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths_gen, + input_positions_gen, + ) + # Use the same caches + flashinfer_meta_gen["ckv_cache"] = flashinfer_meta["ckv_cache"] + flashinfer_meta_gen["kpe_cache"] = flashinfer_meta["kpe_cache"] + + # Run torch backend generate + torch_output_gen = torch.ops.auto_deploy.torch_cached_mla_with_cache( + inputs_gen["q_nope"], + inputs_gen["q_pe"], + inputs_gen["compressed_kv"], + inputs_gen["kpe"], + inputs_context["kv_b_proj_weight"], # Use same weights + torch_meta_gen["batch_info_host"], + torch_meta_gen["seq_len"], + torch_meta_gen["input_pos"], + torch_meta_gen["slot_idx"], + torch_meta_gen["cu_seqlen"], + torch_meta_gen["mla_cache"], + None, + kv_lora_rank, + ) + + # Run FlashInfer generate + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + qo_indptr_gen = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr_gen[1:] = torch.cumsum(torch.tensor(seq_lengths_gen, device=device), dim=0).int() + + batch_indices_gen, positions_gen = flashinfer.get_batch_indices_positions( + qo_indptr_gen, + flashinfer.get_seq_lens( + flashinfer_meta_gen["cu_num_pages"], + flashinfer_meta_gen["last_page_len"], + page_size=page_size, + ), + batch_size, + ) + + flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs_gen["q_nope"], + inputs_gen["q_pe"], + inputs_gen["compressed_kv"], + inputs_gen["kpe"], + inputs_context["kv_b_proj_weight"], + flashinfer_meta_gen["batch_info_host"], + flashinfer_meta_gen["cu_seqlen_host"], + flashinfer_meta_gen["cu_num_pages"], + flashinfer_meta_gen["cu_num_pages_host"], + flashinfer_meta_gen["cache_loc"], + flashinfer_meta_gen["last_page_len"], + flashinfer_meta_gen["last_page_len_host"], + flashinfer_meta_gen["seq_len_with_cache_host"], + batch_indices_gen, + positions_gen, + flashinfer_meta_gen["ckv_cache"], + flashinfer_meta_gen["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify generate outputs match + assert torch.allclose( + flashinfer_output_gen.cpu().to(torch.float32), + torch_output_gen.cpu().to(torch.float32), + atol=0.05, + rtol=0.02, + ), ( + f"Generate phase outputs don't match. " + f"Max diff: {(flashinfer_output_gen - torch_output_gen).abs().max():.6f}" + ) + + +@pytest.mark.parametrize( + "seq_lengths", + [ + [8, 16], + [12, 24, 32], + [4, 8, 16, 32], + ], +) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_with_variable_seq_lengths(seq_lengths, num_heads, dtype, device): + """Test FlashInfer MLA with variable sequence lengths in a batch. + + This test verifies that the FlashInfer MLA backend handles batches + with different sequence lengths correctly. + """ + batch_size = len(seq_lengths) + + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 128 + page_size = 32 + + max_seq_len = 256 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # Create individual inputs for each sequence + total_tokens = sum(seq_lengths) + + # Create batched inputs (we'll flatten later) + q_nope_list = [] + q_pe_list = [] + compressed_kv_list = [] + kpe_list = [] + + for seq_len in seq_lengths: + inputs = _create_mla_inputs( + 1, # Single sequence + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + q_nope_list.append(inputs["q_nope"].squeeze(0)) + q_pe_list.append(inputs["q_pe"].squeeze(0)) + compressed_kv_list.append(inputs["compressed_kv"].squeeze(0)) + kpe_list.append(inputs["kpe"].squeeze(0)) + + # Concatenate into flattened format + q_nope_flat = torch.cat(q_nope_list, dim=0).unsqueeze(0) # [1, total_tokens, N, D] + q_pe_flat = torch.cat(q_pe_list, dim=0).unsqueeze(0) + compressed_kv_flat = torch.cat(compressed_kv_list, dim=0).unsqueeze(0) + kpe_flat = torch.cat(kpe_list, dim=0).unsqueeze(0) + + # Common kv_b_proj_weight + kv_head_dim = qk_nope_head_dim + v_head_dim + kv_b_proj_weight = torch.randn( + num_heads * kv_head_dim, kv_lora_rank, dtype=dtype, device=device + ) + + input_positions = [0] * batch_size + + # ========================================================================= + # Run FlashInfer backend + # ========================================================================= + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + total_tokens, + ) + + flashinfer_output = torch.ops.auto_deploy.flashinfer_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + kv_b_proj_weight, + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify output shape + expected_shape = (1, total_tokens, num_heads, v_head_dim) + assert flashinfer_output.shape == expected_shape, ( + f"Expected shape {expected_shape}, got {flashinfer_output.shape}" + ) + + # Verify output is finite + assert torch.isfinite(flashinfer_output).all(), "Output contains NaN or Inf values" + + +@pytest.mark.parametrize( + "seq_lengths", + [ + [8, 16, 32], + [12, 24, 48, 64], + [16, 32, 64, 96, 128], + ], +) +@pytest.mark.parametrize("num_decode_steps", [3, 5]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_variable_seq_multi_decode( + seq_lengths, num_decode_steps, num_heads, dtype, device +): + """Test FlashInfer MLA with variable sequence lengths and multiple decode steps. + + This test verifies the full workflow with variable sequence lengths: + 1. Context phase: Process initial sequences with different lengths + 2. Multiple decode steps: Generate multiple tokens, updating the cache each step + + Compares torch_backend_mla_with_cache against flashinfer_mla_with_cache. + """ + batch_size = len(seq_lengths) + + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 128 + page_size = 64 + + max_seq_len = max(seq_lengths) + num_decode_steps + 128 # Extra headroom + max_num_pages = batch_size * (max_seq_len // page_size + 2) + + # Create individual inputs for each sequence (context phase) + total_tokens = sum(seq_lengths) + + q_nope_list = [] + q_pe_list = [] + compressed_kv_list = [] + kpe_list = [] + + for seq_len in seq_lengths: + inputs = _create_mla_inputs( + 1, # Single sequence + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + q_nope_list.append(inputs["q_nope"].squeeze(0)) + q_pe_list.append(inputs["q_pe"].squeeze(0)) + compressed_kv_list.append(inputs["compressed_kv"].squeeze(0)) + kpe_list.append(inputs["kpe"].squeeze(0)) + + # Concatenate into flattened format for context phase + q_nope_flat = torch.cat(q_nope_list, dim=0).unsqueeze(0) # [1, total_tokens, N, D] + q_pe_flat = torch.cat(q_pe_list, dim=0).unsqueeze(0) + compressed_kv_flat = torch.cat(compressed_kv_list, dim=0).unsqueeze(0) + kpe_flat = torch.cat(kpe_list, dim=0).unsqueeze(0) + + # Common kv_b_proj_weight + kv_head_dim = qk_nope_head_dim + v_head_dim + weight_scale = 1.0 / (kv_lora_rank**0.5) + kv_b_proj_weight = ( + torch.randn(num_heads * kv_head_dim, kv_lora_rank, dtype=dtype, device=device) + * weight_scale + ) + + input_positions_context = [0] * batch_size + + # ========================================================================= + # Context phase - Setup both backends + # ========================================================================= + + # Create torch unpaged cache + torch_meta = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions_context, + ) + + # Create flashinfer paged cache + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions_context, + ) + + # Run torch backend context + torch_output_context = torch.ops.auto_deploy.torch_cached_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + kv_b_proj_weight, + torch_meta["batch_info_host"], + torch_meta["seq_len"], + torch_meta["input_pos"], + torch_meta["slot_idx"], + torch_meta["cu_seqlen"], + torch_meta["mla_cache"], + None, + kv_lora_rank, + ) + + # Run FlashInfer context + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + total_tokens, + ) + + flashinfer_output_context = torch.ops.auto_deploy.flashinfer_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + kv_b_proj_weight, + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify context outputs match + assert torch.allclose( + flashinfer_output_context.cpu().to(torch.float32), + torch_output_context.cpu().to(torch.float32), + atol=0.05, + rtol=0.02, + ), ( + f"Context phase outputs don't match. " + f"Max diff: {(flashinfer_output_context - torch_output_context).abs().max():.6f}" + ) + + # ========================================================================= + # Multiple decode steps + # ========================================================================= + current_positions = list(seq_lengths) # Track current position for each sequence + + for decode_step in range(num_decode_steps): + # Create decode inputs for this step + inputs_decode = _create_mla_inputs( + batch_size, + 1, # seq_length = 1 for decode + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + seq_lengths_decode = [1] * batch_size + input_positions_decode = current_positions.copy() + + # Update torch metadata for decode + torch_meta_decode = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths_decode, + input_positions_decode, + ) + # Use the same cache (accumulated from context and previous decode steps) + torch_meta_decode["mla_cache"] = torch_meta["mla_cache"] + + # Update flashinfer metadata for decode + flashinfer_meta_decode = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths_decode, + input_positions_decode, + ) + # Use the same caches + flashinfer_meta_decode["ckv_cache"] = flashinfer_meta["ckv_cache"] + flashinfer_meta_decode["kpe_cache"] = flashinfer_meta["kpe_cache"] + + # Run torch backend decode + torch_output_decode = torch.ops.auto_deploy.torch_cached_mla_with_cache( + inputs_decode["q_nope"], + inputs_decode["q_pe"], + inputs_decode["compressed_kv"], + inputs_decode["kpe"], + kv_b_proj_weight, + torch_meta_decode["batch_info_host"], + torch_meta_decode["seq_len"], + torch_meta_decode["input_pos"], + torch_meta_decode["slot_idx"], + torch_meta_decode["cu_seqlen"], + torch_meta_decode["mla_cache"], + None, + kv_lora_rank, + ) + + # Run FlashInfer decode + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + qo_indptr_decode = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr_decode[1:] = torch.cumsum( + torch.tensor(seq_lengths_decode, device=device), dim=0 + ).int() + + batch_indices_decode, positions_decode = flashinfer.get_batch_indices_positions( + qo_indptr_decode, + flashinfer.get_seq_lens( + flashinfer_meta_decode["cu_num_pages"], + flashinfer_meta_decode["last_page_len"], + page_size=page_size, + ), + batch_size, + ) + + flashinfer_output_decode = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs_decode["q_nope"], + inputs_decode["q_pe"], + inputs_decode["compressed_kv"], + inputs_decode["kpe"], + kv_b_proj_weight, + flashinfer_meta_decode["batch_info_host"], + flashinfer_meta_decode["cu_seqlen_host"], + flashinfer_meta_decode["cu_num_pages"], + flashinfer_meta_decode["cu_num_pages_host"], + flashinfer_meta_decode["cache_loc"], + flashinfer_meta_decode["last_page_len"], + flashinfer_meta_decode["last_page_len_host"], + flashinfer_meta_decode["seq_len_with_cache_host"], + batch_indices_decode, + positions_decode, + flashinfer_meta_decode["ckv_cache"], + flashinfer_meta_decode["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify decode outputs match + assert torch.allclose( + flashinfer_output_decode.cpu().to(torch.float32), + torch_output_decode.cpu().to(torch.float32), + atol=0.05, + rtol=0.02, + ), ( + f"Decode step {decode_step + 1} outputs don't match. " + f"Max diff: {(flashinfer_output_decode - torch_output_decode).abs().max():.6f}" + ) + # Update positions for next decode step + current_positions = [pos + 1 for pos in current_positions] + + # Final verification: all outputs should be finite + assert torch.isfinite(flashinfer_output_decode).all(), ( + "Final decode output contains NaN or Inf values" + ) + + +@pytest.mark.parametrize( + "chunk_config", + [ + # Each config has list of chunk sizes per sequence + # e.g., [[32, 16, 8], [64, 32, 16]] means 2 sequences with 3 chunks each + {"chunks_per_seq": [[32, 16], [64, 32]]}, # 2 sequences, 2 chunks each + {"chunks_per_seq": [[32, 16, 8], [64, 32, 16]]}, # 2 sequences, 3 chunks each + {"chunks_per_seq": [[64, 32, 16, 8]]}, # 1 sequence, 4 chunks + { + "chunks_per_seq": [[32, 32, 32], [48, 48, 48], [64, 64, 64]] + }, # 3 sequences, 3 chunks each + {"chunks_per_seq": [[16, 16, 16, 16, 16]]}, # 1 sequence, 5 chunks + ], +) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_chunked_prefill(chunk_config, num_heads, dtype, device): + """Test FlashInfer MLA chunked prefill (incremental prefill) with multiple chunks. + + This test verifies that chunked prefill works correctly when: + 1. First chunk is processed (input_pos == 0) - uses BatchPrefillWithRaggedKVCacheWrapper + 2. Subsequent chunks are processed (input_pos > 0) - uses BatchMLAPagedAttentionWrapper + + In chunked prefill, the Q tokens attend to all KV tokens (cached + current), + which is different from regular prefill where Q and KV lengths are equal. + + Compares flashinfer_mla_with_cache against torch_backend_mla_with_cache. + """ + chunks_per_seq = chunk_config["chunks_per_seq"] + batch_size = len(chunks_per_seq) + num_chunks = len(chunks_per_seq[0]) # Assume all sequences have same number of chunks + + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 128 + page_size = 32 + + # Calculate total sequence lengths + total_seq_lengths = [sum(chunks) for chunks in chunks_per_seq] + max_seq_len = max(total_seq_lengths) + 128 + max_num_pages = batch_size * (max_seq_len // page_size + 2) + + # Common kv_b_proj_weight + kv_head_dim = qk_nope_head_dim + v_head_dim + weight_scale = 1.0 / (kv_lora_rank**0.5) + kv_b_proj_weight = ( + torch.randn(num_heads * kv_head_dim, kv_lora_rank, dtype=dtype, device=device) + * weight_scale + ) + + # Initialize caches (will be reused across chunks) + torch_mla_cache = torch.zeros( + batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device + ) + flashinfer_ckv_cache = torch.zeros( + max_num_pages, page_size, kv_lora_rank, dtype=dtype, device=device + ) + flashinfer_kpe_cache = torch.zeros( + max_num_pages, page_size, qk_rope_head_dim, dtype=dtype, device=device + ) + + # Track cumulative positions per sequence + cumulative_positions = [0] * batch_size + + # Process each chunk + for chunk_idx in range(num_chunks): + # Get current chunk lengths + current_chunk_lengths = [ + chunks_per_seq[seq_idx][chunk_idx] for seq_idx in range(batch_size) + ] + total_tokens = sum(current_chunk_lengths) + input_positions = cumulative_positions.copy() + + # Create inputs for this chunk + q_nope_list = [] + q_pe_list = [] + compressed_kv_list = [] + kpe_list = [] + + for chunk_len in current_chunk_lengths: + inputs = _create_mla_inputs( + 1, + chunk_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + q_nope_list.append(inputs["q_nope"].squeeze(0)) + q_pe_list.append(inputs["q_pe"].squeeze(0)) + compressed_kv_list.append(inputs["compressed_kv"].squeeze(0)) + kpe_list.append(inputs["kpe"].squeeze(0)) + + q_nope_flat = torch.cat(q_nope_list, dim=0).unsqueeze(0) + q_pe_flat = torch.cat(q_pe_list, dim=0).unsqueeze(0) + compressed_kv_flat = torch.cat(compressed_kv_list, dim=0).unsqueeze(0) + kpe_flat = torch.cat(kpe_list, dim=0).unsqueeze(0) + + # ===================================================================== + # Torch backend + # ===================================================================== + torch_meta = _create_unpaged_cache_and_metadata( + batch_size, + max_seq_len, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + current_chunk_lengths, + input_positions, + ) + torch_meta["mla_cache"] = torch_mla_cache # Use shared cache + + torch_output = torch.ops.auto_deploy.torch_cached_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + kv_b_proj_weight, + torch_meta["batch_info_host"], + torch_meta["seq_len"], + torch_meta["input_pos"], + torch_meta["slot_idx"], + torch_meta["cu_seqlen"], + torch_meta["mla_cache"], + None, + kv_lora_rank, + ) + + # ===================================================================== + # FlashInfer backend + # ===================================================================== + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + current_chunk_lengths, + input_positions, + ) + # Use shared caches + flashinfer_meta["ckv_cache"] = flashinfer_ckv_cache + flashinfer_meta["kpe_cache"] = flashinfer_kpe_cache + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum( + torch.tensor(current_chunk_lengths, device=device), dim=0 + ).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + total_tokens, + ) + + flashinfer_output = torch.ops.auto_deploy.flashinfer_mla_with_cache( + q_nope_flat, + q_pe_flat, + compressed_kv_flat, + kpe_flat, + kv_b_proj_weight, + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify outputs match + is_first_chunk = chunk_idx == 0 + chunk_type = "regular prefill" if is_first_chunk else "chunked prefill" + assert torch.allclose( + flashinfer_output.cpu().to(torch.float32), + torch_output.cpu().to(torch.float32), + atol=0.05, + rtol=0.02, + ), ( + f"Chunk {chunk_idx + 1}/{num_chunks} ({chunk_type}) outputs don't match. " + f"Max diff: {(flashinfer_output - torch_output).abs().max():.6f}" + ) + + # Verify outputs are finite + assert torch.isfinite(flashinfer_output).all(), ( + f"Chunk {chunk_idx + 1}/{num_chunks} ({chunk_type}) output contains NaN or Inf values" + ) + + # Update cumulative positions for next chunk + for seq_idx in range(batch_size): + cumulative_positions[seq_idx] += current_chunk_lengths[seq_idx] + + +# ============================================================================= +# CUDA Graph Tests +# ============================================================================= +# Tests for CUDA graph functionality of the FlashInfer MLA planner to verify +# that wrappers are correctly created and cached with use_cuda_graph=True. + + +@pytest.mark.parametrize("prefill_seq_length", [64, 128]) +@pytest.mark.parametrize("num_heads", [1, 8]) +@pytest.mark.parametrize("batch_size", [4, 16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_cuda_graph_wrapper_creation( + prefill_seq_length, num_heads, batch_size, dtype, device +): + """Test that CUDA graph wrappers are created with use_cuda_graph=True during warm-up. + + This test verifies that: + 1. During CudaGraphWarmUpPhase, the planner creates a wrapper with use_cuda_graph=True + 2. The wrapper is cached in cached_cuda_graph_decode_wrappers + 3. The wrapper has correct buffer tensors attached + """ + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 # Must be 64 for FlashInfer MLA. + kv_lora_rank = 512 # Must be 512 for FlashInfer MLA. + v_head_dim = 128 + page_size = 64 + + seq_length = 1 # Decode phase + + max_seq_len = 256 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # Create input tensors + inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + # Sequence lengths and positions + seq_lengths = [seq_length] * batch_size + input_positions = [prefill_seq_length] * batch_size + + # Create paged cache + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + # Pre-fill cache with random data + if prefill_seq_length > 0: + total_prefill_pages = (prefill_seq_length - 1) // page_size + 1 + for batch_idx in range(batch_size): + page_start = batch_idx * total_prefill_pages + for page_offset in range(total_prefill_pages): + page_idx = page_start + page_offset + tokens_in_page = min(page_size, prefill_seq_length - page_offset * page_size) + if tokens_in_page > 0: + flashinfer_meta["ckv_cache"][page_idx, :tokens_in_page] = torch.randn( + tokens_in_page, kv_lora_rank, dtype=dtype, device=device + ) / (kv_lora_rank**0.5) + flashinfer_meta["kpe_cache"][page_idx, :tokens_in_page] = torch.randn( + tokens_in_page, qk_rope_head_dim, dtype=dtype, device=device + ) / (qk_rope_head_dim**0.5) + + # Reset planner + _GlobalFlashInferMLAPlanner.workspace_buffer = None + _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers = {} + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + batch_size * seq_length, + ) + + # Verify no wrappers exist before warm-up + assert len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers) == 0, ( + "Expected no cached wrappers before warm-up" + ) + + # Warm-up phase: This triggers wrapper creation with use_cuda_graph=True + with CudaGraphWarmUpPhase(): + output = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, # scale + kv_lora_rank, + ) + + # Verify a CUDA graph wrapper was created + assert len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers) == 1, ( + f"Expected 1 cached wrapper after warm-up, " + f"got {len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers)}" + ) + + # Verify the wrapper has the correct plan params + for ( + plan_params, + wrapper, + ) in _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers.items(): + assert plan_params.num_seq == batch_size, ( + f"Plan params num_seq={plan_params.num_seq} doesn't match batch_size={batch_size}" + ) + assert plan_params.num_heads == num_heads, ( + f"Plan params num_heads={plan_params.num_heads} doesn't match num_heads={num_heads}" + ) + assert plan_params.kv_lora_rank == kv_lora_rank, ( + f"Plan params kv_lora_rank={plan_params.kv_lora_rank} doesn't match " + f"kv_lora_rank={kv_lora_rank}" + ) + assert plan_params.page_size == page_size, ( + f"Plan params page_size={plan_params.page_size} doesn't match page_size={page_size}" + ) + + # Verify wrapper is not None + assert wrapper is not None, "CUDA graph wrapper should not be None" + + # Verify output is valid + expected_shape = (batch_size, seq_length, num_heads, v_head_dim) + assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}" + assert torch.isfinite(output).all(), "Output contains NaN or Inf values" + + +@pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_cuda_graph_wrapper_caching_per_batch_size(batch_size, dtype, device): + """Test that CUDA graph wrappers are cached per batch size. + + This test verifies that: + 1. Each batch size gets its own cached wrapper + 2. Wrappers are keyed by MLADecodePlanParams which includes num_seq + """ + # MLA dimensions + num_heads = 8 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + v_head_dim = 128 + page_size = 64 + prefill_seq_length = 64 + + seq_length = 1 # Decode phase + + max_seq_len = 256 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # Create inputs + inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + seq_lengths = [seq_length] * batch_size + input_positions = [prefill_seq_length] * batch_size + + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + # Pre-fill cache + if prefill_seq_length > 0: + total_prefill_pages = (prefill_seq_length - 1) // page_size + 1 + for batch_idx in range(batch_size): + page_start = batch_idx * total_prefill_pages + for page_offset in range(total_prefill_pages): + page_idx = page_start + page_offset + tokens_in_page = min(page_size, prefill_seq_length - page_offset * page_size) + if tokens_in_page > 0: + flashinfer_meta["ckv_cache"][page_idx, :tokens_in_page] = torch.randn( + tokens_in_page, kv_lora_rank, dtype=dtype, device=device + ) / (kv_lora_rank**0.5) + flashinfer_meta["kpe_cache"][page_idx, :tokens_in_page] = torch.randn( + tokens_in_page, qk_rope_head_dim, dtype=dtype, device=device + ) / (qk_rope_head_dim**0.5) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + batch_size * seq_length, + ) + + # Reset planner + _GlobalFlashInferMLAPlanner.workspace_buffer = None + _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers = {} + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + # Warm-up to create CUDA graph wrapper for this batch size + with CudaGraphWarmUpPhase(): + _ = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify wrapper was created for this batch size + assert len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers) == 1, ( + f"Expected 1 cached wrapper for batch_size={batch_size}, " + f"got {len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers)}" + ) + + # Verify the wrapper has the correct num_seq + for plan_params in _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers: + assert plan_params.num_seq == batch_size, ( + f"Plan params num_seq={plan_params.num_seq} doesn't match batch_size={batch_size}" + ) + + +@pytest.mark.parametrize("prefill_seq_length", [64]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_plan_generate_only( + prefill_seq_length, num_heads, batch_size, dtype, device +): + """Test plan_generate_only function for re-planning decode-only batches. + + This test verifies that: + 1. plan_generate_only can re-plan cached CUDA graph wrappers + 2. The wrappers can be used after re-planning + """ + # MLA dimensions + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + v_head_dim = 128 + page_size = 64 + + seq_length = 1 # Decode phase + + max_seq_len = 256 + max_num_pages = batch_size * (max_seq_len // page_size + 1) + + # Create inputs + inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + seq_lengths = [seq_length] * batch_size + input_positions = [prefill_seq_length] * batch_size + + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + # Pre-fill cache + if prefill_seq_length > 0: + total_prefill_pages = (prefill_seq_length - 1) // page_size + 1 + for batch_idx in range(batch_size): + page_start = batch_idx * total_prefill_pages + for page_offset in range(total_prefill_pages): + page_idx = page_start + page_offset + tokens_in_page = min(page_size, prefill_seq_length - page_offset * page_size) + if tokens_in_page > 0: + flashinfer_meta["ckv_cache"][page_idx, :tokens_in_page] = torch.randn( + tokens_in_page, kv_lora_rank, dtype=dtype, device=device + ) / (kv_lora_rank**0.5) + flashinfer_meta["kpe_cache"][page_idx, :tokens_in_page] = torch.randn( + tokens_in_page, qk_rope_head_dim, dtype=dtype, device=device + ) / (qk_rope_head_dim**0.5) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + batch_size * seq_length, + ) + + # Reset planner + _GlobalFlashInferMLAPlanner.workspace_buffer = None + _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers = {} + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + # First warm-up to create the wrapper + with CudaGraphWarmUpPhase(): + output1 = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify wrapper was created + assert len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers) == 1 + + # Now test plan_generate_only - this is called by the host-side preparation + # to re-plan the cached wrappers before graph replay + _GlobalFlashInferMLAPlanner.plan_generate_only( + batch_size, + flashinfer_meta["cu_num_pages"][: batch_size + 1], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"][:batch_size], + ) + + # Run again (not in warm-up, so it should use cached wrapper) + # First, update the inputs to simulate new tokens + new_inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + # Create new metadata for position+1 + new_input_positions = [prefill_seq_length + 1] * batch_size + new_flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + new_input_positions, + ) + # Reuse the same cache + new_flashinfer_meta["ckv_cache"] = flashinfer_meta["ckv_cache"] + new_flashinfer_meta["kpe_cache"] = flashinfer_meta["kpe_cache"] + + new_qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + new_qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + new_batch_indices, new_positions = flashinfer.get_batch_indices_positions( + new_qo_indptr, + flashinfer.get_seq_lens( + new_flashinfer_meta["cu_num_pages"], + new_flashinfer_meta["last_page_len"], + page_size=page_size, + ), + batch_size * seq_length, + ) + + output2 = torch.ops.auto_deploy.flashinfer_mla_with_cache( + new_inputs["q_nope"], + new_inputs["q_pe"], + new_inputs["compressed_kv"], + new_inputs["kpe"], + inputs["kv_b_proj_weight"], # Use same weights + new_flashinfer_meta["batch_info_host"], + new_flashinfer_meta["cu_seqlen_host"], + new_flashinfer_meta["cu_num_pages"], + new_flashinfer_meta["cu_num_pages_host"], + new_flashinfer_meta["cache_loc"], + new_flashinfer_meta["last_page_len"], + new_flashinfer_meta["last_page_len_host"], + new_flashinfer_meta["seq_len_with_cache_host"], + new_batch_indices, + new_positions, + new_flashinfer_meta["ckv_cache"], + new_flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify output is valid + expected_shape = (batch_size, seq_length, num_heads, v_head_dim) + assert output2.shape == expected_shape, f"Expected shape {expected_shape}, got {output2.shape}" + assert torch.isfinite(output2).all(), "Output contains NaN or Inf values" + + # Outputs should be different since inputs are different + assert not torch.allclose(output1, output2, atol=1e-6), ( + "Outputs should differ since inputs are different" + ) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_cuda_graph_multiple_batch_sizes(dtype, device): + """Test that multiple batch sizes can have their own CUDA graph wrappers. + + This test verifies that the planner correctly caches wrappers for + different batch sizes, which is important for supporting multiple + CUDA graph configurations. + """ + # MLA dimensions + num_heads = 8 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + v_head_dim = 128 + page_size = 64 + prefill_seq_length = 64 + + seq_length = 1 # Decode phase + + batch_sizes = [4, 8, 16] + + # Reset planner + _GlobalFlashInferMLAPlanner.workspace_buffer = None + _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers = {} + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + for batch_size in batch_sizes: + max_num_pages = batch_size * (256 // page_size + 1) + + inputs = _create_mla_inputs( + batch_size, + seq_length, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + dtype, + device, + ) + + seq_lengths = [seq_length] * batch_size + input_positions = [prefill_seq_length] * batch_size + + flashinfer_meta = _create_paged_cache_and_metadata( + batch_size, + max_num_pages, + page_size, + kv_lora_rank, + qk_rope_head_dim, + dtype, + device, + seq_lengths, + input_positions, + ) + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(torch.tensor(seq_lengths, device=device), dim=0).int() + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + flashinfer_meta["cu_num_pages"], + flashinfer_meta["last_page_len"], + page_size=page_size, + ), + batch_size * seq_length, + ) + + # Warm-up to create wrapper for this batch size + with CudaGraphWarmUpPhase(): + _ = torch.ops.auto_deploy.flashinfer_mla_with_cache( + inputs["q_nope"], + inputs["q_pe"], + inputs["compressed_kv"], + inputs["kpe"], + inputs["kv_b_proj_weight"], + flashinfer_meta["batch_info_host"], + flashinfer_meta["cu_seqlen_host"], + flashinfer_meta["cu_num_pages"], + flashinfer_meta["cu_num_pages_host"], + flashinfer_meta["cache_loc"], + flashinfer_meta["last_page_len"], + flashinfer_meta["last_page_len_host"], + flashinfer_meta["seq_len_with_cache_host"], + batch_indices, + positions, + flashinfer_meta["ckv_cache"], + flashinfer_meta["kpe_cache"], + None, + kv_lora_rank, + ) + + # Verify we have a wrapper for each batch size + assert len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers) == len(batch_sizes), ( + f"Expected {len(batch_sizes)} cached wrappers, " + f"got {len(_GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers)}" + ) + + # Verify each batch size has a wrapper + cached_num_seqs = { + params.num_seq for params in _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers + } + assert cached_num_seqs == set(batch_sizes), ( + f"Expected wrappers for batch_sizes {batch_sizes}, got {cached_num_seqs}" + ) + + +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_mla_init_decode_wrapper_with_buffers(batch_size, dtype, device): + """Test that _init_decode_wrapper correctly passes buffer tensors with use_cuda_graph=True. + + This test directly tests the _init_decode_wrapper method to verify buffer handling. + """ + # Reset planner + _GlobalFlashInferMLAPlanner.workspace_buffer = None + _GlobalFlashInferMLAPlanner.cached_cuda_graph_decode_wrappers = {} + _GlobalFlashInferMLAPlanner.reset(torch.device(device)) + + # Create buffer tensors + qo_indptr = torch.arange(batch_size + 1, device=device, dtype=torch.int32) + kv_indptr = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * 2 + kv_indices = torch.arange(batch_size * 2, device=device, dtype=torch.int32) + kv_len_arr = torch.ones(batch_size, device=device, dtype=torch.int32) * 64 + + # Test creating wrapper without CUDA graph (no buffers needed) + wrapper_no_cg = _GlobalFlashInferMLAPlanner._init_decode_wrapper(use_cuda_graph=False) + assert wrapper_no_cg is not None, "Should create wrapper without CUDA graph" + + # Test creating wrapper with CUDA graph (buffers required) + wrapper_with_cg = _GlobalFlashInferMLAPlanner._init_decode_wrapper( + use_cuda_graph=True, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + kv_len_arr=kv_len_arr, + ) + assert wrapper_with_cg is not None, "Should create wrapper with CUDA graph" + + # Both wrappers should be valid BatchMLAPagedAttentionWrapper instances + assert isinstance(wrapper_no_cg, flashinfer.mla.BatchMLAPagedAttentionWrapper), ( + "Should be BatchMLAPagedAttentionWrapper" + ) + assert isinstance(wrapper_with_cg, flashinfer.mla.BatchMLAPagedAttentionWrapper), ( + "Should be BatchMLAPagedAttentionWrapper" + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py new file mode 100644 index 0000000000..369f2b181f --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py @@ -0,0 +1,780 @@ +"""Comprehensive test suite for torch MLA backend operations. + +Tests the torch_mla source op and torch_backend_mla_with_cache cached op +with FlashInfer-compatible compressed cache layout. + +Key features: +- 5 tensor arguments: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight +- Compressed cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] +- Prefill: Expand compressed_kv, compute normal attention +- Generate: Weight absorption for efficiency +""" + +import math + +import numpy as np +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy # noqa: F401 + + +def numpy_mla_reference_with_expansion( + q_nope: np.ndarray, + q_pe: np.ndarray, + compressed_kv: np.ndarray, + kpe: np.ndarray, + kv_b_proj_weight: np.ndarray, + mla_cache: np.ndarray, + seq_len: np.ndarray, + input_pos: np.ndarray, + cache_loc: np.ndarray, + seq_start: np.ndarray, + scale: float = None, + kv_lora_rank: int = None, + is_generate: bool = False, +): + """Numpy reference implementation of MLA attention with FlashInfer cache layout. + + This expands compressed_kv using kv_b_proj_weight for attention computation. + """ + # Get dimensions + if is_generate: + batch_size = q_nope.shape[0] + num_heads = q_nope.shape[2] + qk_nope_head_dim = q_nope.shape[3] + qk_rope_head_dim = q_pe.shape[3] + else: + batch_size = len(seq_len) + num_heads = q_nope.shape[2] + qk_nope_head_dim = q_nope.shape[3] + qk_rope_head_dim = q_pe.shape[3] + + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + + if kv_lora_rank is None: + kv_lora_rank = compressed_kv.shape[-1] + + # Infer v_head_dim from kv_b_proj_weight + out_features = kv_b_proj_weight.shape[0] + kv_head_dim = out_features // num_heads + v_head_dim = kv_head_dim - qk_nope_head_dim + + if scale is None: + scale = 1.0 / math.sqrt(qk_head_dim) + + # Update MLA cache first + if is_generate: + for i in range(batch_size): + cache_idx = cache_loc[i] + pos = input_pos[i] + mla_cache[cache_idx, pos, :kv_lora_rank] = compressed_kv[i, 0] + mla_cache[cache_idx, pos, kv_lora_rank:] = kpe[i, 0, 0] + else: + for i in range(batch_size): + cache_idx = cache_loc[i] + pos = input_pos[i] + seq_len_i = seq_len[i] + seq_start_i = seq_start[i] + for j in range(seq_len_i): + mla_cache[cache_idx, pos + j, :kv_lora_rank] = compressed_kv[seq_start_i + j] + mla_cache[cache_idx, pos + j, kv_lora_rank:] = kpe[seq_start_i + j, 0] + + # Compute attention for each sequence + outputs = [] + + for i in range(batch_size): + cache_idx = cache_loc[i] + pos = input_pos[i] + seq_len_i = seq_len[i] + seq_start_i = seq_start[i] + + if seq_len_i == 0: + continue + + # Get query for this sequence + if is_generate: + q_nope_seq = q_nope[i, 0] # [N, qk_nope_head_dim] + q_pe_seq = q_pe[i, 0] # [N, qk_rope_head_dim] + else: + q_nope_seq = q_nope[seq_start_i : seq_start_i + seq_len_i] + q_pe_seq = q_pe[seq_start_i : seq_start_i + seq_len_i] + + # Get cached compressed_kv and kpe + kv_seq_len = pos + seq_len_i + cached_data = mla_cache[cache_idx, :kv_seq_len] + compressed_kv_cached = cached_data[:, :kv_lora_rank] + kpe_cached = cached_data[:, kv_lora_rank:] + + # Expand compressed_kv using kv_b_proj_weight + # compressed_kv_cached: [kv_seq_len, kv_lora_rank] + # kv_b_proj_weight: [N * kv_head_dim, kv_lora_rank] + kv_expanded = np.matmul(compressed_kv_cached, kv_b_proj_weight.T) + kv_expanded = kv_expanded.reshape(kv_seq_len, num_heads, kv_head_dim) + + k_nope = kv_expanded[:, :, :qk_nope_head_dim] + v = kv_expanded[:, :, qk_nope_head_dim:] + + # Expand kpe to all heads + kpe_expanded = np.broadcast_to( + kpe_cached[:, None, :], (kv_seq_len, num_heads, qk_rope_head_dim) + ) + + # Construct full query and key + if is_generate: + query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1) + else: + query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1) + + key_full = np.concatenate([k_nope, kpe_expanded], axis=-1) + + # Compute attention scores + if is_generate: + attn_scores = np.einsum("nh,knh->nk", query_full, key_full) * scale + else: + attn_scores = np.einsum("snh,knh->snk", query_full, key_full) * scale + causal_mask = np.triu(np.ones((seq_len_i, kv_seq_len)), k=kv_seq_len - seq_len_i + 1) + attn_scores = np.where(causal_mask[:, None, :], -np.inf, attn_scores) + + # Apply softmax + attn_scores_max = np.max(attn_scores, axis=-1, keepdims=True) + attn_scores_exp = np.exp(attn_scores - attn_scores_max) + attn_weights = attn_scores_exp / np.sum(attn_scores_exp, axis=-1, keepdims=True) + + # Compute output + if is_generate: + attn_out = np.einsum("nk,knh->nh", attn_weights, v) + else: + attn_out = np.einsum("snk,knh->snh", attn_weights, v) + + outputs.append(attn_out) + + # Concatenate outputs + if len(outputs) == 0: + return np.zeros((1, 0, num_heads, v_head_dim), dtype=np.float32) + elif is_generate: + result = np.stack(outputs, axis=0) + return result[:, None, :, :] + else: + result = np.concatenate(outputs, axis=0) + return result[None, :, :, :] + + +class TestTorchMLASourceOp: + """Test torch_mla source op (without cache).""" + + @pytest.fixture(autouse=True) + def setup_method(self): + """Setup test configuration.""" + self.device = "cuda" + self.dtype = torch.bfloat16 + self.atol = 1e-2 + self.rtol = 1e-2 + + torch.cuda.empty_cache() + torch.manual_seed(42) + np.random.seed(42) + + def _create_mla_data( + self, + batch_size: int, + seq_len: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + kv_lora_rank: int, + v_head_dim: int, + layout: str = "bsnd", + ): + """Create test data for MLA source op with compressed_kv.""" + kv_head_dim = qk_nope_head_dim + v_head_dim + + if layout == "bsnd": + q_nope = torch.randn( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + dtype=self.dtype, + device=self.device, + ) + q_pe = torch.randn( + batch_size, + seq_len, + num_heads, + qk_rope_head_dim, + dtype=self.dtype, + device=self.device, + ) + compressed_kv = torch.randn( + batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device + ) + kpe = torch.randn( + batch_size, seq_len, 1, qk_rope_head_dim, dtype=self.dtype, device=self.device + ) + else: # bnsd + q_nope = torch.randn( + batch_size, + num_heads, + seq_len, + qk_nope_head_dim, + dtype=self.dtype, + device=self.device, + ) + q_pe = torch.randn( + batch_size, + num_heads, + seq_len, + qk_rope_head_dim, + dtype=self.dtype, + device=self.device, + ) + compressed_kv = torch.randn( + batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device + ) + kpe = torch.randn( + batch_size, 1, seq_len, qk_rope_head_dim, dtype=self.dtype, device=self.device + ) + + # kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank] + kv_b_proj_weight = torch.randn( + num_heads * kv_head_dim, kv_lora_rank, dtype=self.dtype, device=self.device + ) + + return { + "q_nope": q_nope, + "q_pe": q_pe, + "compressed_kv": compressed_kv, + "kpe": kpe, + "kv_b_proj_weight": kv_b_proj_weight, + } + + def test_basic_functionality(self): + """Test basic MLA source op functionality.""" + batch_size, seq_len, num_heads = 2, 4, 8 + qk_nope_head_dim, qk_rope_head_dim = 128, 64 + kv_lora_rank = 512 + v_head_dim = 128 + + data = self._create_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + ) + + output = torch.ops.auto_deploy.torch_mla( + data["q_nope"], + data["q_pe"], + data["compressed_kv"], + data["kpe"], + data["kv_b_proj_weight"], + True, # is_causal + None, # scale + "bsnd", # layout + ) + + # Verify output shape: [B, S, N, v_head_dim] + expected_shape = (batch_size, seq_len, num_heads, v_head_dim) + assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}" + + # Verify output is finite + assert torch.isfinite(output).all(), "Output contains NaN or Inf values" + + def test_both_layouts(self): + """Test MLA source op with both bsnd and bnsd layouts.""" + batch_size, seq_len, num_heads = 2, 4, 8 + qk_nope_head_dim, qk_rope_head_dim = 64, 32 + kv_lora_rank = 256 + v_head_dim = 64 + + for layout in ["bsnd", "bnsd"]: + data = self._create_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + layout, + ) + + output = torch.ops.auto_deploy.torch_mla( + data["q_nope"], + data["q_pe"], + data["compressed_kv"], + data["kpe"], + data["kv_b_proj_weight"], + True, + None, + layout, + ) + + if layout == "bsnd": + expected_shape = (batch_size, seq_len, num_heads, v_head_dim) + else: + expected_shape = (batch_size, num_heads, seq_len, v_head_dim) + + assert output.shape == expected_shape, ( + f"Layout {layout}: Expected {expected_shape}, got {output.shape}" + ) + + def test_custom_scale(self): + """Test MLA source op with custom scale.""" + batch_size, seq_len, num_heads = 1, 2, 4 + qk_nope_head_dim, qk_rope_head_dim = 32, 16 + kv_lora_rank = 128 + v_head_dim = 32 + + data = self._create_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + ) + + # Test with default scale + output_default = torch.ops.auto_deploy.torch_mla( + data["q_nope"], + data["q_pe"], + data["compressed_kv"], + data["kpe"], + data["kv_b_proj_weight"], + True, + None, + "bsnd", + ) + + # Test with custom scale + custom_scale = 0.5 + output_custom = torch.ops.auto_deploy.torch_mla( + data["q_nope"], + data["q_pe"], + data["compressed_kv"], + data["kpe"], + data["kv_b_proj_weight"], + True, + custom_scale, + "bsnd", + ) + + # Outputs should be different + assert not torch.allclose(output_default, output_custom, atol=1e-3), ( + "Custom scale should affect output" + ) + + +class TestTorchBackendMLAWithCache: + """Test torch_backend_mla_with_cache cached op.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + """Setup test configuration.""" + self.device = "cuda" + self.dtype = torch.bfloat16 + self.atol = 5e-2 + self.rtol = 5e-2 + + torch.cuda.empty_cache() + torch.manual_seed(42) + np.random.seed(42) + + def _create_cached_mla_data( + self, + batch_size: int, + seq_len: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + kv_lora_rank: int, + v_head_dim: int, + max_seq_len: int, + cache_offset: int = 0, + ): + """Create test data for cached MLA op with FlashInfer layout.""" + kv_head_dim = qk_nope_head_dim + v_head_dim + + # Create input tensors (BSND layout) + q_nope = torch.randn( + batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=self.dtype, device=self.device + ) + q_pe = torch.randn( + batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=self.dtype, device=self.device + ) + compressed_kv = torch.randn( + batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device + ) + kpe = torch.randn( + batch_size, seq_len, 1, qk_rope_head_dim, dtype=self.dtype, device=self.device + ) + + # kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank] + kv_b_proj_weight = torch.randn( + num_heads * kv_head_dim, kv_lora_rank, dtype=self.dtype, device=self.device + ) + + # Create FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + mla_cache = torch.zeros( + batch_size, + max_seq_len, + kv_lora_rank + qk_rope_head_dim, + dtype=self.dtype, + device=self.device, + ) + + # Pre-fill cache with random data if cache_offset > 0 + if cache_offset > 0: + mla_cache[:, :cache_offset, :] = torch.randn( + batch_size, + cache_offset, + kv_lora_rank + qk_rope_head_dim, + dtype=self.dtype, + device=self.device, + ) + + # Setup metadata + seq_len_tensor = torch.full((batch_size,), seq_len, device=self.device, dtype=torch.int32) + input_pos = torch.full((batch_size,), cache_offset, device=self.device, dtype=torch.int32) + cache_loc = torch.arange(batch_size, device=self.device, dtype=torch.int32) + + if seq_len == 1: + # Generate phase + batch_info_host = torch.tensor( + [0, 0, batch_size], device=self.device, dtype=torch.int32 + ) + cu_seqlen = torch.arange(batch_size, device=self.device, dtype=torch.int32) + else: + # Context phase + batch_info_host = torch.tensor( + [batch_size, batch_size * seq_len, 0], device=self.device, dtype=torch.int32 + ) + cu_seqlen = torch.arange( + 0, batch_size * seq_len, seq_len, device=self.device, dtype=torch.int32 + ) + # Flatten inputs for context phase + q_nope = q_nope.view(1, batch_size * seq_len, num_heads, qk_nope_head_dim) + q_pe = q_pe.view(1, batch_size * seq_len, num_heads, qk_rope_head_dim) + compressed_kv = compressed_kv.view(1, batch_size * seq_len, kv_lora_rank) + kpe = kpe.view(1, batch_size * seq_len, 1, qk_rope_head_dim) + + return { + "q_nope": q_nope, + "q_pe": q_pe, + "compressed_kv": compressed_kv, + "kpe": kpe, + "kv_b_proj_weight": kv_b_proj_weight, + "batch_info_host": batch_info_host, + "seq_len": seq_len_tensor, + "input_pos": input_pos, + "cache_loc": cache_loc, + "cu_seqlen": cu_seqlen, + "mla_cache": mla_cache, + "kv_lora_rank": kv_lora_rank, + } + + def _run_cached_mla(self, data, scale=None): + """Run cached MLA operation.""" + return torch.ops.auto_deploy.torch_cached_mla_with_cache( + data["q_nope"], + data["q_pe"], + data["compressed_kv"], + data["kpe"], + data["kv_b_proj_weight"], + data["batch_info_host"], + data["seq_len"], + data["input_pos"], + data["cache_loc"], + data["cu_seqlen"], + data["mla_cache"], + scale, + data["kv_lora_rank"], + ) + + def test_generate_phase_basic(self): + """Test generate phase (single token) basic functionality.""" + batch_size, seq_len, num_heads = 2, 1, 8 + qk_nope_head_dim, qk_rope_head_dim = 64, 32 + kv_lora_rank = 256 + v_head_dim = 64 + max_seq_len = 128 + cache_offset = 5 + + data = self._create_cached_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + max_seq_len, + cache_offset, + ) + + output = self._run_cached_mla(data) + + # Verify output shape + expected_shape = (batch_size, seq_len, num_heads, v_head_dim) + assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}" + + # Verify output is finite + assert torch.isfinite(output).all(), "Output contains NaN or Inf values" + + def test_context_phase_basic(self): + """Test context phase (multi-token) basic functionality.""" + batch_size, seq_len, num_heads = 2, 4, 8 + qk_nope_head_dim, qk_rope_head_dim = 64, 32 + kv_lora_rank = 256 + v_head_dim = 64 + max_seq_len = 128 + + data = self._create_cached_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + max_seq_len, + ) + + output = self._run_cached_mla(data) + + # Verify output shape + expected_shape = (1, batch_size * seq_len, num_heads, v_head_dim) + assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}" + + # Verify output is finite + assert torch.isfinite(output).all(), "Output contains NaN or Inf values" + + def test_cache_update_correctness(self): + """Test that cache is updated correctly during forward pass.""" + batch_size, seq_len, num_heads = 1, 1, 4 + qk_nope_head_dim, qk_rope_head_dim = 32, 16 + kv_lora_rank = 128 + v_head_dim = 32 + max_seq_len = 32 + cache_offset = 5 + + data = self._create_cached_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + max_seq_len, + cache_offset, + ) + + # Store original cache values at target position + original_cache_at_pos = data["mla_cache"][0, cache_offset].clone() + + # Run forward pass + _ = self._run_cached_mla(data) + + # Check cache was updated at the correct position + updated_cache_at_pos = data["mla_cache"][0, cache_offset] + + # The cache should have been updated + assert not torch.allclose(original_cache_at_pos, updated_cache_at_pos, atol=1e-6), ( + "Cache should have been updated at the target position" + ) + + def test_cache_layout_flashinfer_compatible(self): + """Test that cache layout matches FlashInfer spec (no num_heads dimension).""" + batch_size, seq_len, num_heads = 2, 1, 4 + qk_nope_head_dim, qk_rope_head_dim = 64, 32 + kv_lora_rank = 512 # DeepSeek-style + v_head_dim = 128 + max_seq_len = 64 + + data = self._create_cached_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + max_seq_len, + ) + + # Verify cache shape matches FlashInfer layout: [batch, seq, kv_lora_rank + rope_dim] + expected_cache_shape = (batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim) + assert data["mla_cache"].shape == expected_cache_shape, ( + f"Cache shape {data['mla_cache'].shape} doesn't match FlashInfer layout {expected_cache_shape}" + ) + + # Verify zero-copy slicing works + compressed_kv_slice = data["mla_cache"][:, :, :kv_lora_rank] + kpe_slice = data["mla_cache"][:, :, kv_lora_rank:] + + assert compressed_kv_slice.shape == (batch_size, max_seq_len, kv_lora_rank) + assert kpe_slice.shape == (batch_size, max_seq_len, qk_rope_head_dim) + + # Verify slices share memory (zero-copy) + assert compressed_kv_slice.data_ptr() == data["mla_cache"].data_ptr(), ( + "compressed_kv slice should be zero-copy" + ) + + def test_generate_with_reference(self): + """Test generate phase against numpy reference.""" + batch_size, seq_len, num_heads = 2, 1, 4 + qk_nope_head_dim, qk_rope_head_dim = 32, 16 + kv_lora_rank = 64 + v_head_dim = 32 + max_seq_len = 64 + cache_offset = 3 + + data = self._create_cached_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + max_seq_len, + cache_offset, + ) + + # Run backend + output = self._run_cached_mla(data) + + # Run numpy reference + reference = numpy_mla_reference_with_expansion( + data["q_nope"].cpu().float().numpy(), + data["q_pe"].cpu().float().numpy(), + data["compressed_kv"].cpu().float().numpy(), + data["kpe"].cpu().float().numpy(), + data["kv_b_proj_weight"].cpu().float().numpy(), + data["mla_cache"].cpu().float().numpy(), + data["seq_len"].cpu().numpy(), + data["input_pos"].cpu().numpy(), + data["cache_loc"].cpu().numpy(), + data["cu_seqlen"].cpu().numpy(), + None, + kv_lora_rank, + is_generate=True, + ) + + reference_torch = torch.from_numpy(reference).to(output.device, output.dtype) + assert torch.allclose(output, reference_torch, atol=self.atol, rtol=self.rtol), ( + f"Generate phase output doesn't match reference. " + f"Max diff: {(output - reference_torch).abs().max():.6f}" + ) + + def test_dtype_preservation(self): + """Test that output dtype matches input dtype.""" + batch_size, seq_len, num_heads = 1, 1, 4 + qk_nope_head_dim, qk_rope_head_dim = 32, 16 + kv_lora_rank = 64 + v_head_dim = 32 + max_seq_len = 32 + + for dtype in [torch.float16, torch.bfloat16]: + self.dtype = dtype + data = self._create_cached_mla_data( + batch_size, + seq_len, + num_heads, + qk_nope_head_dim, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + max_seq_len, + ) + + output = self._run_cached_mla(data) + assert output.dtype == dtype, f"Expected dtype {dtype}, got {output.dtype}" + + def test_memory_efficiency(self): + """Test that cache uses compressed dimensions (no num_heads).""" + batch_size = 1 + max_seq_len = 1024 + kv_lora_rank = 512 + qk_rope_head_dim = 64 + num_heads = 128 # DeepSeek V3 + + # FlashInfer compressed cache size + compressed_cache_size = batch_size * max_seq_len * (kv_lora_rank + qk_rope_head_dim) + + # Expanded per-head cache size (what we avoid) + qk_nope_head_dim = 128 + v_head_dim = 128 + expanded_cache_size = ( + batch_size + * max_seq_len + * num_heads + * (qk_nope_head_dim + v_head_dim + qk_rope_head_dim) + ) + + # Verify compression ratio + compression_ratio = expanded_cache_size / compressed_cache_size + assert compression_ratio > 50, f"Expected >50x compression, got {compression_ratio:.1f}x" + + +class TestMLADescriptor: + """Test MultiHeadLatentAttention descriptor configuration.""" + + def _get_mla_descriptor(self): + """Get MLA descriptor from registry.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionRegistry + + return AttentionRegistry.get("torch_mla") + + def test_descriptor_registration(self): + """Test that MLA descriptor is properly registered.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionRegistry + + assert AttentionRegistry.has("torch_mla"), "torch_mla should be registered" + + def test_descriptor_layout(self): + """Test that MLA descriptor uses correct layout.""" + mla_descriptor = self._get_mla_descriptor() + + assert mla_descriptor.get_attention_layout() == "bsnd", "MLA should use bsnd layout" + + def test_descriptor_num_qkv_args(self): + """Test that MLA descriptor expects 5 tensor args.""" + mla_descriptor = self._get_mla_descriptor() + + assert mla_descriptor.get_num_qkv_args() == 5, ( + "MLA should expect 5 tensor args (q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight)" + ) + + def test_descriptor_source_op(self): + """Test that MLA descriptor points to correct source op.""" + mla_descriptor = self._get_mla_descriptor() + + source_op = mla_descriptor.get_source_attention_op() + assert source_op == torch.ops.auto_deploy.torch_mla, "MLA should use torch_mla as source op" + + def test_descriptor_cached_op(self): + """Test that MLA descriptor points to correct cached op.""" + mla_descriptor = self._get_mla_descriptor() + + cached_op = mla_descriptor.get_cached_attention_op() + assert cached_op == torch.ops.auto_deploy.torch_cached_mla_with_cache.default, ( + "MLA should use torch_cached_mla_with_cache as cached op" + ) + + def test_descriptor_standard_metadata(self): + """Test that MLA descriptor uses standard metadata args.""" + mla_descriptor = self._get_mla_descriptor() + + expected_args = ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"] + actual_args = mla_descriptor.get_standard_metadata_args() + assert actual_args == expected_args, ( + f"Expected standard metadata {expected_args}, got {actual_args}" + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_update_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_update_kv_cache.py index 34baf42d43..47681b5596 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_update_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_update_kv_cache.py @@ -1,6 +1,8 @@ import torch -from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention import update_kv_cache +from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import ( + _update_kv_cache, +) def test_update_kv_cache(): @@ -26,7 +28,7 @@ def test_update_kv_cache(): print("slot_idx: " + str(torch.tensor([0, 1]))) print("seq_start: " + str(torch.tensor([0, 3]))) - update_kv_cache( + _update_kv_cache( k.view(batch_size * seq_length, n_heads, K_D_HEAD), v.view(batch_size * seq_length, n_heads, V_D_HEAD), k_cache, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py new file mode 100644 index 0000000000..7e8779941c --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py @@ -0,0 +1,494 @@ +"""Testing custom DeepSeekV3 model implementation for auto_deploy export.""" + +import pytest +import torch +from transformers import PretrainedConfig + +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_deepseek import ( + DeepSeekV3Attention, + DeepSeekV3DecoderLayer, + DeepSeekV3ForCausalLM, + DeepSeekV3MLP, + DeepSeekV3Model, + DeepSeekV3MoE, + DeepSeekV3RMSNorm, + DeepSeekV3RotaryEmbedding, + DeepSeekV3YarnRotaryEmbedding, +) + + +class MockDeepSeekConfig(PretrainedConfig): + """Mock DeepSeek config for testing the custom model components.""" + + model_type = "deepseek_v3" + + def __init__(self): + super().__init__() + # Attention config + self.num_attention_heads = 8 + self.qk_nope_head_dim = 64 + self.qk_rope_head_dim = 32 + self.v_head_dim = 64 + self.kv_lora_rank = 128 + self.q_lora_rank = None # No LoRA for Q in tests + self.hidden_size = 256 + self.rope_theta = 10000.0 + self.max_position_embeddings = 512 + self.attention_bias = False + self.rope_scaling = None + self.rms_norm_eps = 1e-6 + + # MLP config + self.intermediate_size = 512 + self.hidden_act = "silu" + + # MoE config + self.n_routed_experts = 4 + self.num_experts_per_tok = 2 + self.moe_intermediate_size = 256 + self.n_shared_experts = 1 + self.routed_scaling_factor = 1.0 + self.n_group = 1 + self.topk_group = 1 + + # Model config + self.num_hidden_layers = 2 + self.first_k_dense_replace = 1 # First layer is dense, second is MoE + self.moe_layer_freq = 1 + self.vocab_size = 1000 + self.pad_token_id = 0 + self.initializer_range = 0.02 + + +class TestDeepSeekV3RMSNorm: + """Test DeepSeekV3RMSNorm implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_forward_shape(self): + """Test that RMSNorm preserves input shape.""" + hidden_size = 256 + norm = DeepSeekV3RMSNorm(hidden_size).to(self.device, self.dtype) + + x = torch.randn(2, 4, hidden_size, dtype=self.dtype, device=self.device) + output = norm(x) + + assert output.shape == x.shape + assert torch.isfinite(output).all() + + def test_output_normalized(self): + """Test that output has approximately unit variance.""" + hidden_size = 256 + norm = DeepSeekV3RMSNorm(hidden_size).to(self.device, torch.float32) + + x = torch.randn(2, 4, hidden_size, dtype=torch.float32, device=self.device) + output = norm(x) + + # RMS should be close to 1 after normalization (scaled by weight) + rms = torch.sqrt((output**2).mean(-1)) + assert torch.allclose(rms, torch.ones_like(rms), atol=0.1) + + +class TestDeepSeekV3RotaryEmbedding: + """Test DeepSeekV3 Rotary Embedding implementations.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_base_rope_shape(self): + """Test base rotary embedding output shape.""" + dim = 32 + max_pos = 512 + rope = DeepSeekV3RotaryEmbedding(dim, max_pos).to(self.device) + + x = torch.randn(2, 4, 8, dim, dtype=self.dtype, device=self.device) + cos, sin = rope(x) + + # Should return full cached values + assert cos.shape == (max_pos, dim) + assert sin.shape == (max_pos, dim) + + def test_yarn_rope_shape(self): + """Test YaRN rotary embedding output shape.""" + dim = 32 + max_pos = 512 + rope = DeepSeekV3YarnRotaryEmbedding( + dim, + max_pos, + scaling_factor=2.0, + original_max_position_embeddings=256, + ).to(self.device) + + x = torch.randn(2, 4, 8, dim, dtype=self.dtype, device=self.device) + cos, sin = rope(x) + + assert cos.shape == (max_pos, dim) + assert sin.shape == (max_pos, dim) + + +class TestDeepSeekV3MLP: + """Test DeepSeekV3MLP implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_forward_shape(self): + """Test MLP output shape.""" + config = MockDeepSeekConfig() + mlp = DeepSeekV3MLP(config).to(self.device, self.dtype) + + x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device) + output = mlp(x) + + assert output.shape == x.shape + assert torch.isfinite(output).all() + + +class TestDeepSeekV3Attention: + """Test DeepSeekV3Attention (MLA) implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_forward_shape(self): + """Test attention output shape.""" + config = MockDeepSeekConfig() + attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype) + + batch_size, seq_len = 2, 4 + hidden_states = torch.randn( + batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device + ) + position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1) + + output = attn(hidden_states, position_ids) + + assert output.shape == hidden_states.shape + assert torch.isfinite(output).all() + + def test_different_batch_sizes(self): + """Test attention with different batch sizes.""" + config = MockDeepSeekConfig() + attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype) + + for batch_size in [1, 2, 4]: + seq_len = 4 + hidden_states = torch.randn( + batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device + ) + position_ids = ( + torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1) + ) + + output = attn(hidden_states, position_ids) + assert output.shape == (batch_size, seq_len, config.hidden_size) + + def test_different_sequence_lengths(self): + """Test attention with different sequence lengths.""" + config = MockDeepSeekConfig() + attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype) + + for seq_len in [1, 4, 16]: + batch_size = 2 + hidden_states = torch.randn( + batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device + ) + position_ids = ( + torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1) + ) + + output = attn(hidden_states, position_ids) + assert output.shape == (batch_size, seq_len, config.hidden_size) + + +class TestDeepSeekV3MoE: + """Test DeepSeekV3MoE implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_forward_shape(self): + """Test MoE output shape.""" + config = MockDeepSeekConfig() + moe = DeepSeekV3MoE(config).to(self.device, self.dtype) + + x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device) + output = moe(x) + + assert output.shape == x.shape + assert torch.isfinite(output).all() + + def test_with_shared_experts(self): + """Test MoE with shared experts.""" + config = MockDeepSeekConfig() + config.n_shared_experts = 2 + moe = DeepSeekV3MoE(config).to(self.device, self.dtype) + + assert moe.shared_experts is not None + + x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device) + output = moe(x) + + assert output.shape == x.shape + + def test_without_shared_experts(self): + """Test MoE without shared experts.""" + config = MockDeepSeekConfig() + config.n_shared_experts = None + moe = DeepSeekV3MoE(config).to(self.device, self.dtype) + + assert moe.shared_experts is None + + x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device) + output = moe(x) + + assert output.shape == x.shape + + +class TestDeepSeekV3DecoderLayer: + """Test DeepSeekV3DecoderLayer implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_dense_layer(self): + """Test decoder layer with dense MLP.""" + config = MockDeepSeekConfig() + # Layer 0 should be dense (before first_k_dense_replace) + layer = DeepSeekV3DecoderLayer(config, layer_idx=0).to(self.device, self.dtype) + + assert isinstance(layer.mlp, DeepSeekV3MLP) + + batch_size, seq_len = 2, 4 + hidden_states = torch.randn( + batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device + ) + position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1) + + output = layer(hidden_states, position_ids) + + assert output.shape == hidden_states.shape + + def test_moe_layer(self): + """Test decoder layer with MoE.""" + config = MockDeepSeekConfig() + # Layer 1 should be MoE (at first_k_dense_replace) + layer = DeepSeekV3DecoderLayer(config, layer_idx=1).to(self.device, self.dtype) + + assert isinstance(layer.mlp, DeepSeekV3MoE) + + batch_size, seq_len = 2, 4 + hidden_states = torch.randn( + batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device + ) + position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1) + + output = layer(hidden_states, position_ids) + + assert output.shape == hidden_states.shape + + +class TestDeepSeekV3Model: + """Test DeepSeekV3Model implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_forward_with_input_ids(self): + """Test model forward with input_ids.""" + config = MockDeepSeekConfig() + model = DeepSeekV3Model(config).to(self.device, self.dtype) + + batch_size, seq_len = 2, 4 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device) + + output = model(input_ids=input_ids) + + assert output.last_hidden_state.shape == (batch_size, seq_len, config.hidden_size) + + def test_forward_with_inputs_embeds(self): + """Test model forward with inputs_embeds.""" + config = MockDeepSeekConfig() + model = DeepSeekV3Model(config).to(self.device, self.dtype) + + batch_size, seq_len = 2, 4 + inputs_embeds = torch.randn( + batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device + ) + + output = model(inputs_embeds=inputs_embeds) + + assert output.last_hidden_state.shape == inputs_embeds.shape + + +class TestDeepSeekV3ForCausalLM: + """Test DeepSeekV3ForCausalLM implementation.""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 + torch.manual_seed(42) + + def test_forward(self): + """Test causal LM forward pass.""" + config = MockDeepSeekConfig() + model = DeepSeekV3ForCausalLM(config).to(self.device, self.dtype) + + batch_size, seq_len = 2, 4 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device) + + output = model(input_ids=input_ids) + + assert output.logits.shape == (batch_size, seq_len, config.vocab_size) + + def test_output_dtype(self): + """Test that logits are float32.""" + config = MockDeepSeekConfig() + model = DeepSeekV3ForCausalLM(config).to(self.device, self.dtype) + + batch_size, seq_len = 2, 4 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device) + + output = model(input_ids=input_ids) + + # Logits should be float32 for numerical stability + assert output.logits.dtype == torch.float32 + + +class TestMLAOpRegistration: + """Test that MLA ops are properly registered.""" + + def test_torch_mla_registered(self): + """Test that torch_mla op is registered.""" + assert hasattr(torch.ops.auto_deploy, "torch_mla"), "torch_mla op should be registered" + + def test_torch_cached_mla_registered(self): + """Test that torch_cached_mla_with_cache op is registered.""" + assert hasattr(torch.ops.auto_deploy, "torch_cached_mla_with_cache"), ( + "torch_cached_mla_with_cache op should be registered" + ) + + def test_torch_mla_callable(self): + """Test that torch_mla op is callable.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + batch_size, seq_len, num_heads = 1, 2, 4 + qk_nope_head_dim, qk_rope_head_dim = 64, 32 + kv_lora_rank = 128 + v_head_dim = 64 + + q_nope = torch.randn( + batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device + ) + q_pe = torch.randn( + batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device + ) + compressed_kv = torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device) + kpe = torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device) + kv_b_proj_weight = torch.randn( + num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank, dtype=dtype, device=device + ) + + # Should not raise + output = torch.ops.auto_deploy.torch_mla( + q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, True, None, "bsnd" + ) + assert output is not None + + def test_torch_cached_mla_callable(self): + """Test that torch_cached_mla_with_cache op is callable.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + batch_size, seq_len, num_heads = 1, 1, 4 + qk_nope_head_dim, qk_rope_head_dim = 32, 16 + kv_lora_rank = 64 + v_head_dim = 32 + max_seq_len = 32 + + q_nope = torch.randn( + batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device + ) + q_pe = torch.randn( + batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device + ) + compressed_kv = torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device) + kpe = torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device) + kv_b_proj_weight = torch.randn( + num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank, dtype=dtype, device=device + ) + + batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device) + seq_len_tensor = torch.tensor([seq_len], dtype=torch.int32, device=device) + input_pos = torch.tensor([0], dtype=torch.int32, device=device) + cache_loc = torch.tensor([0], dtype=torch.int32, device=device) + cu_seqlen = torch.tensor([0], dtype=torch.int32, device=device) + + mla_cache = torch.zeros( + batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device + ) + + # Should not raise + output = torch.ops.auto_deploy.torch_cached_mla_with_cache( + q_nope, + q_pe, + compressed_kv, + kpe, + kv_b_proj_weight, + batch_info_host, + seq_len_tensor, + input_pos, + cache_loc, + cu_seqlen, + mla_cache, + None, + kv_lora_rank, + ) + assert output is not None + + +class TestMoEOpRegistration: + """Test that MoE ops are properly registered.""" + + def test_torch_moe_registered(self): + """Test that torch_moe op is registered.""" + assert hasattr(torch.ops.auto_deploy, "torch_moe"), "torch_moe op should be registered" + + +class TestCustomModelRegistration: + """Test that custom model is properly registered.""" + + def test_model_registered(self): + """Test that DeepSeekV3ForCausalLM is registered with the factory.""" + from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory + + # Check that the model is registered in the custom model mapping + assert "DeepseekV3Config" in AutoModelForCausalLMFactory._custom_model_mapping + assert ( + AutoModelForCausalLMFactory._custom_model_mapping["DeepseekV3Config"] + == DeepSeekV3ForCausalLM + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_patches.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_patches.py deleted file mode 100644 index 070d85a598..0000000000 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_patches.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Testing module patches that enable export of deepseek model.""" - -import types - -import pytest -import torch -from test_common.llm_data import hf_id_to_local_model_dir -from transformers import AutoConfig, AutoModelForCausalLM - -from tensorrt_llm._torch.auto_deploy.models.patches.deepseek import ( - deepseek_v3_attention, - deepseek_v3_moe_exact, -) - - -def _load_layer_from_model(model_name_or_path, layer_name): - """ - Loads a specific layer/module from a model without loading the entire model. - - Parameters: - model_name_or_path (str): Path or name of the pretrained model. - layer_name (str): Name of the layer to extract. - - Returns: - module: The specified layer/module if available, otherwise None. - """ - try: - # Load only the model configuration - config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) - - # Load a subset of layers of the model and configure yarn - config.num_hidden_layers = 1 - config.use_cache = False - config.first_k_dense_replace = 0 - config.n_routed_experts = 2 - config.num_experts_per_tok = 1 - config.n_group = 1 - config.topk_group = 1 - config.hidden_size = 8 - config.moe_intermediate_size = 8 - config.num_attention_heads = 2 - config.num_key_value_heads = 2 - config.qk_nope_head_dim = 4 - config.qk_rope_head_dim = 2 - config.v_head_dim = 4 - config.intermediate_size = 8 - config.max_position_embeddings = 7 - - config.rope_scaling = None - - # Build the model architecture (no weights loaded yet) - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) - model.eval() - - # Access the specific layer by its name - module = dict(model.named_modules()).get(layer_name) - if module is None: - print(f"Layer '{layer_name}' not found in the model.") - else: - print(f"Successfully extracted layer '{layer_name}'.") - return module - except Exception as e: - print(f"Error extracting layer: {e}") - return None - - -def _generate_ds_attention_mask(b, s): - return torch.where( - torch.tril(torch.full((s, s), float("-inf"))).unsqueeze(0).unsqueeze(0).expand(b, 1, s, s) - == float("-inf"), - torch.tensor(0.0), - torch.tensor(float(-3.4028e38)), - ) - - -@pytest.mark.parametrize( - "model_name, module_name, patch, inputs", - [ - pytest.param( - hf_id_to_local_model_dir("deepseek-ai/DeepSeek-R1"), - "model.layers.0.self_attn", - deepseek_v3_attention, - [ - torch.randn(2, 6, 8, dtype=torch.bfloat16), - _generate_ds_attention_mask(2, 6), - torch.tensor([[0, 1, 2, 3, 4, 5]]), - ], - ), # attention requires inputs [hidden_states, attention_mask, position_ids] - pytest.param( - hf_id_to_local_model_dir("deepseek-ai/DeepSeek-R1"), - "model.layers.0.mlp", - deepseek_v3_moe_exact, - [torch.randn(2, 6, 8, dtype=torch.bfloat16)], - ), # moe requires inputs [hidden_states] - ], -) -def test_module_patches(model_name, module_name, patch, inputs): - # Get module - module = _load_layer_from_model(model_name, module_name) - - # Pass test inputs to generate reference - ref, *_ = module(*inputs) - - # Patch layer - module.forward = types.MethodType(patch, module) - - # Generate test output - test, *_ = module(*inputs) - - torch.allclose(ref, test, atol=0, rtol=0) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_glm4_moe_lite_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_glm4_moe_lite_modeling.py new file mode 100644 index 0000000000..486ffe7d31 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_glm4_moe_lite_modeling.py @@ -0,0 +1,652 @@ +"""Tests for GLM4 MoE Lite custom model implementation. + +This module tests the custom GLM4 MoE Lite model implementation which uses +auto_deploy custom ops (torch_mla, torch_moe, etc.) for export compatibility. +""" + +import pytest +import torch +from torch.export import Dim + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import ( + Glm4MoeLiteConfig, + Glm4MoeLiteForCausalLM, + Glm4MoeLiteMLP, + Glm4MoeLiteMoE, +) +from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device + +_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8)) + + +@pytest.fixture(scope="function", autouse=True) +def set_seed(): + torch.manual_seed(42) + + +def _create_small_config() -> Glm4MoeLiteConfig: + """Create a small GLM4 MoE Lite config for testing.""" + return Glm4MoeLiteConfig( + vocab_size=1000, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=3, # Layer 0 dense, layers 1-2 MoE + num_attention_heads=4, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=512, + rms_norm_eps=1e-5, + # MLA params (scaled down) + q_lora_rank=32, + kv_lora_rank=32, + qk_nope_head_dim=8, + qk_rope_head_dim=8, + v_head_dim=16, + # MoE params (scaled down) + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + moe_intermediate_size=32, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=True, + first_k_dense_replace=1, + # RoPE + rope_theta=10000.0, + rope_scaling=None, + # Other + attention_bias=False, + attention_dropout=0.0, + pad_token_id=0, + ) + + +def _create_moe_layer(config: Glm4MoeLiteConfig) -> Glm4MoeLiteMoE: + """Create a MoE layer from config.""" + moe = Glm4MoeLiteMoE(config) + # Initialize gate weights with randn for reproducibility + # (gate weight is initialized with torch.empty which isn't seeded) + moe.gate.weight = torch.nn.Parameter(torch.randn_like(moe.gate.weight)) + return moe + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_glm4_moe_lite_moe_layer(B, S, dtype): + """Test that MoE layer produces valid output.""" + device = "cuda" + config = _create_small_config() + + moe = _create_moe_layer(config) + moe.to(device=device, dtype=dtype) + moe.eval() + + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + + output = moe(x) + + # Check output shape matches input shape + assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" + + # Check output is not all zeros (MoE should transform the input) + assert not torch.allclose(output, torch.zeros_like(output)), "Output should not be all zeros" + + # Check output doesn't have NaN or Inf values + assert not torch.isnan(output).any(), "Output contains NaN values" + assert not torch.isinf(output).any(), "Output contains Inf values" + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_glm4_moe_lite_full_model(B, S, dtype): + """Test that full model produces valid output.""" + device = "cuda" + config = _create_small_config() + + model = Glm4MoeLiteForCausalLM(config) + model.to(device=device, dtype=dtype) + model.eval() + + # Create input + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + output = model(input_ids=input_ids, position_ids=position_ids) + + # Check output shape + assert output.logits.shape == (B, S, config.vocab_size), ( + f"Expected logits shape {(B, S, config.vocab_size)}, got {output.logits.shape}" + ) + + # Check output doesn't have NaN or Inf values + assert not torch.isnan(output.logits).any(), "Logits contain NaN values" + assert not torch.isinf(output.logits).any(), "Logits contain Inf values" + + +def test_glm4_moe_lite_model_can_be_exported(): + """Test that the custom model can be exported with torch_export_to_gm. + + This test verifies: + 1. The model exports successfully without graph breaks + 2. The exported graph module produces outputs with correct shape + 3. The outputs contain finite values (no NaN/Inf) + + Note: We don't test numerical equivalence between original and exported model + here because torch.export lifts parameters into the graph, creating a different + parameter structure that doesn't match load_state_dict. The numerical correctness + of the model itself is already validated by test_glm4_moe_lite_full_model. + """ + device = "cuda" + dtype = torch.bfloat16 + config = _create_small_config() + + model = Glm4MoeLiteForCausalLM(config) + model.to(device=device, dtype=dtype) + model.eval() + + # Create input + B, S = 2, 8 + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + # Define dynamic shapes + batch_size_dynamic = Dim.DYNAMIC + seq_len_dynamic = Dim.DYNAMIC + dynamic_shapes = ( + {0: batch_size_dynamic, 1: seq_len_dynamic}, + {0: batch_size_dynamic, 1: seq_len_dynamic}, + ) + + # Export the model - this is the main test: verify no graph breaks + gm = torch_export_to_gm( + model, + args=tuple(), + kwargs={"input_ids": input_ids, "position_ids": position_ids}, + dynamic_shapes=dynamic_shapes, + ) + + # Move graph module to device + move_to_device(gm, device) + + # Verify the exported model produces valid output + with torch.inference_mode(): + out_gm = gm(input_ids=input_ids, position_ids=position_ids) + + # Check output structure and shape + assert "logits" in out_gm, "Output should contain 'logits' key" + logits = out_gm["logits"] + assert logits.shape == (B, S, config.vocab_size), ( + f"Expected shape {(B, S, config.vocab_size)}, got {logits.shape}" + ) + assert torch.isfinite(logits).all(), "Logits should not contain NaN or Inf" + + # Test with different input shape to verify dynamic shapes work + B2, S2 = 1, 4 + input_ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device) + position_ids2 = torch.arange(S2, device=device).unsqueeze(0).expand(B2, -1) + + with torch.inference_mode(): + out_gm2 = gm(input_ids=input_ids2, position_ids=position_ids2) + + logits2 = out_gm2["logits"] + expected_shape = (B2, S2, config.vocab_size) + assert logits2.shape == expected_shape, ( + f"Dynamic shape test failed: expected {expected_shape}, got {logits2.shape}" + ) + assert torch.isfinite(logits2).all(), "Logits should not contain NaN or Inf" + + +def test_glm4_moe_lite_config_registration(): + """Test that the config is properly registered or model_type is correct.""" + # Create a config and verify model_type + config = _create_small_config() + assert config.model_type == "glm4_moe_lite" + + # Verify our config class can be instantiated with expected attributes + assert hasattr(config, "hidden_size") + assert hasattr(config, "num_attention_heads") + assert hasattr(config, "n_routed_experts") + assert hasattr(config, "kv_lora_rank") + assert hasattr(config, "qk_rope_head_dim") + + +def test_glm4_moe_lite_layer_types(): + """Test that layer 0 uses dense MLP and later layers use MoE.""" + config = _create_small_config() + model = Glm4MoeLiteForCausalLM(config) + + # Check layer 0 (should be dense MLP, not MoE) + layer0_mlp = model.model.layers[0].mlp + assert type(layer0_mlp).__name__ == "Glm4MoeLiteMLP", ( + f"Layer 0 should use Glm4MoeLiteMLP, got {type(layer0_mlp).__name__}" + ) + + # Check layer 1+ (should be MoE) + for i in range(1, config.num_hidden_layers): + layer_mlp = model.model.layers[i].mlp + assert type(layer_mlp).__name__ == "Glm4MoeLiteMoE", ( + f"Layer {i} should use Glm4MoeLiteMoE, got {type(layer_mlp).__name__}" + ) + + +def test_glm4_moe_lite_expert_structure(): + """Test that experts have correct structure for checkpoint loading.""" + config = _create_small_config() + moe = Glm4MoeLiteMoE(config) + + # Check that experts is a ModuleList + assert isinstance(moe.experts, torch.nn.ModuleList), "experts should be nn.ModuleList" + + # Check number of experts + assert len(moe.experts) == config.n_routed_experts, ( + f"Expected {config.n_routed_experts} experts, got {len(moe.experts)}" + ) + + # Check each expert has the correct structure + for i, expert in enumerate(moe.experts): + assert hasattr(expert, "gate_proj"), f"Expert {i} missing gate_proj" + assert hasattr(expert, "up_proj"), f"Expert {i} missing up_proj" + assert hasattr(expert, "down_proj"), f"Expert {i} missing down_proj" + + # Check state_dict keys match expected checkpoint format + state_dict = moe.state_dict() + expected_keys = [ + "experts.0.gate_proj.weight", + "experts.0.up_proj.weight", + "experts.0.down_proj.weight", + ] + for key in expected_keys: + assert key in state_dict, ( + f"Expected key '{key}' in state_dict, got keys: {list(state_dict.keys())[:10]}..." + ) + + +# ============================================================================= +# Numerical Equivalence Tests +# These tests compare our custom implementation against the HF implementation +# ============================================================================= + + +def _get_hf_model_class(): + """Get the HF model class for GLM4 MoE Lite. + + Returns None if transformers doesn't have glm4_moe_lite (older versions). + """ + try: + from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( + Glm4MoeLiteForCausalLM as HFGlm4MoeLiteForCausalLM, + ) + + return HFGlm4MoeLiteForCausalLM + except ImportError: + return None + + +def _get_hf_moe_class(): + """Get the HF MoE class for GLM4 MoE Lite.""" + try: + from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( + Glm4MoeLiteMoE as HFGlm4MoeLiteMoE, + ) + + return HFGlm4MoeLiteMoE + except ImportError: + return None + + +def _get_hf_attention_class(): + """Get the HF Attention class for GLM4 MoE Lite.""" + try: + from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( + Glm4MoeLiteAttention as HFGlm4MoeLiteAttention, + ) + + return HFGlm4MoeLiteAttention + except ImportError: + return None + + +def _get_hf_mlp_class(): + """Get the HF MLP class for GLM4 MoE Lite.""" + try: + from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( + Glm4MoeLiteMLP as HFGlm4MoeLiteMLP, + ) + + return HFGlm4MoeLiteMLP + except ImportError: + return None + + +def _get_hf_config_class(): + """Get the HF Config class for GLM4 MoE Lite.""" + try: + from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import ( + Glm4MoeLiteConfig as HFGlm4MoeLiteConfig, + ) + + return HFGlm4MoeLiteConfig + except ImportError: + return None + + +def _convert_hf_moe_state_dict_to_custom(hf_state_dict: dict, n_experts: int) -> dict: + """Convert HF MoE state dict to our custom format. + + HF format (stacked): + experts.gate_up_proj: [n_experts, 2 * intermediate_size, hidden_size] + experts.down_proj: [n_experts, hidden_size, intermediate_size] + + Custom format (per-expert): + experts.0.gate_proj.weight: [intermediate_size, hidden_size] + experts.0.up_proj.weight: [intermediate_size, hidden_size] + experts.0.down_proj.weight: [hidden_size, intermediate_size] + """ + custom_state_dict = {} + + for key, value in hf_state_dict.items(): + if key == "experts.gate_up_proj": + # Split stacked gate_up into individual gate and up per expert + # Shape: [n_experts, 2 * intermediate_size, hidden_size] + intermediate_size = value.shape[1] // 2 + for i in range(n_experts): + gate_up = value[i] # [2 * intermediate_size, hidden_size] + gate_weight = gate_up[:intermediate_size] # [intermediate_size, hidden_size] + up_weight = gate_up[intermediate_size:] # [intermediate_size, hidden_size] + custom_state_dict[f"experts.{i}.gate_proj.weight"] = gate_weight + custom_state_dict[f"experts.{i}.up_proj.weight"] = up_weight + elif key == "experts.down_proj": + # Split stacked down into individual down per expert + # Shape: [n_experts, hidden_size, intermediate_size] + for i in range(n_experts): + custom_state_dict[f"experts.{i}.down_proj.weight"] = value[i] + else: + # Copy other keys as-is + custom_state_dict[key] = value + + return custom_state_dict + + +def _convert_hf_full_model_state_dict_to_custom(hf_state_dict: dict, config) -> dict: + """Convert full HF model state dict to custom format. + + Handles MoE expert weight conversion for all MoE layers. + """ + custom_state_dict = {} + n_experts = config.n_routed_experts + + for key, value in hf_state_dict.items(): + # Check if this is an MoE expert weight + if ".mlp.experts.gate_up_proj" in key: + # Extract layer prefix (e.g., "model.layers.1.mlp.") + prefix = key.replace("experts.gate_up_proj", "") + intermediate_size = value.shape[1] // 2 + for i in range(n_experts): + gate_up = value[i] + gate_weight = gate_up[:intermediate_size] + up_weight = gate_up[intermediate_size:] + custom_state_dict[f"{prefix}experts.{i}.gate_proj.weight"] = gate_weight + custom_state_dict[f"{prefix}experts.{i}.up_proj.weight"] = up_weight + elif ".mlp.experts.down_proj" in key: + prefix = key.replace("experts.down_proj", "") + for i in range(n_experts): + custom_state_dict[f"{prefix}experts.{i}.down_proj.weight"] = value[i] + else: + # Copy other keys as-is + custom_state_dict[key] = value + + return custom_state_dict + + +def _create_hf_config(): + """Create HF config that matches our test config.""" + HFConfig = _get_hf_config_class() + if HFConfig is None: + return None + + config = HFConfig( + vocab_size=1000, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=3, + num_attention_heads=4, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=512, + rms_norm_eps=1e-5, + # MLA params + q_lora_rank=32, + kv_lora_rank=32, + qk_nope_head_dim=8, + qk_rope_head_dim=8, + v_head_dim=16, + # MoE params + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + moe_intermediate_size=32, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=True, + first_k_dense_replace=1, + # RoPE + rope_theta=10000.0, + rope_scaling=None, + # Other + attention_bias=False, + attention_dropout=0.0, + pad_token_id=0, + ) + + # Set internal attributes needed by HF's MoE implementation + # _experts_implementation tells HF which expert forward function to use + config._experts_implementation = "eager" + + return config + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_glm4_moe_lite_moe_numerical_equivalence(B, S, dtype): + """Test MoE layer produces numerically equivalent output to HF implementation.""" + HFMoE = _get_hf_moe_class() + if HFMoE is None: + pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)") + + device = "cuda" + config = _create_small_config() + hf_config = _create_hf_config() + + # Create HF MoE + hf_moe = HFMoE(hf_config) + hf_moe.to(device=device, dtype=dtype) + hf_moe.eval() + + # Initialize weights with randn for reproducibility + # (HF uses torch.empty which may be zeros, causing NaN in computation) + hf_moe.gate.weight = torch.nn.Parameter(torch.randn_like(hf_moe.gate.weight)) + hf_moe.experts.gate_up_proj = torch.nn.Parameter(torch.randn_like(hf_moe.experts.gate_up_proj)) + hf_moe.experts.down_proj = torch.nn.Parameter(torch.randn_like(hf_moe.experts.down_proj)) + + # Create custom MoE and load converted weights + custom_moe = Glm4MoeLiteMoE(config) + custom_moe.to(device=device, dtype=dtype) + + # Convert HF stacked expert weights to our per-expert format + hf_state_dict = hf_moe.state_dict() + + # Debug: print state dict keys and shapes + print("\n=== HF MoE state_dict keys and shapes ===") + for k, v in hf_state_dict.items(): + print(f" {k}: {v.shape}") + + custom_state_dict = _convert_hf_moe_state_dict_to_custom(hf_state_dict, config.n_routed_experts) + + print("\n=== Converted custom state_dict keys and shapes ===") + for k, v in custom_state_dict.items(): + print(f" {k}: {v.shape}") + + print("\n=== Expected custom MoE state_dict keys ===") + for k, v in custom_moe.state_dict().items(): + print(f" {k}: {v.shape}") + + custom_moe.load_state_dict(custom_state_dict) + custom_moe.eval() + + # Sanity check: verify expert weights match after conversion + # HF has stacked weights: experts.gate_up_proj [n_experts, 2*intermediate, hidden] + # Our model has per-expert: experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight + hf_gate_up = hf_moe.experts.gate_up_proj # [n_experts, 2*intermediate, hidden] + hf_down = hf_moe.experts.down_proj # [n_experts, hidden, intermediate] + intermediate_size = config.moe_intermediate_size + + print(f"\n=== Debug: intermediate_size = {intermediate_size} ===") + print(f"hf_gate_up shape: {hf_gate_up.shape}") + print(f"hf_gate_up[0, :2, :2]: {hf_gate_up[0, :2, :2]}") + + # Get the converted state dict values for comparison + converted_gate_0 = custom_state_dict["experts.0.gate_proj.weight"] + print(f"converted_gate_0 shape: {converted_gate_0.shape}") + print(f"converted_gate_0[:2, :2]: {converted_gate_0[:2, :2]}") + + # After load_state_dict + loaded_gate_0 = custom_moe.experts[0].gate_proj.weight + print(f"loaded_gate_0 shape: {loaded_gate_0.shape}") + print(f"loaded_gate_0[:2, :2]: {loaded_gate_0[:2, :2]}") + + for i in range(config.n_routed_experts): + # Check gate_proj + hf_gate = hf_gate_up[i, :intermediate_size, :] # [intermediate, hidden] + custom_gate = custom_moe.experts[i].gate_proj.weight + torch.testing.assert_close( + custom_gate, hf_gate, msg=f"Expert {i} gate_proj weights don't match" + ) + + # Check up_proj + hf_up = hf_gate_up[i, intermediate_size:, :] # [intermediate, hidden] + custom_up = custom_moe.experts[i].up_proj.weight + torch.testing.assert_close(custom_up, hf_up, msg=f"Expert {i} up_proj weights don't match") + + # Check down_proj + hf_down_i = hf_down[i] # [hidden, intermediate] + custom_down = custom_moe.experts[i].down_proj.weight + torch.testing.assert_close( + custom_down, hf_down_i, msg=f"Expert {i} down_proj weights don't match" + ) + + # Also verify gate weights match + torch.testing.assert_close( + custom_moe.gate.weight, hf_moe.gate.weight, msg="Gate weights don't match" + ) + + # Create input + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + + # Run both + hf_out = hf_moe(x) + custom_out = custom_moe(x) + + # Handle tuple output from HF (output, router_logits) + if isinstance(hf_out, tuple): + hf_out = hf_out[0] + + # Compare + rtol, atol = 0.05, 0.05 + torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_glm4_moe_lite_mlp_numerical_equivalence(B, S, dtype): + """Test MLP layer produces numerically equivalent output to HF implementation.""" + HFMLP = _get_hf_mlp_class() + if HFMLP is None: + pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)") + + device = "cuda" + config = _create_small_config() + hf_config = _create_hf_config() + + # Create HF MLP + hf_mlp = HFMLP(hf_config) + hf_mlp.to(device=device, dtype=dtype) + hf_mlp.eval() + + # Create custom MLP and load same weights + custom_mlp = Glm4MoeLiteMLP(config) + custom_mlp.to(device=device, dtype=dtype) + custom_mlp.load_state_dict(hf_mlp.state_dict()) + custom_mlp.eval() + + # Create input + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + + # Run both + hf_out = hf_mlp(x) + custom_out = custom_mlp(x) + + # Compare + rtol, atol = 1e-3, 1e-3 + torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_glm4_moe_lite_full_model_numerical_equivalence(B, S, dtype): + """Test full model produces numerically equivalent output to HF implementation.""" + HFModel = _get_hf_model_class() + if HFModel is None: + pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)") + + device = "cuda" + config = _create_small_config() + hf_config = _create_hf_config() + + # Create HF model + hf_model = HFModel(hf_config) + hf_model.to(device=device, dtype=dtype) + hf_model.eval() + + # Initialize all gate weights for reproducibility + for module in hf_model.modules(): + if hasattr(module, "gate") and hasattr(module.gate, "weight"): + module.gate.weight = torch.nn.Parameter(torch.randn_like(module.gate.weight)) + + # Create custom model and load converted weights + custom_model = Glm4MoeLiteForCausalLM(config) + custom_model.to(device=device, dtype=dtype) + + # Convert HF stacked expert weights to our per-expert format + hf_state_dict = hf_model.state_dict() + custom_state_dict = _convert_hf_full_model_state_dict_to_custom(hf_state_dict, config) + custom_model.load_state_dict(custom_state_dict) + custom_model.eval() + + # Create input + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + # Run both + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids) + custom_out = custom_model(input_ids=input_ids, position_ids=position_ids) + + # Compare logits - cast to same dtype for comparison + # (HF model may output float32 for numerical stability in lm_head) + rtol, atol = 0.05, 0.05 + torch.testing.assert_close( + custom_out.logits.float(), + hf_out.logits.float(), + rtol=rtol, + atol=atol, + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py deleted file mode 100644 index ffa2594d90..0000000000 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -import tensorrt_llm._torch.auto_deploy # noqa: F401 - - -def naive_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale): - bsz = q_nope.shape[0] - q_len = q_nope.shape[2] - num_heads = q_nope.shape[1] - qk_nope_head_dim = q_nope.shape[-1] - v_head_dim = wkv_b.weight.shape[-1] - qk_nope_head_dim - qk_head_dim = qk_nope_head_dim + q_pe.shape[-1] - k_pe = k_pe.view(bsz, q_len, 1, q_pe.shape[-1]).transpose(1, 2) - - # Up project compressed_kv - kv = ( - wkv_b(compressed_kv) - .view(bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1) - - query_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim) - query_states[:, :, :, :qk_nope_head_dim] = q_nope - query_states[:, :, :, qk_nope_head_dim:] = q_pe - - key_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim) - key_states[:, :, :, :qk_nope_head_dim] = k_nope - key_states[:, :, :, qk_nope_head_dim:] = k_pe - - x = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=None, - is_causal=False, - dropout_p=0.0, - scale=softmax_scale, - ).transpose(1, 2) - - return x - - -def mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale): - num_heads = q_nope.shape[1] - qk_nope_head_dim = q_nope.shape[-1] - v_head_dim = wkv_b.weight.shape[-1] - qk_nope_head_dim - kv_lora_rank = compressed_kv.shape[-1] - - # Down project q_nope - wkv_b_weight = wkv_b.weight.view(num_heads, -1, kv_lora_rank) - q_nope_proj = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b_weight[:, :qk_nope_head_dim]) - - # MLA ref operation - x = torch.ops.auto_deploy.torch_attention_deepseek_mla( - q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale - ) - - # Up project attention scores - x = torch.einsum("bshc,hdc->bshd", x, wkv_b_weight[:, -v_head_dim:]) - return x - - -def test_attn(): - # Define test configurations - kv_lora_rank = 4 - bsz = 2 - q_len = 6 - v_head_dim = 2 - qk_nope_head_dim = 2 - qk_rope_head_dim = 1 - qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - num_heads = 4 - softmax_scale = qk_head_dim**-0.5 - - # Generate inputs - q_nope = torch.randn(bsz, num_heads, q_len, qk_nope_head_dim) - q_pe = torch.randn(bsz, num_heads, q_len, qk_rope_head_dim) - compressed_kv = torch.randn(bsz, q_len, kv_lora_rank) - k_pe = torch.randn(bsz, q_len, qk_rope_head_dim) - - # Define w_kv_b projection matrix - wkv_b = nn.Linear( - kv_lora_rank, num_heads * (qk_head_dim - qk_rope_head_dim + v_head_dim), bias=False - ) - - # Compute naive attention - out_naive = naive_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale) - - # Compute MLA attention - out_mla = mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale) - - # Check if the two outputs are close - assert torch.allclose(out_naive, out_mla, rtol=1e-5, atol=1e-5)