[#11032][feat] MLA revisited and GLM 4.7 Flash support (#11324)

This commit is contained in:
Lucas Liebenwein 2026-02-09 23:26:51 -05:00 committed by GitHub
parent d50f010fa9
commit a2fb5afecf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 7709 additions and 1034 deletions

View File

@ -0,0 +1,348 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deploying GLM-4.7-Flash with TensorRT-LLM\n",
"\n",
"This notebook walks you through deploying the `zai-org/GLM-4.7-Flash` model using TensorRT-LLM.\n",
"\n",
"[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating and optimizing LLM inference on NVIDIA GPUs. Support for GLM-4.7-Flash is enabled through the AutoDeploy workflow. More details about AutoDeploy can be found [here](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n",
"\n",
"**Model Resources:**\n",
"- [HuggingFace Model Card](https://huggingface.co/zai-org/GLM-4.7-Flash)\n",
"- [Technical Blog](https://z.ai/blog/glm-4.7)\n",
"- [Technical Report (GLM-4.5)](https://arxiv.org/abs/2508.06471)\n",
"- [Z.ai API Platform](https://docs.z.ai/guides/llm/glm-4.7)\n",
"\n",
"**Model Highlights:**\n",
"- 30B-A3B Mixture of Experts (MoE) architecture\n",
"- 131,072 token context length\n",
"- Tool calling support\n",
"- MIT License\n",
"\n",
"**Prerequisites:**\n",
"- NVIDIA GPU with recent drivers (≥ 64 GB VRAM for BF16) and CUDA 12.x\n",
"- Python 3.10+\n",
"- TensorRT-LLM ([container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release) or pip install)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites & Environment\n",
"\n",
"Set up a containerized environment for TensorRT-LLM by running the following command in a terminal:\n",
"\n",
"```shell\n",
"docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n",
"```\n",
"\n",
"You now have TensorRT-LLM set up!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If pip not found\n",
"!python -m ensurepip --default-pip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install torch openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify GPU\n",
"\n",
"Check that CUDA is available and the GPU is detected correctly."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python: 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0]\n",
"CUDA available: True\n",
"Num GPUs: 8\n",
"GPU[0]: NVIDIA H100 80GB HBM3\n",
"GPU[1]: NVIDIA H100 80GB HBM3\n",
"GPU[2]: NVIDIA H100 80GB HBM3\n",
"GPU[3]: NVIDIA H100 80GB HBM3\n",
"GPU[4]: NVIDIA H100 80GB HBM3\n",
"GPU[5]: NVIDIA H100 80GB HBM3\n",
"GPU[6]: NVIDIA H100 80GB HBM3\n",
"GPU[7]: NVIDIA H100 80GB HBM3\n"
]
}
],
"source": [
"# Environment check\n",
"import sys\n",
"\n",
"import torch\n",
"\n",
"print(f\"Python: {sys.version}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"print(f\"Num GPUs: {torch.cuda.device_count()}\")\n",
"\n",
"if torch.cuda.is_available():\n",
" for i in range(torch.cuda.device_count()):\n",
" print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI-Compatible Server\n",
"\n",
"Start a local OpenAI-compatible server with TensorRT-LLM via the terminal, within the running docker container.\n",
"\n",
"Ensure that the following commands are executed from the docker terminal.\n",
"\n",
"Start with the GLM 4.7 Flash Yaml here: `examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the Model\n",
"\n",
"Launch the TensorRT-LLM server with GLM-4.7-Flash:\n",
"\n",
"```shell\n",
"trtllm-serve \"zai-org/GLM-4.7-Flash\" \\\n",
" --host 0.0.0.0 \\\n",
" --port 8000 \\\n",
" --backend _autodeploy \\\n",
" --trust_remote_code \\\n",
" --extra_llm_api_options examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Your server is now running!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use the API\n",
"\n",
"Use the OpenAI-compatible client to send requests to the TensorRT-LLM server."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"\n",
"# Setup client\n",
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
"API_KEY = \"null\"\n",
"client = OpenAI(base_url=BASE_URL, api_key=API_KEY)\n",
"\n",
"MODEL_ID = \"zai-org/GLM-4.7-Flash\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat Completion Example\n",
"==================================================\n",
"Response:\n",
"1. **Analyze the Request:** The user wants to know 15% of 85 and wants to see the reasoning behind the calculation.\n",
"\n",
"2. **Identify the Core Task:** Calculate $15\\% \\times 85$.\n",
"\n",
"3. **Determine the Mathematical Approach:** There are several ways to solve this:\n",
" * *Method 1: Fraction multiplication.* Convert 15% to a fraction ($\\frac{15}{100}$), then multiply by 85.\n",
" * *Method 2: Decimal multiplication.* Convert 15% to a decimal ($0.15$), then multiply by 85.\n",
" * *Method 3: Decomposition (Breaking it down).* $15\\% = 10\\% + 5\\%$.\n",
" * $10\\%$ of $85 = 8.5$\n",
" * $5\\%$ of $85 = \\frac{8.5}{2} = 4.25$\n",
" * Sum: $8.5 + 4.25 = 12.75$\n",
"\n",
"4. **Select the Best Approach for Explanation:** Method 3 is often easiest for a general audience to follow step-by-step because it avoids dealing with decimals until the end or simplifies large multiplications. Method 2 is the most direct standard school method. I will use Method 3 (Splitting 15% into 10% and 5%) as the primary reasoning because it is intuitive, but I might briefly mention the standard formula ($\\frac{\\text{percent}}{100} \\times \\text{number}$).\n",
"\n",
"5. **Execute the Calculation (Method 3):**\n",
" * Step 1: Find 10% of 85.\n",
" * Moving the decimal point one place to the left: 8.5.\n",
" * Step 2: Find 5% of 85.\n",
" * Since 5% is half of 10%, take half of 8.5.\n",
" * $8.5 / 2 = 4.25$.\n",
" * Step 3: Add them together.\n",
" * $8.5 + 4.25$.\n",
" * $8.5 + 4 = 12.5$.\n",
" * $12.5 + 0\n"
]
}
],
"source": [
"# Basic chat completion\n",
"print(\"Chat Completion Example\")\n",
"print(\"=\" * 50)\n",
"\n",
"response = client.chat.completions.create(\n",
" model=MODEL_ID,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n",
" ],\n",
" temperature=1,\n",
" top_p=0.95,\n",
" max_tokens=512,\n",
")\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Streaming response:\n",
"==================================================\n",
"1. **Analyze the Request:** The user is asking for the \"first 5 prime numbers\".\n",
"\n",
"2. **Define \"Prime Number\":** A prime number is a natural number greater than 1 that is not a product of two smaller natural numbers. In other words, it has exactly two distinct positive divisors: 1 and itself.\n",
"\n",
"3. **Identify the First Numbers:**\n",
" * Start checking from 1 (exclusive).\n",
" * Check 2: Divisors are 1 and 2. Prime. (1st)\n",
" * Check 3: Divisors are 1 and 3. Prime. (2nd)\n",
" * Check 4: Divisors are 1, 2, 4. Not prime (2 * 2).\n",
" * Check 5: Divisors are 1 and 5. Prime. (3rd)\n",
" * Check 6: Divisors are 1, 2, 3, 6. Not prime.\n",
" * Check 7: Divisors are 1 and 7. Prime. (4th)\n",
" * Check 8: Divisors are 1, 2, 4, 8. Not prime.\n",
" * Check 9: Divisors are 1, 3, 9. Not prime.\n",
" * Check 10: Divisors are 1, 2, 5, 10. Not prime.\n",
" * Check 11: Divisors are 1 and 11. Prime. (5th)\n",
"\n",
"4. **Compile the List:** 2, 3, 5, 7, 11.\n",
"\n",
"5. **Formulate the Output:** Present the list clearly.\n",
"\n",
"6. **Final Review:** Does this answer the user's prompt accurately? Yes.</think>The first 5 prime numbers are **2, 3, 5, 7, and 11**."
]
}
],
"source": [
"# Streaming chat completion\n",
"print(\"Streaming response:\")\n",
"print(\"=\" * 50)\n",
"\n",
"stream = client.chat.completions.create(\n",
" model=MODEL_ID,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n",
" ],\n",
" temperature=0.7,\n",
" max_tokens=1024,\n",
" stream=True,\n",
")\n",
"\n",
"for chunk in stream:\n",
" if chunk.choices[0].delta.content:\n",
" print(chunk.choices[0].delta.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation Parameters\n",
"\n",
"For optimal results, use the following parameters based on your task:\n",
"\n",
"**Default Settings (Most Tasks)**\n",
"- `temperature`: 1.0\n",
"- `top_p`: 0.95\n",
"- `max_tokens`: 131072\n",
"\n",
"**Agentic Tasks (SWE-bench, Terminal Bench)**\n",
"- `temperature`: 0.7\n",
"- `top_p`: 1.0\n",
"- `max_tokens`: 16384\n",
"\n",
"**Deterministic Tasks**\n",
"- `temperature`: 0\n",
"- `max_tokens`: 16384"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Additional Resources\n",
"\n",
"- [TensorRT-LLM Documentation](https://nvidia.github.io/TensorRT-LLM/)\n",
"- [AutoDeploy Guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n",
"- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)\n",
"- [Z.ai Discord Community](https://discord.gg/QR7SARHRxK)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,8 @@
compile_backend: torch-cudagraph
max_batch_size: 64
max_seq_len: 4096
enable_chunked_prefill: true
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
transforms:
fuse_nvfp4_moe:
allow_different_input_scales: true

View File

@ -223,3 +223,5 @@ models:
yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'llama4_maverick_lite.yaml']
- name: nvidia/NVIDIA-Nemotron-3-Super-120B-BF16-BF16KV-010726
yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml','super_v3.yaml']
- name: zai-org/GLM-4.7-Flash
yaml_extra: ['glm-4.7-flash.yaml']

View File

@ -162,15 +162,16 @@ transforms:
visualize_namespace:
stage: visualize
enabled: false
############################################################################################
###########################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
###########################################################################################
insert_cached_attention:
stage: cache_init
backend: flashinfer
insert_cached_mla_attention:
stage: cache_init
backend: MultiHeadLatentAttention
requires_shape_prop: true
backend: flashinfer_mla
insert_cached_ssm_attention:
stage: cache_init
backend: triton_ssm

View File

@ -5,38 +5,3 @@ All AutoDeploy custom operators follow the following naming convention:
`torch.ops.auto_deploy.<kernel_backend>_<op_category>_<op_name>`
The table below lists the operators ordered by their backend.
### Available Custom Operators
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion |
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache |
| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support |
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation |
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) |
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) |
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |

View File

@ -19,7 +19,6 @@ import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -242,299 +241,3 @@ def torch_attention_fake(
layout: str = "bnsd",
):
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
def update_kv_cache(
key_states: torch.Tensor,
value_states: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
slot_idx: torch.Tensor,
seq_start: torch.Tensor,
) -> torch.Tensor:
"""
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
This function can be used to build reference attention implementations that use KV cache.
"""
for idx in range(seq_len.shape[0]):
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=())
def fused_mla_ref(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
cache_loc: torch.Tensor,
seq_start: torch.Tensor,
k_cache: torch.Tensor, # caches
v_cache: torch.Tensor, # caches
freqs_cis: torch.Tensor,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
"""
Reference implementation for Fused MLA with KV cache support.
This implementation flattens the inputs and can be used as a reference to debug the triton kernels.
"""
# Compute parameters
bs, num_heads, q_len, qk_nope_head_dim = q_nope.shape
qk_rope_head_dim = q_pe.shape[-1]
v_head_dim = kv.shape[-1] - qk_nope_head_dim
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
# Flatten inputs
bs_view = (bs * q_len,)
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
q_nope = q_nope.transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
q_pe = q_pe.transpose(1, 2).clone().view(*bs_view, num_heads, qk_rope_head_dim).contiguous()
k_nope = k_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous()
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
if freqs_cis is not None:
cos_base = freqs_cis[0, ...]
sin_base = freqs_cis[1, ...]
for i in range(seq_len.shape[0]):
start = seq_start[i]
length = seq_len[i]
if q_len == 1:
idx = (input_pos[i] + length - 1).item()
pos_ids = torch.tensor(idx, device=cos_base.device)
else:
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
cos = cos_base[pos_ids] # [..., 1, head_dim]
sin = sin_base[pos_ids]
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
sin,
-2,
)
q_pe[start : start + length] = q_rot
k_pe[start : start + length] = k_rot
query_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) # [b*s,n,d]
query_states[..., :qk_nope_head_dim] = q_nope
query_states[..., qk_nope_head_dim:] = q_pe
key_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim)
key_states[..., :qk_nope_head_dim] = k_nope
key_states[..., qk_nope_head_dim:] = k_pe
# Update KV cache
update_kv_cache(
key_states, value_states, k_cache, v_cache, seq_len, input_pos, cache_loc, seq_start
)
# Compute attention
attn_outputs = []
for idx in range(seq_len.shape[0]):
# Get inputs from KV cache
k = k_cache[cache_loc[idx], : input_pos[idx] + seq_len[idx], :, :] # [kv_seq_len, n, d]
v = v_cache[cache_loc[idx], : input_pos[idx] + seq_len[idx], :, :] # [kv_seq_len, n, d]
# Generate attention mask
if q_len == 1:
# Generate phase - single token attention mask
attn_mask = torch.zeros(
1, input_pos[idx] + 1, device=query_states.device, dtype=query_states.dtype
)
else:
# Context phase - causal attention mask
temp_mask = torch.ones(
seq_len[idx],
input_pos[idx] + seq_len[idx],
dtype=torch.bool,
device=query_states.device,
).tril(diagonal=0)
attn_bias = torch.zeros(
seq_len[idx],
input_pos[idx] + seq_len[idx],
device=query_states.device,
dtype=query_states.dtype,
)
attn_mask = attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")).to(
query_states.device
)
# Compute attention weights
attn_weights = (
torch.matmul(
query_states[seq_start[idx] : seq_start[idx] + seq_len[idx], :, :].transpose(0, 1),
k.transpose(0, 1).transpose(1, 2),
)
* 1
/ math.sqrt(query_states.size(-1))
)
attn_weights = attn_weights + attn_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = torch.nn.functional.dropout(attn_weights, p=0.0, training=False)
attn_output = torch.matmul(attn_weights, v.transpose(0, 1))
attn_outputs.append(attn_output)
if q_len == 1:
attn_output = torch.stack(attn_outputs)
else:
attn_output = torch.cat(attn_outputs, dim=-2).unsqueeze(0)
return attn_output
@fused_mla_ref.register_fake
def fused_mla_ref_fake(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
cache_loc: torch.Tensor,
seq_start: torch.Tensor,
k_cache: torch.Tensor, # caches
v_cache: torch.Tensor, # caches
freqs_cis: torch.Tensor,
softmax_scale: Optional[float] = None,
):
"""Fake Fused MLA+Rope with KV cache support."""
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
return torch.empty_like(kv[..., -v_head_dim:])
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_fused_mla", mutates_args=())
def fused_mla(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""MultiHeadLatentAttention as implemented in DeepSeekV3Attention. This does not capture KV cache use/update."""
# Did not implement KV cache logic, since we would be writing our own custom op and inserting KV caches later
# Compute parameters
bs, num_heads, q_len, qk_nope_head_dim = q_nope.shape
qk_rope_head_dim = q_pe.shape[-1]
v_head_dim = kv.shape[-1] - qk_nope_head_dim
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
cos = cos[position_ids]
sin = sin[position_ids]
q_pe, k_pe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
query_states[:, :, :, :qk_nope_head_dim] = q_nope
query_states[:, :, :, qk_nope_head_dim:] = q_pe
key_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
key_states[:, :, :, :qk_nope_head_dim] = k_nope
key_states[:, :, :, qk_nope_head_dim:] = k_pe
# Use old logic
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale
if attn_weights.size() != (bs, num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bs, num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
assert attention_mask is not None
if attention_mask is not None:
if attention_mask.size() != (bs, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bs, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = nn.functional.dropout(attn_weights, p=0.0, training=False)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bs, num_heads, q_len, v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bs, num_heads, q_len, v_head_dim)}, but is"
f" {attn_output.size()}"
)
# We do not return attn_weights along with attn_output
return attn_output
@fused_mla.register_fake
def fused_mla(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
return torch.empty_like(kv[..., -v_head_dim:])
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_mla", mutates_args=())
def mla(
q_nope: torch.Tensor, # Down projected q_nope
q_pe: torch.Tensor, # q_pe after applying rope
kv: torch.Tensor, # compressed kv after passing through layernorm
pe: torch.Tensor, # k_pe after applying rope
attention_mask: torch.Tensor, # attention mask
softmax_scale: float, # softmax scale
) -> torch.Tensor:
"""
Reference implementation for MLA style attention that handles compressed kv.
"""
scores = (
torch.einsum("bhsc,btc->bsht", q_nope, kv) + torch.einsum("bhsr,btr->bsht", q_pe, pe)
) * softmax_scale
if attention_mask is not None:
scores += attention_mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(q_nope)
attn_output = torch.einsum("bsht,btc->bshc", scores, kv)
return attn_output
@mla.register_fake
def mla(
q_nope: torch.Tensor, # Down projected q_nope
q_pe: torch.Tensor, # q_pe after applying rope
kv: torch.Tensor, # compressed kv after passing through layernorm
k_pe: torch.Tensor, # k_pe after applying rope
attention_mask: torch.Tensor, # attention mask
softmax_scale: float, # softmax scale
) -> torch.Tensor:
"""MLA style attention that handles compressed kv."""
return torch.empty_like(q_nope)

View File

@ -35,7 +35,31 @@ from ..attention_interface import (
ResourceHandlerDict,
UnpagedResourceHandler,
)
from .torch_attention import repeat_kv, update_kv_cache
from .torch_attention import repeat_kv
def _update_kv_cache(
key_states: torch.Tensor,
value_states: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
slot_idx: torch.Tensor,
seq_start: torch.Tensor,
) -> torch.Tensor:
"""
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
This function can be used to build reference attention implementations that use KV cache.
"""
for idx in range(seq_len.shape[0]):
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
@ -149,7 +173,7 @@ def _torch_context_mha(
) -> None:
"""Context attention (multiple tokens, potentially multiple sequences) using existing torch functions."""
# Update KV cache first using existing function
update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
_update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
# Compute attention for each sequence
attn_outputs = []

View File

@ -1,24 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MLA (Multi-head Latent Attention) custom ops.
"""Multi-head Latent Attention operations.
This module provides Multi-head Latent Attention (MLA) implementations:
- mla: MLA operations and attention descriptor
Exports:
- TorchBackendMLAAttention: Attention descriptor for MLA (registered as "torch_mla")
- FlashInferMLAAttention: Attention descriptor for FlashInfer MLA (registered as "flashinfer_mla")
- torch_mla: Source op for MLA attention
- torch_backend_mla_with_cache: Cached backend op with FlashInfer-compatible cache
- flashinfer_mla_with_cache: Cached backend op using FlashInfer MLA kernels
"""
from .flashinfer_mla import FlashInferMLAAttention, flashinfer_mla_with_cache
from .torch_backend_mla import TorchBackendMLAAttention, torch_backend_mla_with_cache
from .torch_mla import torch_mla
__all__ = [
"mla",
"TorchBackendMLAAttention",
"FlashInferMLAAttention",
"torch_mla",
"torch_backend_mla_with_cache",
"flashinfer_mla_with_cache",
]

View File

@ -0,0 +1,957 @@
"""FlashInfer-based MLA (Multi-head Latent Attention) backend with paged caching.
This module provides:
- FlashInferMLAAttention: attention descriptor using FlashInfer MLA kernels
- flashinfer_mla_with_cache: cached backend op with paged KV cache
FlashInfer MLA uses:
- Regular prefill (input_pos == 0): BatchPrefillWithRaggedKVCacheWrapper with expanded K, V
- Chunked prefill (input_pos > 0): BatchMLAPagedAttentionWrapper with matrix absorption
- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache
FlashInfer MLA Cache Layout (two separate caches):
ckv_cache: [num_pages, page_size, kv_lora_rank]
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
- No num_heads dimension (MLA-specific optimization)
Reference: https://docs.flashinfer.ai/api/mla.html
"""
import math
from dataclasses import dataclass, fields
from math import prod
from typing import Dict, List, Literal, Optional, Tuple
import flashinfer
import torch
from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ...utils.cuda_graph import cuda_graph_state
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Constant,
MHACallable,
PrepareMetadataCallable,
PrepareMetadataHostCallable,
ResourceHandler,
ResourceHandlerDict,
SequenceInfo,
)
@dataclass
class MLADecodePlanParams:
"""Parameters that affect the FlashInfer MLA decode execution plan."""
num_heads: int
kv_lora_rank: int # head_dim_ckv
qk_rope_head_dim: int # head_dim_kpe
qk_nope_head_dim: int
v_head_dim: int
num_seq: int
page_size: int
q_dtype: torch.dtype
kv_dtype: torch.dtype
sm_scale: Optional[float] = None
def __hash__(self):
"""Convert all fields to a string representation and concatenate them."""
return hash("_".join([str(getattr(self, f.name)) for f in fields(self)]))
@dataclass
class MLAPrefillPlanParams:
"""Parameters that affect the FlashInfer MLA prefill execution plan."""
num_heads: int
num_kv_heads: int # For MLA with expanded KV, same as num_heads
head_dim_qk: int # qk_nope_head_dim + qk_rope_head_dim
head_dim_vo: int # v_head_dim (value/output head dimension)
num_seq: int
q_dtype: torch.dtype
kv_dtype: torch.dtype
sm_scale: Optional[float] = None
def __hash__(self):
"""Convert all fields to a string representation and concatenate them."""
return hash("_".join([str(getattr(self, f.name)) for f in fields(self)]))
class _FlashInferMLAPlanner:
"""A class interface to handle FlashInfer MLA-related planning/wrapping operations.
For MLA attention:
- Regular prefill uses BatchPrefillWithRaggedKVCacheWrapper with expanded K, V tensors
- Chunked prefill uses BatchMLAPagedAttentionWrapper with matrix absorption (same as decode)
- Decode uses BatchMLAPagedAttentionWrapper with paged compressed KV cache
"""
workspace_buffer: Optional[torch.Tensor]
prefill_wrapper: Optional[flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper]
decode_wrapper: Optional["flashinfer.mla.BatchMLAPagedAttentionWrapper"]
# Separate wrapper for chunked/incremental prefill (uses same kernel as decode but different planning)
chunked_prefill_wrapper: Optional["flashinfer.mla.BatchMLAPagedAttentionWrapper"]
cached_cuda_graph_decode_wrappers: Dict[
MLADecodePlanParams, "flashinfer.mla.BatchMLAPagedAttentionWrapper"
]
plan_params_prefill: Optional[MLAPrefillPlanParams]
plan_params_decode: Optional[MLADecodePlanParams]
plan_params_chunked_prefill: Optional[MLADecodePlanParams]
kv_layout: Literal["NHD", "HND"] = "NHD"
def __init__(self):
self.workspace_buffer = None
self.prefill_wrapper = None
self.decode_wrapper = None
self.chunked_prefill_wrapper = None
self.cached_cuda_graph_decode_wrappers = {}
self.plan_params_prefill = None
self.plan_params_decode = None
self.plan_params_chunked_prefill = None
def _init_decode_wrapper(
self,
use_cuda_graph: bool = False,
qo_indptr: Optional[torch.Tensor] = None,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_len_arr: Optional[torch.Tensor] = None,
):
assert self.workspace_buffer is not None
if use_cuda_graph:
return flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=kv_len_arr,
)
else:
return flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=False,
)
def reset(self, device: torch.device) -> None:
self.plan_params_prefill = None
self.plan_params_decode = None
self.plan_params_chunked_prefill = None
if isinstance(self.workspace_buffer, torch.Tensor):
return
self.__init__() # reset all state
# NOTE: avoid OOM for many cudagraphs
self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8)
# Prefill uses BatchPrefillWithRaggedKVCacheWrapper with expanded K, V
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer,
self.kv_layout,
)
# Decode uses BatchMLAPagedAttentionWrapper with paged compressed KV cache
self.decode_wrapper = self._init_decode_wrapper()
# Chunked prefill uses same kernel as decode but with variable-length queries
self.chunked_prefill_wrapper = self._init_decode_wrapper()
def plan_prefill(
self,
qo_indptr_host: torch.Tensor,
kv_indptr_host: torch.Tensor,
plan_params: MLAPrefillPlanParams,
) -> flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper:
"""Plan prefill using BatchPrefillWithRaggedKVCacheWrapper.
For MLA prefill, we expand compressed_kv to get full K, V tensors
and use standard ragged KV cache attention with causal masking.
Args:
qo_indptr_host: Cumulative query/output lengths on host.
kv_indptr_host: Cumulative key/value lengths on host.
plan_params: Parameters for planning (hashable, no tensors).
"""
if plan_params != self.plan_params_prefill:
self.prefill_wrapper.plan(
qo_indptr_host,
kv_indptr_host,
plan_params.num_heads,
plan_params.num_kv_heads,
plan_params.head_dim_qk,
head_dim_vo=plan_params.head_dim_vo,
use_fp16_qk_reduction=False,
causal=True,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
)
self.plan_params_prefill = plan_params
return self.prefill_wrapper
def _plan_mla_wrapper(
self,
wrapper: "flashinfer.mla.BatchMLAPagedAttentionWrapper",
qo_indptr: torch.Tensor,
kv_page_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
kv_last_page_len: torch.Tensor,
plan_params: MLADecodePlanParams,
):
"""Helper to plan a BatchMLAPagedAttentionWrapper."""
# Compute actual KV lengths from paging metadata:
# kv_len = (num_pages - 1) * page_size + last_page_len
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
wrapper.plan(
qo_indptr,
kv_page_indptr,
kv_page_indices,
kv_len_arr,
plan_params.num_heads,
plan_params.kv_lora_rank, # head_dim_ckv
plan_params.qk_rope_head_dim, # head_dim_kpe
plan_params.page_size,
causal=True,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
)
def plan_decode(
self,
kv_page_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
kv_last_page_len: torch.Tensor,
plan_params: MLADecodePlanParams,
) -> "flashinfer.mla.BatchMLAPagedAttentionWrapper":
"""Plan decode using BatchMLAPagedAttentionWrapper.
For MLA decode, we use the paged compressed KV cache with
FlashInfer's optimized MLA kernels. Each sequence generates 1 token.
Args:
kv_page_indptr: Cumulative page counts [batch_size + 1].
kv_page_indices: Page indices for the KV cache.
kv_last_page_len: Length of the last page per sequence.
plan_params: Parameters for planning.
"""
# Decode qo_indptr: [0, 1, 2, ..., batch_size] (1 token per sequence)
batch_size = kv_page_indptr.shape[0] - 1
qo_indptr = torch.arange(batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32)
# Compute kv_len_arr for CUDA graph wrapper initialization
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
if (
cuda_graph_state.in_warm_up()
and plan_params not in self.cached_cuda_graph_decode_wrappers
):
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
# Pass the buffer tensors to the wrapper for use_cuda_graph=True
wrapper = self._init_decode_wrapper(
use_cuda_graph=True,
qo_indptr=qo_indptr,
kv_indptr=kv_page_indptr,
kv_indices=kv_page_indices,
kv_len_arr=kv_len_arr,
)
self.cached_cuda_graph_decode_wrappers[plan_params] = wrapper
self._plan_mla_wrapper(
wrapper, qo_indptr, kv_page_indptr, kv_page_indices, kv_last_page_len, plan_params
)
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
return wrapper
# Re-plan if plan_params changed
if plan_params != self.plan_params_decode:
self._plan_mla_wrapper(
self.decode_wrapper,
qo_indptr,
kv_page_indptr,
kv_page_indices,
kv_last_page_len,
plan_params,
)
self.plan_params_decode = plan_params
return self.decode_wrapper
def plan_chunked_prefill(
self,
qo_indptr: torch.Tensor,
kv_page_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
kv_last_page_len: torch.Tensor,
plan_params: MLADecodePlanParams,
) -> "flashinfer.mla.BatchMLAPagedAttentionWrapper":
"""Plan chunked/incremental prefill using BatchMLAPagedAttentionWrapper.
For chunked prefill (input_pos > 0), we use the same kernel as decode but with
variable-length queries. Each sequence can have multiple tokens.
Args:
qo_indptr: Cumulative query lengths [batch_size + 1].
kv_page_indptr: Cumulative page counts [batch_size + 1].
kv_page_indices: Page indices for the KV cache.
kv_last_page_len: Length of the last page per sequence.
plan_params: Parameters for planning.
"""
# Re-plan if plan_params changed
if plan_params != self.plan_params_chunked_prefill:
self._plan_mla_wrapper(
self.chunked_prefill_wrapper,
qo_indptr,
kv_page_indptr,
kv_page_indices,
kv_last_page_len,
plan_params,
)
self.plan_params_chunked_prefill = plan_params
return self.chunked_prefill_wrapper
def plan_generate_only(
self,
num_seq: int,
cu_num_pages: torch.Tensor,
cache_loc: torch.Tensor,
last_page_len: torch.Tensor,
):
"""Plan decode-only batches for cached CUDA graph wrappers.
This is called from the host-side preparation function to plan
the decode wrappers for decode-only batches before the actual
attention op is invoked.
Args:
num_seq: Number of sequences in the decode batch.
cu_num_pages: Cumulative page counts, already sliced to [: num_seq + 1].
cache_loc: Page indices for the KV cache.
last_page_len: Length of the last page per sequence, already sliced to [:num_seq].
"""
for plan_params in self.cached_cuda_graph_decode_wrappers:
if plan_params.num_seq == num_seq:
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
# For a pure decode batch, qo_indptr is just [0, 1, 2, ..., batch_size]
qo_indptr = torch.arange(num_seq + 1, device=cu_num_pages.device, dtype=torch.int32)
# Compute actual KV lengths from paging metadata:
# kv_len = (num_pages - 1) * page_size + last_page_len
num_pages_per_seq = cu_num_pages[1:] - cu_num_pages[:-1]
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + last_page_len
wrapper.plan(
qo_indptr,
cu_num_pages, # kv_page_indptr
cache_loc, # kv_page_indices
kv_len_arr,
plan_params.num_heads,
plan_params.kv_lora_rank, # head_dim_ckv
plan_params.qk_rope_head_dim, # head_dim_kpe
plan_params.page_size,
causal=True,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
)
_GlobalFlashInferMLAPlanner = _FlashInferMLAPlanner()
@torch.library.custom_op("auto_deploy::flashinfer_mla_prepare_metadata", mutates_args=())
def prepare_flashinfer_mla_metadata(
position_ids: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
seq_len_with_cache: torch.Tensor,
) -> List[torch.Tensor]:
"""Prepare metadata for FlashInfer MLA attention.
This prepares batch_indices and positions for cache appends, similar to
the standard FlashInfer attention preparation.
"""
# retrieve host-side metadata
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_tokens = num_prefill_tokens + num_decode
_GlobalFlashInferMLAPlanner.reset(position_ids.device)
qo_indptr = cu_seqlen[: num_seq + 1]
# Compute batch_indices and positions for cache appends
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr, seq_len_with_cache[:num_seq], num_tokens
)
return batch_indices, positions
@prepare_flashinfer_mla_metadata.register_fake
def prepare_flashinfer_mla_metadata_fake(
position_ids: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
seq_len_with_cache: torch.Tensor,
):
num_tokens = position_ids.shape[0] * position_ids.shape[1]
return (
torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), # batch_indices
torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), # positions
)
def prepare_flashinfer_mla_metadata_host(
batch_info_host: torch.Tensor,
cu_num_pages_host: torch.Tensor,
cache_loc_host: torch.Tensor,
last_page_len_host: torch.Tensor,
) -> None:
"""Host-side preparation for FlashInfer MLA attention.
For decode-only batches, this function pre-plans the cached CUDA graph
wrappers to avoid planning during graph capture/replay.
"""
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
if num_prefill == 0:
_GlobalFlashInferMLAPlanner.plan_generate_only(
num_decode,
cu_num_pages_host[: num_decode + 1],
cache_loc_host,
last_page_len_host[:num_decode],
)
@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=())
def flashinfer_mla_with_cache(
# 5 tensor args (matching torch_mla source op)
q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim]
q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim]
compressed_kv: torch.Tensor, # [B, S, kv_lora_rank]
kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim]
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# Standard paged metadata
batch_info_host: torch.Tensor,
cu_seqlen_host: torch.Tensor,
cu_num_pages: torch.Tensor,
cu_num_pages_host: torch.Tensor,
cache_loc: torch.Tensor,
last_page_len: torch.Tensor,
last_page_len_host: torch.Tensor,
seq_len_with_cache_host: torch.Tensor,
# Extra FlashInfer metadata
flashinfer_batch_indices: torch.Tensor,
flashinfer_positions: torch.Tensor,
# Paged caches (two separate caches)
ckv_cache: torch.Tensor, # [num_pages, page_size, kv_lora_rank]
kpe_cache: torch.Tensor, # [num_pages, page_size, qk_rope_head_dim]
# Constants
scale: Optional[float],
kv_lora_rank: int,
) -> torch.Tensor:
"""FlashInfer MLA attention with paged cache.
Uses FlashInfer's optimized kernels:
- Prefill: BatchPrefillWithRaggedKVCacheWrapper with expanded K, V tensors
- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache
FlashInfer MLA Cache Layout (two separate caches):
ckv_cache: [num_pages, page_size, kv_lora_rank]
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
Args:
q_nope: Query non-positional component [B, S, N, qk_nope_head_dim]
q_pe: Query positional component [B, S, N, qk_rope_head_dim]
compressed_kv: Compressed KV latent [B, S, kv_lora_rank]
kpe: Key positional encoding [B, S, 1, qk_rope_head_dim]
kv_b_proj_weight: Projection weight [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
(metadata args): Standard paged attention metadata
ckv_cache: Paged cache for compressed KV
kpe_cache: Paged cache for key positional encoding
scale: Softmax scale factor
kv_lora_rank: Rank of compressed KV
Returns:
Attention output [B, S, N, v_head_dim]
"""
# Get dimensions
b, s = q_nope.shape[:2]
num_heads = q_nope.shape[2]
qk_nope_head_dim = q_nope.shape[3]
qk_rope_head_dim = q_pe.shape[3]
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# Infer v_head_dim from kv_b_proj_weight
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads
v_head_dim = kv_head_dim - qk_nope_head_dim
# Get batch info
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
# Set scale
if scale is None:
scale = 1.0 / math.sqrt(qk_head_dim)
page_size = ckv_cache.shape[1]
# Flatten inputs to [total_tokens, ...] format
bs = b * s
q_nope_flat = q_nope.contiguous().view(bs, num_heads, qk_nope_head_dim)
q_pe_flat = q_pe.contiguous().view(bs, num_heads, qk_rope_head_dim)
compressed_kv_flat = compressed_kv.contiguous().view(bs, kv_lora_rank)
kpe_flat = kpe.contiguous().view(bs, qk_rope_head_dim)
# Convert cache dtype if needed
if ckv_cache.dtype == torch.float8_e4m3fn:
compressed_kv_flat = compressed_kv_flat.to(torch.float8_e4m3fn)
kpe_flat = kpe_flat.to(torch.float8_e4m3fn)
# Append to paged cache using FlashInfer's append function
# Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager
flashinfer.page.append_paged_mla_kv_cache(
compressed_kv_flat,
kpe_flat,
flashinfer_batch_indices,
flashinfer_positions,
ckv_cache,
kpe_cache,
cache_loc,
cu_num_pages[: num_seq + 1],
last_page_len[:num_seq],
)
# Pre-allocate output
if num_prefill > 0 and num_decode > 0:
y = torch.empty(bs, num_heads, v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
else:
y = None
# =========================================================================
# PREFILL phase: Use BatchPrefillWithRaggedKVCacheWrapper for regular prefill
# or BatchMLAPagedAttentionWrapper for chunked prefill
# =========================================================================
if num_prefill > 0:
q_nope_prefill = q_nope_flat[:num_prefill_tokens]
q_pe_prefill = q_pe_flat[:num_prefill_tokens]
compressed_kv_prefill = compressed_kv_flat[:num_prefill_tokens]
kpe_prefill = kpe_flat[:num_prefill_tokens]
# Check if any prefill sequence has cached tokens (chunked prefill)
# seq_len_with_cache > current_seq_len means there are cached tokens
q_lens = cu_seqlen_host[1 : num_prefill + 1] - cu_seqlen_host[:num_prefill]
kv_lens = seq_len_with_cache_host[:num_prefill]
is_chunked_prefill = (kv_lens > q_lens).any().item()
if is_chunked_prefill:
# =================================================================
# CHUNKED PREFILL: Use BatchMLAPagedAttentionWrapper with absorption
# Same approach as decode, but with variable-length Q sequences
# =================================================================
# Extract W_kn and W_v from kv_b_proj_weight
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank]
kv_b_proj_reshaped = kv_b_proj_weight.view(
num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank
)
# W_kn: [N, qk_nope_head_dim, kv_lora_rank]
w_kn = kv_b_proj_reshaped[:, :qk_nope_head_dim, :]
# W_v: [N, v_head_dim, kv_lora_rank]
w_v = kv_b_proj_reshaped[:, qk_nope_head_dim:, :]
# Absorb W_kn into q_nope:
# q_nope_prefill: [num_prefill_tokens, N, qk_nope_head_dim]
# w_kn: [N, qk_nope_head_dim, kv_lora_rank]
# q_nope_absorbed: [num_prefill_tokens, N, kv_lora_rank]
q_nope_absorbed = torch.einsum("bnd,ndk->bnk", q_nope_prefill, w_kn).contiguous()
# Build qo_indptr for variable-length prefill sequences
qo_indptr = cu_seqlen_host[: num_prefill + 1].to(
device=cu_num_pages.device, dtype=torch.int32
)
pp_chunked = MLADecodePlanParams(
num_heads=num_heads,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
v_head_dim=v_head_dim,
num_seq=num_prefill,
page_size=page_size,
q_dtype=q_nope.dtype,
kv_dtype=ckv_cache.dtype,
sm_scale=scale,
)
wrapper_chunked = _GlobalFlashInferMLAPlanner.plan_chunked_prefill(
qo_indptr=qo_indptr,
kv_page_indptr=cu_num_pages[: num_prefill + 1],
kv_page_indices=cache_loc,
kv_last_page_len=last_page_len[:num_prefill],
plan_params=pp_chunked,
)
# Run paged MLA attention in compressed space
y_prefill_compressed = wrapper_chunked.run(
q_nope_absorbed,
q_pe_prefill,
ckv_cache,
kpe_cache,
)
# Project output back from latent space to v_head_dim
# y_prefill_compressed: [num_prefill_tokens, N, kv_lora_rank]
# w_v: [N, v_head_dim, kv_lora_rank]
# y_prefill: [num_prefill_tokens, N, v_head_dim]
y_prefill = torch.einsum("bnk,nvk->bnv", y_prefill_compressed, w_v)
else:
# =================================================================
# REGULAR PREFILL: Use BatchPrefillWithRaggedKVCacheWrapper
# Expand compressed_kv to K, V and use ragged attention
# =================================================================
# Expand compressed_kv using kv_b_proj_weight to get k_nope and v
# compressed_kv: [num_prefill_tokens, kv_lora_rank]
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# kv_expanded: [num_prefill_tokens, N * (qk_nope_head_dim + v_head_dim)]
kv_expanded = torch.matmul(compressed_kv_prefill, kv_b_proj_weight.t())
kv_expanded = kv_expanded.view(
num_prefill_tokens, num_heads, qk_nope_head_dim + v_head_dim
)
# Split into k_nope and v
k_nope_prefill = kv_expanded[:, :, :qk_nope_head_dim] # [tokens, N, qk_nope_head_dim]
v_prefill = kv_expanded[:, :, qk_nope_head_dim:].contiguous() # [tokens, N, v_head_dim]
# Expand kpe to all heads: [tokens, qk_rope_head_dim] -> [tokens, N, qk_rope_head_dim]
kpe_expanded = kpe_prefill.unsqueeze(1).expand(-1, num_heads, -1).contiguous()
# Concatenate to form full Q and K
# Q: [tokens, N, qk_head_dim]
q_prefill = torch.cat([q_nope_prefill, q_pe_prefill], dim=-1).contiguous()
# K: [tokens, N, qk_head_dim]
k_prefill = torch.cat([k_nope_prefill, kpe_expanded], dim=-1).contiguous()
pp_prefill = MLAPrefillPlanParams(
num_heads=num_heads,
num_kv_heads=num_heads, # For MLA with expanded KV, same as num_heads
head_dim_qk=qk_head_dim,
head_dim_vo=v_head_dim,
num_seq=num_prefill,
q_dtype=q_nope.dtype,
kv_dtype=k_prefill.dtype,
sm_scale=scale,
)
wrapper_prefill = _GlobalFlashInferMLAPlanner.plan_prefill(
qo_indptr_host=cu_seqlen_host[: num_prefill + 1],
kv_indptr_host=cu_seqlen_host[: num_prefill + 1], # Same as qo for self-attention
plan_params=pp_prefill,
)
y_prefill = wrapper_prefill.run(
q_prefill,
k_prefill,
v_prefill,
)
if y is not None:
y[:num_prefill_tokens] = y_prefill
else:
y = y_prefill
# =========================================================================
# DECODE phase: Use BatchMLAPagedAttentionWrapper with paged compressed KV
# =========================================================================
if num_decode > 0:
q_nope_decode = q_nope_flat[num_prefill_tokens:num_total_tokens].contiguous()
q_pe_decode = q_pe_flat[num_prefill_tokens:num_total_tokens].contiguous()
# FlashInfer MLA operates in the compressed latent space.
# We need to:
# 1. Absorb W_kn (K-nope projection) into q_nope
# 2. Run attention in compressed space
# 3. Project output back using W_v
# Extract W_kn and W_v from kv_b_proj_weight
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank]
kv_b_proj_reshaped = kv_b_proj_weight.view(
num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank
)
# W_kn: [N, qk_nope_head_dim, kv_lora_rank]
w_kn = kv_b_proj_reshaped[:, :qk_nope_head_dim, :]
# W_v: [N, v_head_dim, kv_lora_rank]
w_v = kv_b_proj_reshaped[:, qk_nope_head_dim:, :]
# Absorb W_kn into q_nope:
# q_nope_decode: [num_decode, N, qk_nope_head_dim]
# w_kn: [N, qk_nope_head_dim, kv_lora_rank]
# q_nope_absorbed: [num_decode, N, kv_lora_rank]
q_nope_absorbed = torch.einsum("bnd,ndk->bnk", q_nope_decode, w_kn).contiguous()
pp_decode = MLADecodePlanParams(
num_heads=num_heads,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
v_head_dim=v_head_dim,
num_seq=num_decode,
page_size=page_size,
q_dtype=q_nope.dtype,
kv_dtype=ckv_cache.dtype,
sm_scale=scale,
)
wrapper_decode = _GlobalFlashInferMLAPlanner.plan_decode(
kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1],
kv_page_indices=cache_loc,
kv_last_page_len=last_page_len[num_prefill:num_seq],
plan_params=pp_decode,
)
# Run attention in compressed space
# y_decode_compressed: [num_decode, N, kv_lora_rank]
# Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager
y_decode_compressed = wrapper_decode.run(
q_nope_absorbed,
q_pe_decode,
ckv_cache,
kpe_cache,
)
# Project output back from latent space to v_head_dim
# y_decode_compressed: [num_decode, N, kv_lora_rank]
# w_v: [N, v_head_dim, kv_lora_rank]
# y_decode: [num_decode, N, v_head_dim]
y_decode = torch.einsum("bnk,nvk->bnv", y_decode_compressed, w_v)
if y is not None:
y[num_prefill_tokens:num_total_tokens] = y_decode
else:
y = y_decode
return y.view(b, s, num_heads, v_head_dim)
@flashinfer_mla_with_cache.register_fake
def flashinfer_mla_with_cache_fake(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
compressed_kv: torch.Tensor,
kpe: torch.Tensor,
kv_b_proj_weight: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen_host: torch.Tensor,
cu_num_pages: torch.Tensor,
cu_num_pages_host: torch.Tensor,
cache_loc: torch.Tensor,
last_page_len: torch.Tensor,
last_page_len_host: torch.Tensor,
seq_len_with_cache_host: torch.Tensor,
flashinfer_batch_indices: torch.Tensor,
flashinfer_positions: torch.Tensor,
ckv_cache: torch.Tensor,
kpe_cache: torch.Tensor,
scale: Optional[float],
kv_lora_rank: int,
) -> torch.Tensor:
"""Fake implementation for flashinfer_mla_with_cache."""
num_heads = q_nope.shape[2]
qk_nope_head_dim = q_nope.shape[-1]
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads
v_head_dim = kv_head_dim - qk_nope_head_dim
return q_nope.new_empty(
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
).contiguous()
class MLAPagedResourceHandler(ResourceHandler):
"""Handler for paged resources in MLA that require per-layer contiguous memory.
While MLA uses paged caching, the underlying flashinfer MLA kernel uses a uint32_t to track the
strides for the cache. The KVCacheManager will allocate a contiguous tensor for the cache
across all layers with dim 0 representing the layer index. Hence, the per-layer cache has very
large strides to jump between pages which causes overflow in the MLA kernel that uses uint32_t
for strides.
We use a separate handler for this purpose to avoid registering the cache with the
KVCacheManager and instead rely on local allocation.
"""
@property
def is_paged(self) -> bool:
"""Whether the resource is paged."""
return True
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
"""Initialize the ContiguousPagedResourceHandler.
Args:
token_shape: The shape of the resource per token.
dtype: The dtype of the resource.
"""
self.token_shape = token_shape
self.dtype = dtype
def _get_bytes_per_token(self) -> int:
"""The size of the resource per token in bytes."""
return prod(self.token_shape) * self.dtype.itemsize
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
"""Allocate contiguous paged resource.
Args:
sequence_info: SequenceInfo with device and page information.
Returns:
Contiguous tensor of shape [num_blocks, tokens_per_block, *token_shape].
"""
return torch.empty(
sequence_info.num_blocks,
sequence_info.tokens_per_block,
*self.token_shape,
device=sequence_info.device,
dtype=self.dtype,
)
@AttentionRegistry.register("flashinfer_mla")
class FlashInferMLAAttention(AttentionDescriptor):
"""Attention descriptor for FlashInfer-based MLA with paged cache.
This descriptor uses FlashInfer's optimized MLA kernels:
- Source op: torch_mla (same as torch_mla backend)
- Cached op: flashinfer_mla_with_cache with paged cache
FlashInfer MLA Cache Layout (two separate caches):
ckv_cache: [num_pages, page_size, kv_lora_rank]
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
- No num_heads dimension (MLA-specific optimization)
Reference: https://docs.flashinfer.ai/api/mla.html
"""
@classmethod
def _get_planner(cls) -> _FlashInferMLAPlanner:
return _GlobalFlashInferMLAPlanner
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the backend."""
return "bsnd"
@classmethod
def get_num_qkv_args(cls) -> int:
"""Get the number of tensor arguments expected by the source op."""
return 5 # q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.auto_deploy.torch_mla
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
"""Get the cached attention op."""
return torch.ops.auto_deploy.flashinfer_mla_with_cache.default
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
"""Get the list of standard metadata arguments for paged attention."""
return [
"batch_info_host",
"cu_seqlen_host",
"cu_num_pages",
"cu_num_pages_host",
"cache_loc",
"last_page_len",
"last_page_len_host",
"seq_len_with_cache_host",
]
@classmethod
def get_prepare_extra_metadata_info(
cls, any_source_attn_node: Node
) -> Tuple[Optional[PrepareMetadataCallable], int, List[Constant]]:
"""Get the prepare_metadata op for FlashInfer MLA."""
return (torch.ops.auto_deploy.flashinfer_mla_prepare_metadata.default, 2, [])
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
"""Get cache initializers using FlashInfer MLA paged cache layout.
Creates two separate paged caches:
- ckv_cache: [num_pages, page_size, kv_lora_rank]
- kpe_cache: [num_pages, page_size, qk_rope_head_dim]
"""
# Extract dimensions from source node args
# torch_mla signature: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, ...
compressed_kv_fake: FakeTensor = source_attn_node.args[2].meta["val"]
kpe_fake: FakeTensor = source_attn_node.args[3].meta["val"]
# Get dimensions
# compressed_kv: [B, S, kv_lora_rank]
# kpe: [B, S, 1, qk_rope_head_dim]
kv_lora_rank = compressed_kv_fake.shape[-1]
qk_rope_head_dim = kpe_fake.shape[-1]
# flashinfer mla requires kv_lora_rank to be 512 and qk_rope_head_dim to be 64
if kv_lora_rank != 512:
raise ValueError("kv_lora_rank must be 512 for flashinfer_mla")
if qk_rope_head_dim != 64:
raise ValueError("qk_rope_head_dim must be 64 for flashinfer_mla")
cache_dtype = cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype)
# FlashInfer MLA uses two separate paged caches with no num_heads dimension
return {
"ckv_cache": MLAPagedResourceHandler(
kv_lora_rank,
dtype=cache_dtype,
),
"kpe_cache": MLAPagedResourceHandler(
qk_rope_head_dim,
dtype=cache_dtype,
),
}
@classmethod
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
"""Get function for host-side preparation."""
return prepare_flashinfer_mla_metadata_host
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
"""Get constants to pass to the cached attention op."""
# Extract kv_lora_rank for cache operations
compressed_kv_fake = source_attn_node.args[2].meta["val"]
kv_lora_rank = compressed_kv_fake.shape[-1]
# Get scale from kwargs
scale = source_attn_node.kwargs.get("scale", None)
return [scale, kv_lora_rank]

View File

@ -1,280 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom ops for MultiHead Latent attention."""
import math
from typing import List, Optional, Union
import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ..attention.triton_attention import _decode_attention, _prefill_attention
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
MHACallable,
ResourceHandlerDict,
UnpagedResourceHandler,
)
Constant = Union[int, float, str, None]
def _precompute_inv_freq(
max_seq_len: int, head_dim: int, rope_theta: float, device: torch.device
) -> torch.Tensor:
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)
)
t = torch.arange(max_seq_len, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq.to(t.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)])
return cos_sin_stacked
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
)
def fused_flattened_mla_with_cache(
# Q, K, V
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
# STANDARD METADATA
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA
#
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
# CONSTANTS
softmax_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> torch.Tensor:
"""Flattened & fused MLA with cache with triton kernels."""
# b, s info
# NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
# Generally speaking, we expect one of two cases here:
# 1. b > 0, s==1: this indicates a generate-only batch of tokens.
# 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
# and number of tokens per sequence are encoded in seq_len and seq_start.
# check for sequence info and truncate metadata
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
input_pos = input_pos[:num_seq]
cache_loc = cache_loc[:num_seq]
seq_start = cu_seqlen[:num_seq]
# Get parameters
b, num_heads, s, qk_nope_head_dim = q_nope.shape
qk_rope_head_dim = q_pe.shape[-1]
v_head_dim = kv.shape[-1] - qk_nope_head_dim
# Get k_nope and value_states
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
# Flatten inputs
if s == 1:
bs_view = (b, s)
else:
bs_view = (b * s,)
# TODO(suyogg): do something about all these clones, transposes, and contiguous-es
q_nope = q_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
q_pe = q_pe.clone().transpose(1, 2).view(*bs_view, num_heads, qk_rope_head_dim).contiguous()
k_nope = k_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous()
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
# Apply RoPE
if rope_theta is not None:
max_seq_len = (input_pos + seq_len).max().item()
cos_sin_stacked = _precompute_inv_freq(
max_seq_len, qk_rope_head_dim, rope_theta, q_pe.device
)
# Extract cos and sin from freqs_cis
cos_base = cos_sin_stacked[0, ...]
sin_base = cos_sin_stacked[1, ...]
# TODO: Use triton kernels for RoPE
# TODO: Add yarn support
for i in range(seq_len.shape[0]):
start = seq_start[i]
length = seq_len[i]
# build position_ids
if s == 1:
idx = (input_pos[i] + length - 1).item()
pos_ids = torch.tensor(idx, device=cos_base.device)
else:
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
cos = cos_base[pos_ids] # [..., 1, head_dim]
sin = sin_base[pos_ids]
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
sin,
-2,
)
q_pe[start : start + length] = q_rot
k_pe[start : start + length] = k_rot
# Create query_states, key_states
query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d]
key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d]
# Compute scale if not provided
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(qk_head_dim)
# Compute attention
y = torch.empty_like(value_states)
if s == 1:
# generate-only phase (decode)
_decode_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
k_cache,
v_cache,
cache_loc,
input_pos,
scale,
y,
)
else:
# mixed context + generate phase (prefill)
_prefill_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
k_cache,
v_cache,
input_pos,
cache_loc,
seq_len,
seq_start,
scale,
y,
)
y = (
y.view(b, s, -1, v_head_dim).transpose(1, 2).contiguous()
) # BNSD format as expected by the callsite.
return y
@fused_flattened_mla_with_cache.register_fake
def fused_flattened_mla_with_cache_fake(
# Q, K, V
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
# STANDARD METADATA
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA
#
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
# CONSTANTS
softmax_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
return torch.empty_like(kv[..., -v_head_dim:])
@AttentionRegistry.register("MultiHeadLatentAttention")
class MultiHeadLatentAttention(AttentionDescriptor):
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the backend."""
return "bnsd"
@classmethod
def get_num_qkv_args(cls) -> int:
"""Get the number of qkv arguments expected by the source op."""
return 4
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache.default
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
q_nope_fake = source_attn_node.args[0].meta["val"]
q_pe_fake = source_attn_node.args[1].meta["val"]
kv_fake = source_attn_node.args[2].meta["val"]
num_kv_heads = kv_fake.shape[1]
head_dim = q_nope_fake.shape[-1]
rope_dim = q_pe_fake.shape[-1]
return {
"k_cache": UnpagedResourceHandler(
num_kv_heads,
head_dim + rope_dim,
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
),
"v_cache": UnpagedResourceHandler(
num_kv_heads,
head_dim,
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
),
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
softmax_scale = None
rope_theta = 10000.0 # TODO: remove once MLA is unfused
return [softmax_scale, rope_theta]

View File

@ -0,0 +1,518 @@
"""Custom ops for MultiHead Latent Attention (MLA) with FlashInfer-compatible cache.
This module provides:
- torch_cached_mla_with_cache: cached backend op
- TorchBackendMLAAttention: attention descriptor
FlashInfer MLA Cache Layout:
mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
- No num_heads dimension (MLA-specific optimization)
- compressed_kv_cached = mla_cache[:, :, :kv_lora_rank] (zero-copy slice)
- kpe_cached = mla_cache[:, :, kv_lora_rank:] (zero-copy slice)
The implementation uses:
- Prefill: Expand compressed_kv -> full K, V, compute normal attention
- Generate: Weight absorption for efficiency (Q @ W^T instead of expanding cached KV)
Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout
"""
import math
from typing import List, Optional
import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Constant,
MHACallable,
ResourceHandlerDict,
UnpagedResourceHandler,
)
def _update_mla_cache(
compressed_kv: torch.Tensor, # [total_tokens, kv_lora_rank]
kpe: torch.Tensor, # [total_tokens, qk_rope_head_dim]
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
seq_len: torch.Tensor,
input_pos: torch.Tensor,
slot_idx: torch.Tensor,
seq_start: torch.Tensor,
kv_lora_rank: int,
) -> None:
"""Update FlashInfer MLA cache with compressed_kv and kpe values.
FlashInfer MLA cache layout: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
- First kv_lora_rank dims: compressed KV latent (before kv_b_proj)
- Last qk_rope_head_dim dims: key positional encoding
"""
for idx in range(seq_len.shape[0]):
start = seq_start[idx].item()
length = seq_len[idx].item()
cache_idx = slot_idx[idx].item()
pos = input_pos[idx].item()
# Update compressed_kv portion
mla_cache[cache_idx, pos : pos + length, :kv_lora_rank] = compressed_kv[
start : start + length
]
# Update kpe portion
mla_cache[cache_idx, pos : pos + length, kv_lora_rank:] = kpe[start : start + length]
def _torch_mla_generate_with_absorption(
q_nope: torch.Tensor, # [B, 1, N, qk_nope_head_dim]
q_pe: torch.Tensor, # [B, 1, N, qk_rope_head_dim]
compressed_kv: torch.Tensor, # [B, 1, kv_lora_rank]
kpe: torch.Tensor, # [B, 1, 1, qk_rope_head_dim]
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
slot_idx: torch.Tensor,
input_pos: torch.Tensor,
scale: float,
kv_lora_rank: int,
num_heads: int,
qk_nope_head_dim: int,
v_head_dim: int,
out: torch.Tensor,
) -> None:
"""Generate-only MLA attention with weight absorption.
Weight absorption: Instead of expanding all cached KV, we absorb kv_b_proj into Q.
Q_absorbed = Q_nope @ W_k^T where W_k is the k_nope portion of kv_b_proj_weight
This avoids expanding potentially thousands of cached tokens.
"""
b = q_nope.shape[0]
# Extract k_nope and v portions from kv_b_proj_weight
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank]
weight_reshaped = kv_b_proj_weight.view(num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank)
w_k_nope = weight_reshaped[:, :qk_nope_head_dim, :] # [N, qk_nope_head_dim, kv_lora_rank]
w_v = weight_reshaped[:, qk_nope_head_dim:, :] # [N, v_head_dim, kv_lora_rank]
# Update cache with new tokens
compressed_kv_flat = compressed_kv.squeeze(1) # [B, kv_lora_rank]
kpe_flat = kpe.squeeze(1).squeeze(1) # [B, qk_rope_head_dim]
for i in range(b):
cache_idx = slot_idx[i].item()
pos = input_pos[i].item()
mla_cache[cache_idx, pos, :kv_lora_rank] = compressed_kv_flat[i]
mla_cache[cache_idx, pos, kv_lora_rank:] = kpe_flat[i]
# Compute attention for each sequence using weight absorption
for i in range(b):
cache_idx = slot_idx[i].item()
pos = input_pos[i].item()
# Get query for this sequence: [N, qk_nope_head_dim], [N, qk_rope_head_dim]
q_nope_i = q_nope[i, 0] # [N, qk_nope_head_dim]
q_pe_i = q_pe[i, 0] # [N, qk_rope_head_dim]
# Retrieve cached data up to current position
cached_data = mla_cache[cache_idx, : pos + 1] # [seq_len, kv_lora_rank + qk_rope_head_dim]
compressed_kv_cached = cached_data[:, :kv_lora_rank] # [seq_len, kv_lora_rank]
kpe_cached = cached_data[:, kv_lora_rank:] # [seq_len, qk_rope_head_dim]
# =====================================================================
# Weight absorption for Q_nope part
# =====================================================================
# q_absorbed = q_nope @ w_k_nope^T (absorb k_nope projection into query)
# q_nope_i: [N, qk_nope_head_dim]
# w_k_nope: [N, qk_nope_head_dim, kv_lora_rank]
# q_absorbed: [N, kv_lora_rank]
q_absorbed = torch.einsum("nd,ndk->nk", q_nope_i, w_k_nope)
# Attention scores from absorbed Q and compressed KV
# Compute in fp32 to match FlashInfer's use_fp16_qk_reduction=False
# q_absorbed: [N, kv_lora_rank], compressed_kv_cached: [seq_len, kv_lora_rank]
# scores_nope: [N, seq_len]
scores_nope = torch.matmul(q_absorbed.float(), compressed_kv_cached.float().t())
# =====================================================================
# Q_pe part - standard attention with kpe
# =====================================================================
# q_pe_i: [N, qk_rope_head_dim], kpe_cached: [seq_len, qk_rope_head_dim]
# scores_pe: [N, seq_len]
scores_pe = torch.matmul(q_pe_i.float(), kpe_cached.float().t())
# Combined attention scores (already in fp32)
attn_scores = (scores_nope + scores_pe) * scale # [N, seq_len]
# Softmax (already in fp32, convert back to input dtype)
attn_weights = torch.softmax(attn_scores, dim=-1).to(q_nope.dtype) # [N, seq_len]
# =====================================================================
# Compute output with absorbed value projection
# =====================================================================
# v_out = attn_weights @ compressed_kv @ w_v^T
# First: weighted_kv = attn_weights @ compressed_kv_cached -> [N, kv_lora_rank]
weighted_kv = torch.matmul(attn_weights, compressed_kv_cached) # [N, kv_lora_rank]
# Then: attn_out = weighted_kv @ w_v^T -> [N, v_head_dim]
# w_v: [N, v_head_dim, kv_lora_rank]
# weighted_kv: [N, kv_lora_rank]
attn_out = torch.einsum("nk,nvk->nv", weighted_kv, w_v) # [N, v_head_dim]
out[i] = attn_out
def _torch_mla_context_with_expansion(
q_nope: torch.Tensor, # [total_tokens, N, qk_nope_head_dim]
q_pe: torch.Tensor, # [total_tokens, N, qk_rope_head_dim]
compressed_kv: torch.Tensor, # [total_tokens, kv_lora_rank]
kpe: torch.Tensor, # [total_tokens, 1, qk_rope_head_dim]
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
input_pos: torch.Tensor,
slot_idx: torch.Tensor,
seq_len: torch.Tensor,
seq_start: torch.Tensor,
scale: float,
kv_lora_rank: int,
num_heads: int,
qk_nope_head_dim: int,
v_head_dim: int,
out: torch.Tensor,
) -> None:
"""Context MLA attention with kv_b_proj expansion.
For prefill, we expand compressed_kv using kv_b_proj_weight and compute
standard attention. This is more efficient than absorption for prefill
since we only expand the current tokens, not the full cache.
"""
# Flatten kpe: [total_tokens, 1, qk_rope_head_dim] -> [total_tokens, qk_rope_head_dim]
kpe_flat = kpe.squeeze(1)
# Update cache first with compressed representation
_update_mla_cache(
compressed_kv,
kpe_flat,
mla_cache,
seq_len,
input_pos,
slot_idx,
seq_start,
kv_lora_rank,
)
# Compute attention for each sequence
attn_outputs = []
for idx in range(seq_len.shape[0]):
seq_len_i = seq_len[idx].item()
input_pos_i = input_pos[idx].item()
slot_idx_i = slot_idx[idx].item()
seq_start_i = seq_start[idx].item()
if seq_len_i == 0:
continue
# Get query for this sequence
q_nope_seq = q_nope[
seq_start_i : seq_start_i + seq_len_i
] # [seq_len_i, N, qk_nope_head_dim]
q_pe_seq = q_pe[seq_start_i : seq_start_i + seq_len_i] # [seq_len_i, N, qk_rope_head_dim]
# Get cached data for attention (includes just-added tokens)
kv_seq_len = input_pos_i + seq_len_i
cached_data = mla_cache[
slot_idx_i, :kv_seq_len
] # [kv_seq_len, kv_lora_rank + qk_rope_head_dim]
compressed_kv_cached = cached_data[:, :kv_lora_rank] # [kv_seq_len, kv_lora_rank]
kpe_cached = cached_data[:, kv_lora_rank:] # [kv_seq_len, qk_rope_head_dim]
# =====================================================================
# Expand compressed_kv using kv_b_proj_weight for this sequence
# =====================================================================
# compressed_kv_cached: [kv_seq_len, kv_lora_rank]
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# kv_expanded: [kv_seq_len, N * (qk_nope_head_dim + v_head_dim)]
kv_expanded = torch.matmul(compressed_kv_cached, kv_b_proj_weight.t())
# Reshape to [kv_seq_len, N, qk_nope_head_dim + v_head_dim]
kv_expanded = kv_expanded.view(kv_seq_len, num_heads, qk_nope_head_dim + v_head_dim)
# Split into k_nope and v
k_nope_expanded = kv_expanded[:, :, :qk_nope_head_dim] # [kv_seq_len, N, qk_nope_head_dim]
v_expanded = kv_expanded[:, :, qk_nope_head_dim:] # [kv_seq_len, N, v_head_dim]
# Expand kpe to all heads
kpe_expanded = kpe_cached.unsqueeze(1).expand(
-1, num_heads, -1
) # [kv_seq_len, N, qk_rope_head_dim]
# Construct full query and key
query_full = torch.cat([q_nope_seq, q_pe_seq], dim=-1) # [seq_len_i, N, qk_head_dim]
key_full = torch.cat(
[k_nope_expanded, kpe_expanded], dim=-1
) # [kv_seq_len, N, qk_head_dim]
# Transpose for attention: [1, N, seq_len, head_dim]
query_t = query_full.transpose(0, 1).unsqueeze(0) # [1, N, seq_len_i, qk_head_dim]
key_t = key_full.transpose(0, 1).unsqueeze(0) # [1, N, kv_seq_len, qk_head_dim]
# Compute attention scores in fp32 to match FlashInfer's use_fp16_qk_reduction=False
# FlashInfer uses fp32 accumulation for QK^T, so we do the same for numerical consistency
attn_scores = (
torch.matmul(query_t.float(), key_t.float().transpose(-2, -1)) * scale
) # [1, N, seq_len_i, kv_seq_len] in fp32
# Apply causal mask
causal_mask = torch.triu(
torch.ones(seq_len_i, kv_seq_len, device=q_nope.device, dtype=torch.bool),
diagonal=kv_seq_len - seq_len_i + 1,
)
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
# Softmax (already in fp32, convert back to input dtype)
attn_weights = torch.softmax(attn_scores, dim=-1).to(q_nope.dtype)
# Value: [1, N, kv_seq_len, v_head_dim]
v_t = v_expanded.transpose(0, 1).unsqueeze(0)
# Compute output
attn_out = torch.matmul(attn_weights, v_t) # [1, N, seq_len_i, v_head_dim]
attn_out = attn_out[0].transpose(0, 1) # [seq_len_i, N, v_head_dim]
attn_outputs.append(attn_out)
# Concatenate all outputs
if len(attn_outputs) == 0:
out.zero_()
elif len(attn_outputs) == 1:
out.copy_(attn_outputs[0])
else:
out.copy_(torch.cat(attn_outputs, dim=0))
@torch.library.custom_op("auto_deploy::torch_cached_mla_with_cache", mutates_args=())
def torch_backend_mla_with_cache(
# 5 tensor args (get_num_qkv_args = 5)
q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim]
q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim]
compressed_kv: torch.Tensor, # [B, S, kv_lora_rank]
kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim]
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# Standard metadata
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
slot_idx: torch.Tensor,
cu_seqlen: torch.Tensor,
# Cache (FlashInfer layout)
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
# Constants
scale: Optional[float] = None,
kv_lora_rank: int = 512,
) -> torch.Tensor:
"""Torch backend MLA with FlashInfer-compatible compressed cache.
FlashInfer MLA Cache Layout:
mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
- compressed_kv = mla_cache[:, :, :kv_lora_rank] (zero-copy slice)
- kpe = mla_cache[:, :, kv_lora_rank:] (zero-copy slice)
Prefill (context): Expand compressed_kv, compute normal attention
Generate (decode): Use weight absorption for efficiency
"""
# Get dimensions
b, s = q_nope.shape[:2]
num_heads = q_nope.shape[2]
qk_nope_head_dim = q_nope.shape[3]
qk_rope_head_dim = q_pe.shape[3]
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# Infer v_head_dim from kv_b_proj_weight
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads
v_head_dim = kv_head_dim - qk_nope_head_dim
# Get cleaned up metadata
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
input_pos = input_pos[:num_seq]
slot_idx = slot_idx[:num_seq]
seq_start = cu_seqlen[:num_seq]
# Set scale
if scale is None:
scale = 1.0 / math.sqrt(qk_head_dim)
# Define output shape: [B, S, N, v_head_dim]
output_shape = (b, s, num_heads, v_head_dim)
if s == 1:
# =====================================================================
# Generate phase: Use weight absorption
# =====================================================================
y = q_nope.new_empty(b, num_heads, v_head_dim).contiguous()
_torch_mla_generate_with_absorption(
q_nope,
q_pe,
compressed_kv,
kpe,
kv_b_proj_weight,
mla_cache,
slot_idx,
input_pos,
scale,
kv_lora_rank,
num_heads,
qk_nope_head_dim,
v_head_dim,
y,
)
return y.unsqueeze(1) # [B, 1, N, v_head_dim]
else:
# =====================================================================
# Context phase: Expand and compute normal attention
# =====================================================================
bs_view = (b * s,)
q_nope_flat = q_nope.contiguous().view(*bs_view, num_heads, qk_nope_head_dim)
q_pe_flat = q_pe.contiguous().view(*bs_view, num_heads, qk_rope_head_dim)
compressed_kv_flat = compressed_kv.contiguous().view(*bs_view, kv_lora_rank)
kpe_flat = kpe.contiguous().view(*bs_view, 1, qk_rope_head_dim)
y = q_nope.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
_torch_mla_context_with_expansion(
q_nope_flat,
q_pe_flat,
compressed_kv_flat,
kpe_flat,
kv_b_proj_weight,
mla_cache,
input_pos,
slot_idx,
seq_len,
seq_start,
scale,
kv_lora_rank,
num_heads,
qk_nope_head_dim,
v_head_dim,
y,
)
return y.view(*output_shape)
@torch_backend_mla_with_cache.register_fake
def torch_backend_mla_with_cache_fake(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
compressed_kv: torch.Tensor,
kpe: torch.Tensor,
kv_b_proj_weight: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
slot_idx: torch.Tensor,
cu_seqlen: torch.Tensor,
mla_cache: torch.Tensor,
scale: Optional[float] = None,
kv_lora_rank: int = 512,
) -> torch.Tensor:
"""Fake implementation for torch_backend_mla_with_cache."""
num_heads = q_nope.shape[2]
qk_nope_head_dim = q_nope.shape[-1]
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads
v_head_dim = kv_head_dim - qk_nope_head_dim
return q_nope.new_empty(
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
).contiguous()
@AttentionRegistry.register("torch_mla")
class TorchBackendMLAAttention(AttentionDescriptor):
"""Attention descriptor for Multi-head Latent Attention (MLA).
This descriptor uses FlashInfer-compatible compressed cache:
- torch_mla: source op that expands compressed_kv for attention
- torch_cached_mla_with_cache: cached op with absorption for generate
FlashInfer MLA Cache Layout:
mla_cache: [max_batch, max_seq, head_dim_ckv + head_dim_kpe]
- No num_heads dimension (MLA-specific optimization)
- ckv_cached = mla_cache[:, :, :head_dim_ckv] (zero-copy slice)
- kpe_cached = mla_cache[:, :, head_dim_ckv:] (zero-copy slice)
Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout
"""
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the backend."""
return "bsnd"
@classmethod
def get_num_qkv_args(cls) -> int:
"""Get the number of tensor arguments expected by the source op."""
return 5 # q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.auto_deploy.torch_mla
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
"""Get the cached attention op."""
return torch.ops.auto_deploy.torch_cached_mla_with_cache.default
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
"""Get the list of standard metadata arguments."""
return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
"""Get cache initializers using FlashInfer MLA cache layout."""
# Extract dimensions from source node args
# torch_mla signature: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, ...
compressed_kv_fake = source_attn_node.args[2].meta["val"]
kpe_fake = source_attn_node.args[3].meta["val"]
# Get dimensions
# compressed_kv: [B, S, kv_lora_rank]
# kpe: [B, S, 1, qk_rope_head_dim]
kv_lora_rank = compressed_kv_fake.shape[-1]
qk_rope_head_dim = kpe_fake.shape[-1]
# FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
# No num_heads dimension - this is the key MLA optimization
return {
"mla_cache": UnpagedResourceHandler(
kv_lora_rank + qk_rope_head_dim,
dtype=cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype),
),
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
"""Get constants to pass to the cached attention op."""
# Extract kv_lora_rank for cache slicing
compressed_kv_fake = source_attn_node.args[2].meta["val"]
kv_lora_rank = compressed_kv_fake.shape[-1]
# Get scale from kwargs
scale = source_attn_node.kwargs.get("scale", None)
return [scale, kv_lora_rank]

View File

@ -0,0 +1,157 @@
"""Torch reference implementation for Multi-head Latent Attention (MLA).
This module provides the source op for MLA that:
- Accepts compressed_kv (before kv_b_proj) for FlashInfer-compatible caching
- Expands compressed_kv using kv_b_proj_weight for attention computation
- Computes standard attention with the expanded K, V
"""
import math
from typing import Optional
import torch
@torch.library.custom_op("auto_deploy::torch_mla", mutates_args=())
def torch_mla(
q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim]
q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim] (RoPE applied)
compressed_kv: torch.Tensor, # [B, S, kv_lora_rank] - BEFORE kv_b_proj
kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim] (RoPE applied)
kv_b_proj_weight: torch.Tensor, # [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
is_causal: bool = True,
scale: Optional[float] = None,
layout: str = "bsnd",
) -> torch.Tensor:
"""Multi-head Latent Attention (MLA) with FlashInfer-compatible compressed KV.
This op expands compressed_kv using kv_b_proj_weight and computes attention.
For prefill, this is the standard formulation. For the cached version,
weight absorption is used for efficiency.
Args:
q_nope: Query non-positional component [B, S, N, qk_nope_head_dim] (bsnd)
q_pe: Query positional component with RoPE applied [B, S, N, qk_rope_head_dim] (bsnd)
compressed_kv: Compressed KV latent [B, S, kv_lora_rank] (before kv_b_proj)
kpe: Key positional encoding with RoPE applied [B, S, 1, qk_rope_head_dim] (bsnd)
kv_b_proj_weight: Projection weights [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
is_causal: Whether to apply causal masking (default: True)
scale: Softmax scale factor (default: 1/sqrt(qk_head_dim))
layout: Input/output layout, either "bsnd" or "bnsd" (default: "bsnd")
Returns:
Attention output with shape [B, S, N, v_head_dim] (bsnd)
"""
if layout not in ("bnsd", "bsnd"):
raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}")
# Get dimensions
if layout == "bsnd":
bs, s_q, num_heads, qk_nope_head_dim = q_nope.shape
qk_rope_head_dim = q_pe.shape[-1]
else:
bs, num_heads, s_q, qk_nope_head_dim = q_nope.shape
qk_rope_head_dim = q_pe.shape[-1]
s_k = compressed_kv.shape[1]
# Infer dimensions from kv_b_proj_weight
# kv_b_proj_weight: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads # qk_nope_head_dim + v_head_dim
v_head_dim = kv_head_dim - qk_nope_head_dim
# Set scale
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
if scale is None:
scale = 1.0 / math.sqrt(qk_head_dim)
# =========================================================================
# Expand compressed_kv using kv_b_proj_weight (this is the prefill path)
# =========================================================================
# compressed_kv: [B, S, kv_lora_rank]
# kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank]
# kv = compressed_kv @ kv_b_proj_weight.T -> [B, S, num_heads * kv_head_dim]
kv = torch.matmul(compressed_kv, kv_b_proj_weight.t())
# Reshape to [B, S, N, kv_head_dim]
kv = kv.view(bs, s_k, num_heads, kv_head_dim)
# Split into k_nope and value_states
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
# k_nope and value_states are always [B, S, N, D] from the kv reshape above.
# We need them in [B, N, S, D] for attention computation.
k_nope = k_nope.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
# Convert inputs to computation layout [B, N, S, D] if they come in bsnd format
if layout == "bsnd":
# [B, S, N, D] -> [B, N, S, D]
q_nope = q_nope.transpose(1, 2).contiguous()
q_pe = q_pe.transpose(1, 2).contiguous()
kpe = kpe.transpose(1, 2).contiguous()
# kpe is [B, 1, S, qk_rope_head_dim], expand to num_heads
kpe_expanded = kpe.expand(bs, num_heads, s_k, qk_rope_head_dim)
# Construct full query and key states
# query_states: [B, N, S, qk_head_dim]
query_states = torch.cat([q_nope, q_pe], dim=-1)
# key_states: [B, N, S, qk_head_dim]
key_states = torch.cat([k_nope, kpe_expanded], dim=-1)
# Compute attention scores: Q @ K^T
attn_scores = (
torch.matmul(query_states, key_states.transpose(-2, -1)) * scale
) # [B, N, s_q, s_k]
# Apply causal mask if specified
if is_causal and s_q == s_k:
causal_mask = torch.triu(
torch.ones(s_q, s_k, device=q_nope.device, dtype=torch.bool),
diagonal=1,
)
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
# Compute attention weights and output
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q_nope.dtype)
attn_out = torch.matmul(attn_weights, value_states) # [B, N, s_q, v_head_dim]
# Convert back to requested layout
if layout == "bsnd":
return attn_out.transpose(1, 2).contiguous() # [B, S, N, v_head_dim]
else:
return attn_out.contiguous() # [B, N, S, v_head_dim]
@torch_mla.register_fake
def torch_mla_fake(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
compressed_kv: torch.Tensor,
kpe: torch.Tensor,
kv_b_proj_weight: torch.Tensor,
is_causal: bool = True,
scale: Optional[float] = None,
layout: str = "bsnd",
) -> torch.Tensor:
"""Fake implementation for torch_mla."""
# Infer v_head_dim from kv_b_proj_weight
qk_nope_head_dim = q_nope.shape[-1]
num_heads = q_nope.shape[2] if layout == "bsnd" else q_nope.shape[1]
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads
v_head_dim = kv_head_dim - qk_nope_head_dim
# Output shape depends on layout
if layout == "bsnd":
# Input: [B, S, N, D], Output: [B, S, N, v_head_dim]
return q_nope.new_empty(
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
).contiguous()
else:
# Input: [B, N, S, D], Output: [B, N, S, v_head_dim]
return q_nope.new_empty(
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
).contiguous()

View File

@ -1,9 +1,11 @@
from .modeling_eagle import Eagle3DrafterForCausalLM
from .modeling_deepseek import DeepSeekV3ForCausalLM
from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
from .modeling_nemotron_h import NemotronHForCausalLM
__all__ = (
"Eagle3DrafterForCausalLM",
"DeepSeekV3ForCausalLM",
"Glm4MoeLiteForCausalLM",
"NemotronFlashForCausalLM",
"NemotronFlashPreTrainedTokenizerFast",
"NemotronHForCausalLM",

View File

@ -0,0 +1,655 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Slimmed down PyTorch DeepSeekV3 model implementation for auto_deploy export.
Source:
https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
This implementation differs from the original in the following ways:
* Simplified for prefill-only inference (no KV caching)
* Uses auto_deploy custom ops for export compatibility
* Removed flash attention variants (uses torch_mla custom op)
* Removed gradient checkpointing and training code paths
* Removed attention dropout (inference only)
This allows us to have a clean export-ready implementation with auto_deploy custom ops.
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.utils import ActivationType
class DeepSeekV3RMSNorm(nn.Module):
"""RMS Normalization for DeepSeekV3."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return torch.ops.auto_deploy.triton_rms_norm(
hidden_states, self.weight, self.variance_epsilon
).to(hidden_states.dtype)
class DeepSeekV3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding for DeepSeekV3.
Simplified version that precomputes and caches cos/sin values.
Returns full cached values (not sliced by seq_len) to enable export.
"""
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build cos/sin cache
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(
self, x: torch.Tensor, seq_len: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Return full cached cos/sin (not sliced) for export compatibility
return (
self.cos_cached.to(dtype=x.dtype, device=x.device),
self.sin_cached.to(dtype=x.dtype, device=x.device),
)
class DeepSeekV3YarnRotaryEmbedding(DeepSeekV3RotaryEmbedding):
"""YaRN-extended rotary embedding for DeepSeekV3."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: float = 10000.0,
scaling_factor: float = 1.0,
original_max_position_embeddings: int = 4096,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1.0,
mscale_all_dim: float = 0.0,
):
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, max_position_embeddings, base)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
dim = self.dim
freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
freq_inter = 1.0 / (
self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
low, high = self._yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
_mscale = float(
self._yarn_get_mscale(self.scaling_factor, self.mscale)
/ self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", (emb.cos() * _mscale), persistent=False)
self.register_buffer("sin_cached", (emb.sin() * _mscale), persistent=False)
@staticmethod
def _yarn_find_correction_dim(
num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048
) -> float:
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def _yarn_find_correction_range(
self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int
) -> Tuple[int, int]:
low = math.floor(
self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
@staticmethod
def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
@staticmethod
def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor:
if min_val == max_val:
max_val += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
return torch.clamp(linear_func, 0, 1)
class DeepSeekV3MLP(nn.Module):
"""MLP layer for DeepSeekV3 (SwiGLU activation)."""
def __init__(
self, config, hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class DeepSeekV3MoEGate(nn.Module):
"""MoE Gating for DeepSeekV3 with noaux_tc top-k selection."""
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = nn.Parameter(
torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32)
)
self.register_buffer(
"e_score_correction_bias",
torch.zeros(self.n_routed_experts, dtype=torch.float32),
)
self.reset_parameters()
def reset_parameters(self) -> None:
"""Initialize gate weights using kaiming uniform (matches original DeepSeek implementation)."""
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass returning (selected_experts, routing_weights)."""
bsz, seq_len, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Compute router logits
if self.weight.dtype == torch.float32:
router_logits = F.linear(hidden_states_flat.float(), self.weight)
else:
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32
)
# Use fused noaux_tc_op kernel for top-k selection
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
router_logits,
self.e_score_correction_bias,
self.n_group,
self.topk_group,
self.top_k,
self.routed_scaling_factor,
)
return topk_indices, topk_weights
class DeepSeekV3MoE(nn.Module):
"""Mixture of Experts layer for DeepSeekV3."""
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
# Routed experts
self.experts = nn.ModuleList(
[
DeepSeekV3MLP(config, intermediate_size=config.moe_intermediate_size)
for _ in range(config.n_routed_experts)
]
)
# Gate
self.gate = DeepSeekV3MoEGate(config)
# Shared experts (if configured)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepSeekV3MLP(config, intermediate_size=intermediate_size)
else:
self.shared_experts = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
selected_experts, routing_weights = self.gate(hidden_states)
# Use torch_moe custom op for routed experts
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states.view(-1, hidden_states.shape[-1]),
selected_experts,
routing_weights,
w1_weight=[expert.gate_proj.weight for expert in self.experts],
w2_weight=[expert.down_proj.weight for expert in self.experts],
w3_weight=[expert.up_proj.weight for expert in self.experts],
is_gated_mlp=True,
act_fn=int(ActivationType.Silu),
)
final_hidden_states = final_hidden_states.view(*orig_shape)
# Add shared experts output if present
if self.shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(identity)
return final_hidden_states.to(hidden_states.dtype)
class DeepSeekV3Attention(nn.Module):
"""Multi-head Latent Attention (MLA) for DeepSeekV3.
Uses compressed KV representation with latent projections.
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.q_lora_rank = config.q_lora_rank
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# Q projection (with optional LoRA)
if self.q_lora_rank is None:
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = DeepSeekV3RMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
# KV projection with MQA
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = DeepSeekV3RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
)
# Output projection
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias
)
# Initialize rotary embedding
self._init_rope()
# Softmax scale
self.softmax_scale = self.q_head_dim ** (-0.5)
if config.rope_scaling is not None:
mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = config.rope_scaling["factor"]
if mscale_all_dim:
mscale = DeepSeekV3YarnRotaryEmbedding._yarn_get_mscale(
scaling_factor, mscale_all_dim
)
self.softmax_scale = self.softmax_scale * mscale * mscale
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = DeepSeekV3RotaryEmbedding(
self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "yarn":
kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rotary_emb = DeepSeekV3YarnRotaryEmbedding(
self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else:
# Default to base rotary embedding for unsupported types
self.rotary_emb = DeepSeekV3RotaryEmbedding(
self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
# Q projection
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
# Shape: [B, S, N, q_head_dim] (BSND layout)
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# KV projection - keep compressed form
kv_a_output = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# Apply layernorm to compressed_kv
compressed_kv = self.kv_a_layernorm(compressed_kv)
# k_pe: [B, S, 1, qk_rope_head_dim] (BSND layout, shared across heads)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
kv_seq_len = q_len
# Get cos/sin for RoPE
cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
cos = cos[position_ids] # [B, S, head_dim]
sin = sin[position_ids] # [B, S, head_dim]
# Apply RoPE using custom op
q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_pe,
k_pe,
cos,
sin,
2, # unsqueeze_dim=2 for BSND layout
)
# Call MLA with compressed KV
attn_output = torch.ops.auto_deploy.torch_mla(
q_nope, # [B, S, N, qk_nope_head_dim]
q_pe_rotated, # [B, S, N, qk_rope_head_dim]
compressed_kv, # [B, S, kv_lora_rank]
kpe, # [B, S, 1, qk_rope_head_dim]
self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank]
True, # is_causal
self.softmax_scale,
"bsnd", # layout
)
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class DeepSeekV3DecoderLayer(nn.Module):
"""Transformer decoder layer for DeepSeekV3."""
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
# Attention
self.self_attn = DeepSeekV3Attention(config, layer_idx=layer_idx)
# MLP or MoE
# MoE layers are used after first_k_dense_replace and at moe_layer_freq intervals
use_moe = (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
if use_moe:
self.mlp = DeepSeekV3MoE(config)
else:
self.mlp = DeepSeekV3MLP(config)
# Layer norms
self.input_layernorm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DeepSeekV3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
# Self attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, position_ids)
hidden_states = residual + hidden_states
# MLP/MoE
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@dataclass
class DeepSeekV3Output(ModelOutput):
"""Output for DeepSeekV3Model."""
last_hidden_state: Optional[torch.FloatTensor] = None
@dataclass
class DeepSeekV3CausalLMOutput(ModelOutput):
"""Output for DeepSeekV3ForCausalLM."""
logits: Optional[torch.FloatTensor] = None
class DeepSeekV3PreTrainedModel(PreTrainedModel):
"""Base class for DeepSeekV3 models."""
base_model_prefix = "model"
_no_split_modules = ["DeepSeekV3DecoderLayer"]
supports_gradient_checkpointing = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class DeepSeekV3Model(DeepSeekV3PreTrainedModel):
"""DeepSeekV3 transformer decoder model."""
def __init__(self, config):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[
DeepSeekV3DecoderLayer(config, layer_idx=idx)
for idx in range(config.num_hidden_layers)
]
)
self.norm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> DeepSeekV3Output:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Cannot specify both input_ids and inputs_embeds")
elif input_ids is None and inputs_embeds is None:
raise ValueError("Must specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = inputs_embeds.shape[:2]
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
hidden_states = inputs_embeds
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states, position_ids)
hidden_states = self.norm(hidden_states)
return DeepSeekV3Output(last_hidden_state=hidden_states)
class DeepSeekV3ForCausalLM(DeepSeekV3PreTrainedModel, GenerationMixin):
"""DeepSeekV3 model with language modeling head."""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DeepSeekV3Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> DeepSeekV3CausalLMOutput:
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states).float()
return DeepSeekV3CausalLMOutput(logits=logits)
# Register with AutoModelForCausalLMFactory
AutoModelForCausalLMFactory.register_custom_model_cls("DeepseekV3Config", DeepSeekV3ForCausalLM)

View File

@ -0,0 +1,830 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Slimmed down PyTorch GLM4 MoE Lite model implementation for auto_deploy export.
Source:
https://huggingface.co/zai-org/GLM-4.7-Flash
This implementation differs from the original HuggingFace version in the following ways:
* Bundled config class to work with transformers v4.57 (model requires v5.0)
* Simplified for prefill-only inference (no KV caching)
* Uses auto_deploy custom ops for export compatibility
* Removed flash attention variants (uses torch_mla custom op)
* Removed gradient checkpointing and training code paths
* Removed attention dropout (inference only)
The GLM4 MoE Lite model uses Multi-head Latent Attention (MLA), similar to DeepSeek V3.
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.utils import ActivationType
class Glm4MoeLiteConfig(PretrainedConfig):
"""Configuration class for GLM4 MoE Lite model.
This config class is bundled with the custom model implementation to enable
loading on transformers v4.57 (the model requires v5.0 where the config is
natively registered).
"""
model_type = "glm4_moe_lite"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size: int = 154880,
hidden_size: int = 2048,
intermediate_size: int = 10240,
num_hidden_layers: int = 47,
num_attention_heads: int = 20,
num_key_value_heads: int = 20,
hidden_act: str = "silu",
max_position_embeddings: int = 202752,
rms_norm_eps: float = 1e-5,
# MLA parameters
q_lora_rank: int = 768,
kv_lora_rank: int = 512,
qk_nope_head_dim: int = 192,
qk_rope_head_dim: int = 64,
v_head_dim: int = 256,
# MoE parameters
n_routed_experts: int = 64,
n_shared_experts: int = 1,
num_experts_per_tok: int = 4,
moe_intermediate_size: int = 1536,
n_group: int = 1,
topk_group: int = 1,
routed_scaling_factor: float = 1.8,
norm_topk_prob: bool = True,
first_k_dense_replace: int = 1,
# RoPE parameters
rope_theta: float = 1000000.0,
rope_scaling: Optional[dict] = None,
# Other parameters
attention_bias: bool = False,
attention_dropout: float = 0.0,
tie_word_embeddings: bool = False,
pad_token_id: int = 154820,
initializer_range: float = 0.02,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
# MLA
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# MoE
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.moe_intermediate_size = moe_intermediate_size
self.n_group = n_group
self.topk_group = topk_group
self.routed_scaling_factor = routed_scaling_factor
self.norm_topk_prob = norm_topk_prob
self.first_k_dense_replace = first_k_dense_replace
# RoPE
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# Other
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
super().__init__(
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Register config with transformers' AutoConfig so it can be loaded from HF hub
# Use exist_ok=True to handle cases where transformers already has this model type registered
# (e.g., transformers v5.0+). In those cases, AutoConfig will use the built-in config,
# but AutoModelForCausalLMFactory will still use our custom model implementation.
try:
AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig, exist_ok=True)
except TypeError:
# Older transformers versions don't support exist_ok parameter
try:
AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig)
except ValueError:
# Already registered by transformers, that's fine
pass
class Glm4MoeLiteRMSNorm(nn.Module):
"""RMS Normalization for GLM4 MoE Lite.
Uses standard torch operations so AD fusion passes can replace with
the appropriate backend (flashinfer/triton) based on config.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Glm4MoeLiteRotaryEmbedding(nn.Module):
"""Rotary Position Embedding for GLM4 MoE Lite.
Simplified version that precomputes and caches cos/sin values.
Returns full cached values (not sliced by seq_len) to enable export.
Uses _ad_ prefix for buffer names to work with AutoDeploy's lift_to_meta.
"""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: float = 10000.0,
attention_scaling: float = 1.0,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.attention_scaling = attention_scaling
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build cos/sin cache with AD-specific naming
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# Use _ad_ prefix for AutoDeploy compatibility with lift_to_meta
self.register_buffer("_ad_cos_cached", emb.cos() * self.attention_scaling, persistent=False)
self.register_buffer("_ad_sin_cached", emb.sin() * self.attention_scaling, persistent=False)
def forward(
self, x: torch.Tensor, seq_len: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Return full cached cos/sin (not sliced) for export compatibility
return (
self._ad_cos_cached.to(dtype=x.dtype, device=x.device),
self._ad_sin_cached.to(dtype=x.dtype, device=x.device),
)
class Glm4MoeLiteYarnRotaryEmbedding(Glm4MoeLiteRotaryEmbedding):
"""YaRN-extended rotary embedding for GLM4 MoE Lite."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: float = 10000.0,
scaling_factor: float = 1.0,
original_max_position_embeddings: int = 4096,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1.0,
mscale_all_dim: float = 0.0,
attention_scaling: float = 1.0,
):
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, max_position_embeddings, base, attention_scaling)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
dim = self.dim
freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
freq_inter = 1.0 / (
self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
low, high = self._yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
_mscale = float(
self._yarn_get_mscale(self.scaling_factor, self.mscale)
/ self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
emb = torch.cat((freqs, freqs), dim=-1)
# Use _ad_ prefix for AutoDeploy compatibility with lift_to_meta
# Note: attention_scaling is already incorporated in _mscale for YaRN
self.register_buffer("_ad_cos_cached", (emb.cos() * _mscale), persistent=False)
self.register_buffer("_ad_sin_cached", (emb.sin() * _mscale), persistent=False)
@staticmethod
def _yarn_find_correction_dim(
num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048
) -> float:
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def _yarn_find_correction_range(
self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int
) -> Tuple[int, int]:
low = math.floor(
self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
@staticmethod
def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
@staticmethod
def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor:
if min_val == max_val:
max_val += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
return torch.clamp(linear_func, 0, 1)
class Glm4MoeLiteMLP(nn.Module):
"""MLP layer for GLM4 MoE Lite (SwiGLU activation)."""
def __init__(
self, config, hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Glm4MoeLiteMoEGate(nn.Module):
"""MoE Gating for GLM4 MoE Lite with top-k selection.
Uses fused TensorRT-LLM custom ops for efficient routing:
- dsv3_router_gemm_op: Fused router GEMM for non-float32 weights
- noaux_tc_op: Fused sigmoid + bias + group top-k + normalize + scale
"""
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = getattr(config, "norm_topk_prob", True)
# noaux_tc_op always normalizes, so norm_topk_prob must be True
if not self.norm_topk_prob:
raise ValueError(
"Glm4MoeLiteMoEGate requires norm_topk_prob=True when using fused ops. "
"The noaux_tc_op kernel always normalizes routing weights."
)
self.weight = nn.Parameter(
torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32)
)
self.register_buffer(
"e_score_correction_bias",
torch.zeros(self.n_routed_experts, dtype=torch.float32),
)
self.reset_parameters()
def reset_parameters(self) -> None:
"""Initialize gate weights using kaiming uniform."""
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass returning (selected_experts, routing_weights).
Uses fused TensorRT-LLM ops for efficient routing:
1. dsv3_router_gemm_op: Router GEMM (when weights are not float32)
2. noaux_tc_op: Fused sigmoid + bias + group top-k + normalize + scale
"""
bsz, seq_len, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Router GEMM - use fused op when weights are not float32
if self.weight.dtype == torch.float32:
router_logits = F.linear(hidden_states_flat.float(), self.weight)
else:
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32
)
# Fused routing: sigmoid + bias + group top-k + normalize + scale
# noaux_tc_op internally applies:
# 1. Sigmoid to router_logits
# 2. Adds e_score_correction_bias
# 3. Group-wise top-2 scoring and top group selection
# 4. Top-k expert selection from selected groups
# 5. Gathers weights from sigmoid scores
# 6. Normalizes and scales by routed_scaling_factor
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
router_logits,
self.e_score_correction_bias,
self.n_group,
self.topk_group,
self.top_k,
self.routed_scaling_factor,
)
return topk_indices, topk_weights
class Glm4MoeLiteMoE(nn.Module):
"""Mixture of Experts layer for GLM4 MoE Lite."""
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
# Routed experts - use ModuleList with individual expert modules
# This creates state_dict keys like: experts.0.gate_proj.weight
# which matches the checkpoint structure
self.experts = nn.ModuleList(
[
Glm4MoeLiteMLP(config, intermediate_size=config.moe_intermediate_size)
for _ in range(config.n_routed_experts)
]
)
# Gate
self.gate = Glm4MoeLiteMoEGate(config)
# Shared experts
if config.n_shared_experts is not None and config.n_shared_experts > 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Glm4MoeLiteMLP(config, intermediate_size=intermediate_size)
else:
self.shared_experts = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
selected_experts, routing_weights = self.gate(hidden_states)
# Use torch_moe custom op for routed experts
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states.view(-1, hidden_states.shape[-1]),
selected_experts,
routing_weights,
w1_weight=[expert.gate_proj.weight for expert in self.experts],
w2_weight=[expert.down_proj.weight for expert in self.experts],
w3_weight=[expert.up_proj.weight for expert in self.experts],
is_gated_mlp=True,
act_fn=int(ActivationType.Silu),
)
final_hidden_states = final_hidden_states.view(*orig_shape)
# Add shared experts output if present
if self.shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(identity)
return final_hidden_states.to(hidden_states.dtype)
class Glm4MoeLiteAttention(nn.Module):
"""Multi-head Latent Attention (MLA) for GLM4 MoE Lite.
Uses compressed KV representation with latent projections.
Receives position embeddings from the model level (shared rotary embedding).
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.q_lora_rank = config.q_lora_rank
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
# Q projection (with optional LoRA)
if self.q_lora_rank is None:
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = Glm4MoeLiteRMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(
config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False
)
# KV projection with MQA
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = Glm4MoeLiteRMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
)
# Output projection
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias
)
# Softmax scale
self.softmax_scale = self.qk_head_dim ** (-0.5)
# Apply mscale adjustment if using YaRN scaling with factor
if (
config.rope_scaling is not None
and isinstance(config.rope_scaling, dict)
and "factor" in config.rope_scaling
):
mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = config.rope_scaling["factor"]
if mscale_all_dim:
mscale = Glm4MoeLiteYarnRotaryEmbedding._yarn_get_mscale(
scaling_factor, mscale_all_dim
)
self.softmax_scale = self.softmax_scale * mscale * mscale
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
# Q projection
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
# Shape: [B, S, N, qk_head_dim] (BSND layout)
q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# KV projection - keep compressed form
kv_a_output = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# Apply layernorm to compressed_kv
compressed_kv = self.kv_a_layernorm(compressed_kv)
# k_pe: [B, S, 1, qk_rope_head_dim] (BSND layout, shared across heads)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
# Get cos/sin from position_embeddings (full cached from shared rotary embedding)
cos, sin = position_embeddings # Full table: [max_seq_len, head_dim]
cos = cos[position_ids] # [B, S, head_dim]
sin = sin[position_ids] # [B, S, head_dim]
# Apply RoPE using custom op
q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_pe,
k_pe,
cos,
sin,
2, # unsqueeze_dim=2 for BSND layout
)
# Call MLA with compressed KV
attn_output = torch.ops.auto_deploy.torch_mla(
q_nope, # [B, S, N, qk_nope_head_dim]
q_pe_rotated, # [B, S, N, qk_rope_head_dim]
compressed_kv, # [B, S, kv_lora_rank]
kpe, # [B, S, 1, qk_rope_head_dim]
self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank]
True, # is_causal
self.softmax_scale,
"bsnd", # layout
)
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class Glm4MoeLiteDecoderLayer(nn.Module):
"""Transformer decoder layer for GLM4 MoE Lite."""
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
# Attention
self.self_attn = Glm4MoeLiteAttention(config, layer_idx=layer_idx)
# MLP or MoE
# Layer 0 to first_k_dense_replace-1 use dense MLP, rest use MoE
use_moe = config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
if use_moe:
self.mlp = Glm4MoeLiteMoE(config)
else:
self.mlp = Glm4MoeLiteMLP(config)
# Layer norms
self.input_layernorm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Glm4MoeLiteRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
# Self attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, position_ids, position_embeddings)
hidden_states = residual + hidden_states
# MLP/MoE
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@dataclass
class Glm4MoeLiteOutput(ModelOutput):
"""Output for Glm4MoeLiteModel."""
last_hidden_state: Optional[torch.FloatTensor] = None
@dataclass
class Glm4MoeLiteCausalLMOutput(ModelOutput):
"""Output for Glm4MoeLiteForCausalLM."""
logits: Optional[torch.FloatTensor] = None
class Glm4MoeLitePreTrainedModel(PreTrainedModel):
"""Base class for GLM4 MoE Lite models."""
config_class = Glm4MoeLiteConfig
base_model_prefix = "model"
_no_split_modules = ["Glm4MoeLiteDecoderLayer"]
supports_gradient_checkpointing = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Glm4MoeLiteModel(Glm4MoeLitePreTrainedModel):
"""GLM4 MoE Lite transformer decoder model."""
def __init__(self, config):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[
Glm4MoeLiteDecoderLayer(config, layer_idx=idx)
for idx in range(config.num_hidden_layers)
]
)
self.norm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Shared rotary embedding at model level (not per-layer)
# This creates a single set of cos/sin buffers for all layers
self.rotary_emb = self._init_rope(config)
self.post_init()
def _init_rope(self, config):
"""Initialize shared rotary embedding for all layers."""
qk_rope_head_dim = config.qk_rope_head_dim
# Compute attention_scaling for RoPE (same logic as in attention)
attention_scaling = 1.0
if (
config.rope_scaling is not None
and isinstance(config.rope_scaling, dict)
and "factor" in config.rope_scaling
):
mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = config.rope_scaling["factor"]
if mscale_all_dim:
mscale = Glm4MoeLiteYarnRotaryEmbedding._yarn_get_mscale(
scaling_factor, mscale_all_dim
)
attention_scaling = mscale
# Check if rope_scaling is None, empty, or missing required "factor" key
use_yarn = (
config.rope_scaling is not None
and isinstance(config.rope_scaling, dict)
and "factor" in config.rope_scaling
)
if not use_yarn:
return Glm4MoeLiteRotaryEmbedding(
qk_rope_head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
attention_scaling=attention_scaling,
)
else:
scaling_factor = config.rope_scaling["factor"]
kwargs = {
key: config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in config.rope_scaling
}
return Glm4MoeLiteYarnRotaryEmbedding(
qk_rope_head_dim,
max_position_embeddings=config.max_position_embeddings,
scaling_factor=scaling_factor,
base=config.rope_theta,
attention_scaling=attention_scaling,
**kwargs,
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Glm4MoeLiteOutput:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Cannot specify both input_ids and inputs_embeds")
elif input_ids is None and inputs_embeds is None:
raise ValueError("Must specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = inputs_embeds.shape[:2]
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Compute position embeddings once from shared rotary embedding
# This returns full cached cos/sin tables
position_embeddings = self.rotary_emb(inputs_embeds)
hidden_states = inputs_embeds
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states, position_ids, position_embeddings)
hidden_states = self.norm(hidden_states)
return Glm4MoeLiteOutput(last_hidden_state=hidden_states)
class Glm4MoeLiteForCausalLM(Glm4MoeLitePreTrainedModel, GenerationMixin):
"""GLM4 MoE Lite model with language modeling head."""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = Glm4MoeLiteModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Glm4MoeLiteCausalLMOutput:
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states).float()
return Glm4MoeLiteCausalLMOutput(logits=logits)
# Register with AutoModelForCausalLMFactory
AutoModelForCausalLMFactory.register_custom_model_cls("Glm4MoeLiteConfig", Glm4MoeLiteForCausalLM)

View File

@ -1,185 +0,0 @@
import types
import warnings
from typing import Dict, Optional
import torch
import torch.utils.checkpoint
from transformers import AutoModelForCausalLM
from transformers.cache_utils import Cache
def deepseek_v3_attention(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.IntTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""DeepSeekV3Attention forward function rewritten to wrap MultiheadLatentAttention as a custom op."""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
# If else paths are determined by config.json
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
_, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
raise ValueError("past_key_value is not supported")
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Use custom op to capture mla. This does not handle KV cache
# as passing transformers Cache into a custom op is throwing an error.
# Would not be an issue, cause we intend to replace mla op with our implementation further along the pipeline
attn_output = torch.ops.auto_deploy.torch_attention_deepseek_fused_mla(
q_nope,
q_pe,
kv,
k_pe,
cos,
sin,
position_ids,
attention_mask,
self.softmax_scale,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# This patched module matches exactly with HF generate
@torch.inference_mode()
def deepseek_v3_moe_exact(self, hidden_states):
"""DeepSeekV3MoE forward function rewritten to enable torch export.
This custom implementation matches exactly with the deepseek implementation. There are
some errors in the output tensors when the index_add based implementation is used, leading
to some mismatch in the outputs for some prompts. This ensures exact match between HF output
without custom patch and with custom patch.
"""
identity = hidden_states
batch_size, sequence_length, hidden_dim = hidden_states.shape
selected_experts, routing_weights, *_ = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_dim)
idxs = torch.argsort(selected_experts.view(-1), stable=True)
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.experts_per_rank
).permute(2, 1, 0)
outputs = []
for expert_idx in range(len(self.experts)):
expert_layer = self.experts[expert_idx]
_, top_x = torch.where(expert_mask[expert_idx])
# Sort the top_xs and idx
sorted, _ = torch.sort(top_x)
tokens_for_this_expert = hidden_states[None, sorted].reshape(-1, hidden_dim)
expert_out = expert_layer(tokens_for_this_expert)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
# Wrap torch.zeros() in a custom op to fix meta device issue during inference.
new_x = torch.zeros(
(*selected_experts.view(-1).shape, hidden_dim),
device=selected_experts.device,
dtype=outs.dtype,
)
new_x[idxs] = outs
final_hidden_states = (
new_x.view(*selected_experts.shape, -1)
.type(routing_weights.dtype)
.mul_(routing_weights.unsqueeze(-1))
.sum(dim=1)
.type(new_x.dtype)
)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(identity)
return final_hidden_states.to(hidden_states.dtype)
@torch.inference_mode()
def deepseek_v3_moe(self, hidden_states):
"""DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export."""
selected_experts, routing_weights, *_ = self.gate(hidden_states)
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states,
selected_experts,
routing_weights,
w1_weight=[expert.gate_proj.weight for expert in self.experts],
w2_weight=[expert.down_proj.weight for expert in self.experts],
w3_weight=[expert.up_proj.weight for expert in self.experts],
)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(hidden_states)
return final_hidden_states.to(hidden_states.dtype)
def deepseek_v3_rope(self, x, seq_len=None):
"""DeepSeekV3 Rotary Embedding forward function rewritten to enable torch export.
We return the full cached cos and sin values, instead of slicing them based on seq_len as this
would cause an issue during the generate phase (when seq_len=1 from input_ids). We also move the cos
sin buffers to appropriate device to enable export.
"""
return (
self.cos_cached.to(dtype=x.dtype).to(device=x.device),
self.sin_cached.to(dtype=x.dtype).to(device=x.device),
)
_from_config_original = AutoModelForCausalLM.from_config
CUSTOM_MODULE_PATCHES: Dict[str, callable] = {
"DeepseekV3MoE": deepseek_v3_moe,
"DeepseekV2MoE": deepseek_v3_moe,
"DeepseekV3RotaryEmbedding": deepseek_v3_rope,
"DeepseekV3YarnRotaryEmbedding": deepseek_v3_rope,
"DeepseekV2RotaryEmbedding": deepseek_v3_rope,
"DeepseekV2YarnRotaryEmbedding": deepseek_v3_rope,
}
def get_model_from_config_patched(config, **kwargs):
model = _from_config_original(config, **kwargs)
# Patch modules
for _, module in model.named_modules():
if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys():
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
return model
# TODO: figure out how this can be incorporated into the export patch system
AutoModelForCausalLM.from_config = get_model_from_config_patched

View File

@ -83,7 +83,7 @@ class QuantConfigReaderRegistry:
@QuantConfigReaderRegistry.register("modelopt")
class ModelOPTQuantConfigReader(QuantConfigReader):
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*")
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate")
def read_config(self, config: Dict) -> Dict:
producer = config.get("producer", {}).get("name")

View File

@ -411,3 +411,75 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestGLM4Flash(LlmapiAccuracyTestHarness):
"""Accuracy regression tests for GLM-4.7-Flash.
TODO: enable in CI, see https://github.com/NVIDIA/TensorRT-LLM/issues/11117
In the meantime, you should run this test locally:
```
cd tests/integration/defs
TRTLLM_ACCURACY_NO_REFERENCE=1 pytest -svv "accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]"
```
"""
MODEL_NAME = "zai-org/GLM-4.7-Flash"
MODEL_PATH = MODEL_NAME # Model is in HF_CACHE
# Set minimum possible seq len + small buffer, for test speed & memory usage
MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN,
GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN)
MAX_NUM_TOKENS = MAX_SEQ_LEN
def get_default_kwargs(self, enable_chunked_prefill=False):
config = {
"skip_tokenizer_init": False,
"trust_remote_code": True,
"compile_backend": "torch-cudagraph",
"max_batch_size": 128,
"max_seq_len": self.MAX_SEQ_LEN,
"max_num_tokens": self.MAX_NUM_TOKENS,
"skip_loading_weights": False,
"disable_overlap_scheduler": False,
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
"kv_cache_config": {
"enable_block_reuse": False,
"free_gpu_memory_fraction": 0.88
},
"model_kwargs": {
"torch_dtype": "bfloat16"
},
"transforms": {
"fuse_nvfp4_moe": {
"allow_different_input_scales": True,
},
},
}
if enable_chunked_prefill:
config["enable_chunked_prefill"] = True
config[
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
return config
def get_default_sampling_params(self):
eos_id = -1
beam_width = 1
return SamplingParams(end_id=eos_id,
pad_id=eos_id,
n=beam_width,
use_beam_search=beam_width > 1)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
def test_auto_dtype(self, enable_chunked_prefill):
kwargs = self.get_default_kwargs(enable_chunked_prefill)
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,
**kwargs) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -400,7 +400,8 @@ _SMALL_MODEL_CONFIGS = {
"num_hidden_layers": 2,
"hidden_size": 32,
"intermediate_size": 64,
"kv_lora_rank": 128,
"kv_lora_rank": 512, # NOTE: must be 512 (default) for flashinfer_mla to work
"qk_rope_head_dim": 64, # NOTE: must be 64 (default) for flashinfer_mla to work
"moe_intermediate_size": 128,
"n_group": 2,
"topk_group": 2,

View File

@ -0,0 +1,780 @@
"""Comprehensive test suite for torch MLA backend operations.
Tests the torch_mla source op and torch_backend_mla_with_cache cached op
with FlashInfer-compatible compressed cache layout.
Key features:
- 5 tensor arguments: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight
- Compressed cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
- Prefill: Expand compressed_kv, compute normal attention
- Generate: Weight absorption for efficiency
"""
import math
import numpy as np
import pytest
import torch
import tensorrt_llm._torch.auto_deploy # noqa: F401
def numpy_mla_reference_with_expansion(
q_nope: np.ndarray,
q_pe: np.ndarray,
compressed_kv: np.ndarray,
kpe: np.ndarray,
kv_b_proj_weight: np.ndarray,
mla_cache: np.ndarray,
seq_len: np.ndarray,
input_pos: np.ndarray,
cache_loc: np.ndarray,
seq_start: np.ndarray,
scale: float = None,
kv_lora_rank: int = None,
is_generate: bool = False,
):
"""Numpy reference implementation of MLA attention with FlashInfer cache layout.
This expands compressed_kv using kv_b_proj_weight for attention computation.
"""
# Get dimensions
if is_generate:
batch_size = q_nope.shape[0]
num_heads = q_nope.shape[2]
qk_nope_head_dim = q_nope.shape[3]
qk_rope_head_dim = q_pe.shape[3]
else:
batch_size = len(seq_len)
num_heads = q_nope.shape[2]
qk_nope_head_dim = q_nope.shape[3]
qk_rope_head_dim = q_pe.shape[3]
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
if kv_lora_rank is None:
kv_lora_rank = compressed_kv.shape[-1]
# Infer v_head_dim from kv_b_proj_weight
out_features = kv_b_proj_weight.shape[0]
kv_head_dim = out_features // num_heads
v_head_dim = kv_head_dim - qk_nope_head_dim
if scale is None:
scale = 1.0 / math.sqrt(qk_head_dim)
# Update MLA cache first
if is_generate:
for i in range(batch_size):
cache_idx = cache_loc[i]
pos = input_pos[i]
mla_cache[cache_idx, pos, :kv_lora_rank] = compressed_kv[i, 0]
mla_cache[cache_idx, pos, kv_lora_rank:] = kpe[i, 0, 0]
else:
for i in range(batch_size):
cache_idx = cache_loc[i]
pos = input_pos[i]
seq_len_i = seq_len[i]
seq_start_i = seq_start[i]
for j in range(seq_len_i):
mla_cache[cache_idx, pos + j, :kv_lora_rank] = compressed_kv[seq_start_i + j]
mla_cache[cache_idx, pos + j, kv_lora_rank:] = kpe[seq_start_i + j, 0]
# Compute attention for each sequence
outputs = []
for i in range(batch_size):
cache_idx = cache_loc[i]
pos = input_pos[i]
seq_len_i = seq_len[i]
seq_start_i = seq_start[i]
if seq_len_i == 0:
continue
# Get query for this sequence
if is_generate:
q_nope_seq = q_nope[i, 0] # [N, qk_nope_head_dim]
q_pe_seq = q_pe[i, 0] # [N, qk_rope_head_dim]
else:
q_nope_seq = q_nope[seq_start_i : seq_start_i + seq_len_i]
q_pe_seq = q_pe[seq_start_i : seq_start_i + seq_len_i]
# Get cached compressed_kv and kpe
kv_seq_len = pos + seq_len_i
cached_data = mla_cache[cache_idx, :kv_seq_len]
compressed_kv_cached = cached_data[:, :kv_lora_rank]
kpe_cached = cached_data[:, kv_lora_rank:]
# Expand compressed_kv using kv_b_proj_weight
# compressed_kv_cached: [kv_seq_len, kv_lora_rank]
# kv_b_proj_weight: [N * kv_head_dim, kv_lora_rank]
kv_expanded = np.matmul(compressed_kv_cached, kv_b_proj_weight.T)
kv_expanded = kv_expanded.reshape(kv_seq_len, num_heads, kv_head_dim)
k_nope = kv_expanded[:, :, :qk_nope_head_dim]
v = kv_expanded[:, :, qk_nope_head_dim:]
# Expand kpe to all heads
kpe_expanded = np.broadcast_to(
kpe_cached[:, None, :], (kv_seq_len, num_heads, qk_rope_head_dim)
)
# Construct full query and key
if is_generate:
query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)
else:
query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)
key_full = np.concatenate([k_nope, kpe_expanded], axis=-1)
# Compute attention scores
if is_generate:
attn_scores = np.einsum("nh,knh->nk", query_full, key_full) * scale
else:
attn_scores = np.einsum("snh,knh->snk", query_full, key_full) * scale
causal_mask = np.triu(np.ones((seq_len_i, kv_seq_len)), k=kv_seq_len - seq_len_i + 1)
attn_scores = np.where(causal_mask[:, None, :], -np.inf, attn_scores)
# Apply softmax
attn_scores_max = np.max(attn_scores, axis=-1, keepdims=True)
attn_scores_exp = np.exp(attn_scores - attn_scores_max)
attn_weights = attn_scores_exp / np.sum(attn_scores_exp, axis=-1, keepdims=True)
# Compute output
if is_generate:
attn_out = np.einsum("nk,knh->nh", attn_weights, v)
else:
attn_out = np.einsum("snk,knh->snh", attn_weights, v)
outputs.append(attn_out)
# Concatenate outputs
if len(outputs) == 0:
return np.zeros((1, 0, num_heads, v_head_dim), dtype=np.float32)
elif is_generate:
result = np.stack(outputs, axis=0)
return result[:, None, :, :]
else:
result = np.concatenate(outputs, axis=0)
return result[None, :, :, :]
class TestTorchMLASourceOp:
"""Test torch_mla source op (without cache)."""
@pytest.fixture(autouse=True)
def setup_method(self):
"""Setup test configuration."""
self.device = "cuda"
self.dtype = torch.bfloat16
self.atol = 1e-2
self.rtol = 1e-2
torch.cuda.empty_cache()
torch.manual_seed(42)
np.random.seed(42)
def _create_mla_data(
self,
batch_size: int,
seq_len: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
kv_lora_rank: int,
v_head_dim: int,
layout: str = "bsnd",
):
"""Create test data for MLA source op with compressed_kv."""
kv_head_dim = qk_nope_head_dim + v_head_dim
if layout == "bsnd":
q_nope = torch.randn(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
dtype=self.dtype,
device=self.device,
)
q_pe = torch.randn(
batch_size,
seq_len,
num_heads,
qk_rope_head_dim,
dtype=self.dtype,
device=self.device,
)
compressed_kv = torch.randn(
batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device
)
kpe = torch.randn(
batch_size, seq_len, 1, qk_rope_head_dim, dtype=self.dtype, device=self.device
)
else: # bnsd
q_nope = torch.randn(
batch_size,
num_heads,
seq_len,
qk_nope_head_dim,
dtype=self.dtype,
device=self.device,
)
q_pe = torch.randn(
batch_size,
num_heads,
seq_len,
qk_rope_head_dim,
dtype=self.dtype,
device=self.device,
)
compressed_kv = torch.randn(
batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device
)
kpe = torch.randn(
batch_size, 1, seq_len, qk_rope_head_dim, dtype=self.dtype, device=self.device
)
# kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank]
kv_b_proj_weight = torch.randn(
num_heads * kv_head_dim, kv_lora_rank, dtype=self.dtype, device=self.device
)
return {
"q_nope": q_nope,
"q_pe": q_pe,
"compressed_kv": compressed_kv,
"kpe": kpe,
"kv_b_proj_weight": kv_b_proj_weight,
}
def test_basic_functionality(self):
"""Test basic MLA source op functionality."""
batch_size, seq_len, num_heads = 2, 4, 8
qk_nope_head_dim, qk_rope_head_dim = 128, 64
kv_lora_rank = 512
v_head_dim = 128
data = self._create_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
)
output = torch.ops.auto_deploy.torch_mla(
data["q_nope"],
data["q_pe"],
data["compressed_kv"],
data["kpe"],
data["kv_b_proj_weight"],
True, # is_causal
None, # scale
"bsnd", # layout
)
# Verify output shape: [B, S, N, v_head_dim]
expected_shape = (batch_size, seq_len, num_heads, v_head_dim)
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
# Verify output is finite
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
def test_both_layouts(self):
"""Test MLA source op with both bsnd and bnsd layouts."""
batch_size, seq_len, num_heads = 2, 4, 8
qk_nope_head_dim, qk_rope_head_dim = 64, 32
kv_lora_rank = 256
v_head_dim = 64
for layout in ["bsnd", "bnsd"]:
data = self._create_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
layout,
)
output = torch.ops.auto_deploy.torch_mla(
data["q_nope"],
data["q_pe"],
data["compressed_kv"],
data["kpe"],
data["kv_b_proj_weight"],
True,
None,
layout,
)
if layout == "bsnd":
expected_shape = (batch_size, seq_len, num_heads, v_head_dim)
else:
expected_shape = (batch_size, num_heads, seq_len, v_head_dim)
assert output.shape == expected_shape, (
f"Layout {layout}: Expected {expected_shape}, got {output.shape}"
)
def test_custom_scale(self):
"""Test MLA source op with custom scale."""
batch_size, seq_len, num_heads = 1, 2, 4
qk_nope_head_dim, qk_rope_head_dim = 32, 16
kv_lora_rank = 128
v_head_dim = 32
data = self._create_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
)
# Test with default scale
output_default = torch.ops.auto_deploy.torch_mla(
data["q_nope"],
data["q_pe"],
data["compressed_kv"],
data["kpe"],
data["kv_b_proj_weight"],
True,
None,
"bsnd",
)
# Test with custom scale
custom_scale = 0.5
output_custom = torch.ops.auto_deploy.torch_mla(
data["q_nope"],
data["q_pe"],
data["compressed_kv"],
data["kpe"],
data["kv_b_proj_weight"],
True,
custom_scale,
"bsnd",
)
# Outputs should be different
assert not torch.allclose(output_default, output_custom, atol=1e-3), (
"Custom scale should affect output"
)
class TestTorchBackendMLAWithCache:
"""Test torch_backend_mla_with_cache cached op."""
@pytest.fixture(autouse=True)
def setup_method(self):
"""Setup test configuration."""
self.device = "cuda"
self.dtype = torch.bfloat16
self.atol = 5e-2
self.rtol = 5e-2
torch.cuda.empty_cache()
torch.manual_seed(42)
np.random.seed(42)
def _create_cached_mla_data(
self,
batch_size: int,
seq_len: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
kv_lora_rank: int,
v_head_dim: int,
max_seq_len: int,
cache_offset: int = 0,
):
"""Create test data for cached MLA op with FlashInfer layout."""
kv_head_dim = qk_nope_head_dim + v_head_dim
# Create input tensors (BSND layout)
q_nope = torch.randn(
batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=self.dtype, device=self.device
)
q_pe = torch.randn(
batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=self.dtype, device=self.device
)
compressed_kv = torch.randn(
batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device
)
kpe = torch.randn(
batch_size, seq_len, 1, qk_rope_head_dim, dtype=self.dtype, device=self.device
)
# kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank]
kv_b_proj_weight = torch.randn(
num_heads * kv_head_dim, kv_lora_rank, dtype=self.dtype, device=self.device
)
# Create FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
mla_cache = torch.zeros(
batch_size,
max_seq_len,
kv_lora_rank + qk_rope_head_dim,
dtype=self.dtype,
device=self.device,
)
# Pre-fill cache with random data if cache_offset > 0
if cache_offset > 0:
mla_cache[:, :cache_offset, :] = torch.randn(
batch_size,
cache_offset,
kv_lora_rank + qk_rope_head_dim,
dtype=self.dtype,
device=self.device,
)
# Setup metadata
seq_len_tensor = torch.full((batch_size,), seq_len, device=self.device, dtype=torch.int32)
input_pos = torch.full((batch_size,), cache_offset, device=self.device, dtype=torch.int32)
cache_loc = torch.arange(batch_size, device=self.device, dtype=torch.int32)
if seq_len == 1:
# Generate phase
batch_info_host = torch.tensor(
[0, 0, batch_size], device=self.device, dtype=torch.int32
)
cu_seqlen = torch.arange(batch_size, device=self.device, dtype=torch.int32)
else:
# Context phase
batch_info_host = torch.tensor(
[batch_size, batch_size * seq_len, 0], device=self.device, dtype=torch.int32
)
cu_seqlen = torch.arange(
0, batch_size * seq_len, seq_len, device=self.device, dtype=torch.int32
)
# Flatten inputs for context phase
q_nope = q_nope.view(1, batch_size * seq_len, num_heads, qk_nope_head_dim)
q_pe = q_pe.view(1, batch_size * seq_len, num_heads, qk_rope_head_dim)
compressed_kv = compressed_kv.view(1, batch_size * seq_len, kv_lora_rank)
kpe = kpe.view(1, batch_size * seq_len, 1, qk_rope_head_dim)
return {
"q_nope": q_nope,
"q_pe": q_pe,
"compressed_kv": compressed_kv,
"kpe": kpe,
"kv_b_proj_weight": kv_b_proj_weight,
"batch_info_host": batch_info_host,
"seq_len": seq_len_tensor,
"input_pos": input_pos,
"cache_loc": cache_loc,
"cu_seqlen": cu_seqlen,
"mla_cache": mla_cache,
"kv_lora_rank": kv_lora_rank,
}
def _run_cached_mla(self, data, scale=None):
"""Run cached MLA operation."""
return torch.ops.auto_deploy.torch_cached_mla_with_cache(
data["q_nope"],
data["q_pe"],
data["compressed_kv"],
data["kpe"],
data["kv_b_proj_weight"],
data["batch_info_host"],
data["seq_len"],
data["input_pos"],
data["cache_loc"],
data["cu_seqlen"],
data["mla_cache"],
scale,
data["kv_lora_rank"],
)
def test_generate_phase_basic(self):
"""Test generate phase (single token) basic functionality."""
batch_size, seq_len, num_heads = 2, 1, 8
qk_nope_head_dim, qk_rope_head_dim = 64, 32
kv_lora_rank = 256
v_head_dim = 64
max_seq_len = 128
cache_offset = 5
data = self._create_cached_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
max_seq_len,
cache_offset,
)
output = self._run_cached_mla(data)
# Verify output shape
expected_shape = (batch_size, seq_len, num_heads, v_head_dim)
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
# Verify output is finite
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
def test_context_phase_basic(self):
"""Test context phase (multi-token) basic functionality."""
batch_size, seq_len, num_heads = 2, 4, 8
qk_nope_head_dim, qk_rope_head_dim = 64, 32
kv_lora_rank = 256
v_head_dim = 64
max_seq_len = 128
data = self._create_cached_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
max_seq_len,
)
output = self._run_cached_mla(data)
# Verify output shape
expected_shape = (1, batch_size * seq_len, num_heads, v_head_dim)
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
# Verify output is finite
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
def test_cache_update_correctness(self):
"""Test that cache is updated correctly during forward pass."""
batch_size, seq_len, num_heads = 1, 1, 4
qk_nope_head_dim, qk_rope_head_dim = 32, 16
kv_lora_rank = 128
v_head_dim = 32
max_seq_len = 32
cache_offset = 5
data = self._create_cached_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
max_seq_len,
cache_offset,
)
# Store original cache values at target position
original_cache_at_pos = data["mla_cache"][0, cache_offset].clone()
# Run forward pass
_ = self._run_cached_mla(data)
# Check cache was updated at the correct position
updated_cache_at_pos = data["mla_cache"][0, cache_offset]
# The cache should have been updated
assert not torch.allclose(original_cache_at_pos, updated_cache_at_pos, atol=1e-6), (
"Cache should have been updated at the target position"
)
def test_cache_layout_flashinfer_compatible(self):
"""Test that cache layout matches FlashInfer spec (no num_heads dimension)."""
batch_size, seq_len, num_heads = 2, 1, 4
qk_nope_head_dim, qk_rope_head_dim = 64, 32
kv_lora_rank = 512 # DeepSeek-style
v_head_dim = 128
max_seq_len = 64
data = self._create_cached_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
max_seq_len,
)
# Verify cache shape matches FlashInfer layout: [batch, seq, kv_lora_rank + rope_dim]
expected_cache_shape = (batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim)
assert data["mla_cache"].shape == expected_cache_shape, (
f"Cache shape {data['mla_cache'].shape} doesn't match FlashInfer layout {expected_cache_shape}"
)
# Verify zero-copy slicing works
compressed_kv_slice = data["mla_cache"][:, :, :kv_lora_rank]
kpe_slice = data["mla_cache"][:, :, kv_lora_rank:]
assert compressed_kv_slice.shape == (batch_size, max_seq_len, kv_lora_rank)
assert kpe_slice.shape == (batch_size, max_seq_len, qk_rope_head_dim)
# Verify slices share memory (zero-copy)
assert compressed_kv_slice.data_ptr() == data["mla_cache"].data_ptr(), (
"compressed_kv slice should be zero-copy"
)
def test_generate_with_reference(self):
"""Test generate phase against numpy reference."""
batch_size, seq_len, num_heads = 2, 1, 4
qk_nope_head_dim, qk_rope_head_dim = 32, 16
kv_lora_rank = 64
v_head_dim = 32
max_seq_len = 64
cache_offset = 3
data = self._create_cached_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
max_seq_len,
cache_offset,
)
# Run backend
output = self._run_cached_mla(data)
# Run numpy reference
reference = numpy_mla_reference_with_expansion(
data["q_nope"].cpu().float().numpy(),
data["q_pe"].cpu().float().numpy(),
data["compressed_kv"].cpu().float().numpy(),
data["kpe"].cpu().float().numpy(),
data["kv_b_proj_weight"].cpu().float().numpy(),
data["mla_cache"].cpu().float().numpy(),
data["seq_len"].cpu().numpy(),
data["input_pos"].cpu().numpy(),
data["cache_loc"].cpu().numpy(),
data["cu_seqlen"].cpu().numpy(),
None,
kv_lora_rank,
is_generate=True,
)
reference_torch = torch.from_numpy(reference).to(output.device, output.dtype)
assert torch.allclose(output, reference_torch, atol=self.atol, rtol=self.rtol), (
f"Generate phase output doesn't match reference. "
f"Max diff: {(output - reference_torch).abs().max():.6f}"
)
def test_dtype_preservation(self):
"""Test that output dtype matches input dtype."""
batch_size, seq_len, num_heads = 1, 1, 4
qk_nope_head_dim, qk_rope_head_dim = 32, 16
kv_lora_rank = 64
v_head_dim = 32
max_seq_len = 32
for dtype in [torch.float16, torch.bfloat16]:
self.dtype = dtype
data = self._create_cached_mla_data(
batch_size,
seq_len,
num_heads,
qk_nope_head_dim,
qk_rope_head_dim,
kv_lora_rank,
v_head_dim,
max_seq_len,
)
output = self._run_cached_mla(data)
assert output.dtype == dtype, f"Expected dtype {dtype}, got {output.dtype}"
def test_memory_efficiency(self):
"""Test that cache uses compressed dimensions (no num_heads)."""
batch_size = 1
max_seq_len = 1024
kv_lora_rank = 512
qk_rope_head_dim = 64
num_heads = 128 # DeepSeek V3
# FlashInfer compressed cache size
compressed_cache_size = batch_size * max_seq_len * (kv_lora_rank + qk_rope_head_dim)
# Expanded per-head cache size (what we avoid)
qk_nope_head_dim = 128
v_head_dim = 128
expanded_cache_size = (
batch_size
* max_seq_len
* num_heads
* (qk_nope_head_dim + v_head_dim + qk_rope_head_dim)
)
# Verify compression ratio
compression_ratio = expanded_cache_size / compressed_cache_size
assert compression_ratio > 50, f"Expected >50x compression, got {compression_ratio:.1f}x"
class TestMLADescriptor:
"""Test MultiHeadLatentAttention descriptor configuration."""
def _get_mla_descriptor(self):
"""Get MLA descriptor from registry."""
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionRegistry
return AttentionRegistry.get("torch_mla")
def test_descriptor_registration(self):
"""Test that MLA descriptor is properly registered."""
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionRegistry
assert AttentionRegistry.has("torch_mla"), "torch_mla should be registered"
def test_descriptor_layout(self):
"""Test that MLA descriptor uses correct layout."""
mla_descriptor = self._get_mla_descriptor()
assert mla_descriptor.get_attention_layout() == "bsnd", "MLA should use bsnd layout"
def test_descriptor_num_qkv_args(self):
"""Test that MLA descriptor expects 5 tensor args."""
mla_descriptor = self._get_mla_descriptor()
assert mla_descriptor.get_num_qkv_args() == 5, (
"MLA should expect 5 tensor args (q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight)"
)
def test_descriptor_source_op(self):
"""Test that MLA descriptor points to correct source op."""
mla_descriptor = self._get_mla_descriptor()
source_op = mla_descriptor.get_source_attention_op()
assert source_op == torch.ops.auto_deploy.torch_mla, "MLA should use torch_mla as source op"
def test_descriptor_cached_op(self):
"""Test that MLA descriptor points to correct cached op."""
mla_descriptor = self._get_mla_descriptor()
cached_op = mla_descriptor.get_cached_attention_op()
assert cached_op == torch.ops.auto_deploy.torch_cached_mla_with_cache.default, (
"MLA should use torch_cached_mla_with_cache as cached op"
)
def test_descriptor_standard_metadata(self):
"""Test that MLA descriptor uses standard metadata args."""
mla_descriptor = self._get_mla_descriptor()
expected_args = ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
actual_args = mla_descriptor.get_standard_metadata_args()
assert actual_args == expected_args, (
f"Expected standard metadata {expected_args}, got {actual_args}"
)

View File

@ -1,6 +1,8 @@
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention import update_kv_cache
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import (
_update_kv_cache,
)
def test_update_kv_cache():
@ -26,7 +28,7 @@ def test_update_kv_cache():
print("slot_idx: " + str(torch.tensor([0, 1])))
print("seq_start: " + str(torch.tensor([0, 3])))
update_kv_cache(
_update_kv_cache(
k.view(batch_size * seq_length, n_heads, K_D_HEAD),
v.view(batch_size * seq_length, n_heads, V_D_HEAD),
k_cache,

View File

@ -0,0 +1,494 @@
"""Testing custom DeepSeekV3 model implementation for auto_deploy export."""
import pytest
import torch
from transformers import PretrainedConfig
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_deepseek import (
DeepSeekV3Attention,
DeepSeekV3DecoderLayer,
DeepSeekV3ForCausalLM,
DeepSeekV3MLP,
DeepSeekV3Model,
DeepSeekV3MoE,
DeepSeekV3RMSNorm,
DeepSeekV3RotaryEmbedding,
DeepSeekV3YarnRotaryEmbedding,
)
class MockDeepSeekConfig(PretrainedConfig):
"""Mock DeepSeek config for testing the custom model components."""
model_type = "deepseek_v3"
def __init__(self):
super().__init__()
# Attention config
self.num_attention_heads = 8
self.qk_nope_head_dim = 64
self.qk_rope_head_dim = 32
self.v_head_dim = 64
self.kv_lora_rank = 128
self.q_lora_rank = None # No LoRA for Q in tests
self.hidden_size = 256
self.rope_theta = 10000.0
self.max_position_embeddings = 512
self.attention_bias = False
self.rope_scaling = None
self.rms_norm_eps = 1e-6
# MLP config
self.intermediate_size = 512
self.hidden_act = "silu"
# MoE config
self.n_routed_experts = 4
self.num_experts_per_tok = 2
self.moe_intermediate_size = 256
self.n_shared_experts = 1
self.routed_scaling_factor = 1.0
self.n_group = 1
self.topk_group = 1
# Model config
self.num_hidden_layers = 2
self.first_k_dense_replace = 1 # First layer is dense, second is MoE
self.moe_layer_freq = 1
self.vocab_size = 1000
self.pad_token_id = 0
self.initializer_range = 0.02
class TestDeepSeekV3RMSNorm:
"""Test DeepSeekV3RMSNorm implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_forward_shape(self):
"""Test that RMSNorm preserves input shape."""
hidden_size = 256
norm = DeepSeekV3RMSNorm(hidden_size).to(self.device, self.dtype)
x = torch.randn(2, 4, hidden_size, dtype=self.dtype, device=self.device)
output = norm(x)
assert output.shape == x.shape
assert torch.isfinite(output).all()
def test_output_normalized(self):
"""Test that output has approximately unit variance."""
hidden_size = 256
norm = DeepSeekV3RMSNorm(hidden_size).to(self.device, torch.float32)
x = torch.randn(2, 4, hidden_size, dtype=torch.float32, device=self.device)
output = norm(x)
# RMS should be close to 1 after normalization (scaled by weight)
rms = torch.sqrt((output**2).mean(-1))
assert torch.allclose(rms, torch.ones_like(rms), atol=0.1)
class TestDeepSeekV3RotaryEmbedding:
"""Test DeepSeekV3 Rotary Embedding implementations."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_base_rope_shape(self):
"""Test base rotary embedding output shape."""
dim = 32
max_pos = 512
rope = DeepSeekV3RotaryEmbedding(dim, max_pos).to(self.device)
x = torch.randn(2, 4, 8, dim, dtype=self.dtype, device=self.device)
cos, sin = rope(x)
# Should return full cached values
assert cos.shape == (max_pos, dim)
assert sin.shape == (max_pos, dim)
def test_yarn_rope_shape(self):
"""Test YaRN rotary embedding output shape."""
dim = 32
max_pos = 512
rope = DeepSeekV3YarnRotaryEmbedding(
dim,
max_pos,
scaling_factor=2.0,
original_max_position_embeddings=256,
).to(self.device)
x = torch.randn(2, 4, 8, dim, dtype=self.dtype, device=self.device)
cos, sin = rope(x)
assert cos.shape == (max_pos, dim)
assert sin.shape == (max_pos, dim)
class TestDeepSeekV3MLP:
"""Test DeepSeekV3MLP implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_forward_shape(self):
"""Test MLP output shape."""
config = MockDeepSeekConfig()
mlp = DeepSeekV3MLP(config).to(self.device, self.dtype)
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
output = mlp(x)
assert output.shape == x.shape
assert torch.isfinite(output).all()
class TestDeepSeekV3Attention:
"""Test DeepSeekV3Attention (MLA) implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_forward_shape(self):
"""Test attention output shape."""
config = MockDeepSeekConfig()
attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype)
batch_size, seq_len = 2, 4
hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
)
position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
output = attn(hidden_states, position_ids)
assert output.shape == hidden_states.shape
assert torch.isfinite(output).all()
def test_different_batch_sizes(self):
"""Test attention with different batch sizes."""
config = MockDeepSeekConfig()
attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype)
for batch_size in [1, 2, 4]:
seq_len = 4
hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
)
position_ids = (
torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
)
output = attn(hidden_states, position_ids)
assert output.shape == (batch_size, seq_len, config.hidden_size)
def test_different_sequence_lengths(self):
"""Test attention with different sequence lengths."""
config = MockDeepSeekConfig()
attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype)
for seq_len in [1, 4, 16]:
batch_size = 2
hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
)
position_ids = (
torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
)
output = attn(hidden_states, position_ids)
assert output.shape == (batch_size, seq_len, config.hidden_size)
class TestDeepSeekV3MoE:
"""Test DeepSeekV3MoE implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_forward_shape(self):
"""Test MoE output shape."""
config = MockDeepSeekConfig()
moe = DeepSeekV3MoE(config).to(self.device, self.dtype)
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
output = moe(x)
assert output.shape == x.shape
assert torch.isfinite(output).all()
def test_with_shared_experts(self):
"""Test MoE with shared experts."""
config = MockDeepSeekConfig()
config.n_shared_experts = 2
moe = DeepSeekV3MoE(config).to(self.device, self.dtype)
assert moe.shared_experts is not None
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
output = moe(x)
assert output.shape == x.shape
def test_without_shared_experts(self):
"""Test MoE without shared experts."""
config = MockDeepSeekConfig()
config.n_shared_experts = None
moe = DeepSeekV3MoE(config).to(self.device, self.dtype)
assert moe.shared_experts is None
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
output = moe(x)
assert output.shape == x.shape
class TestDeepSeekV3DecoderLayer:
"""Test DeepSeekV3DecoderLayer implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_dense_layer(self):
"""Test decoder layer with dense MLP."""
config = MockDeepSeekConfig()
# Layer 0 should be dense (before first_k_dense_replace)
layer = DeepSeekV3DecoderLayer(config, layer_idx=0).to(self.device, self.dtype)
assert isinstance(layer.mlp, DeepSeekV3MLP)
batch_size, seq_len = 2, 4
hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
)
position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
output = layer(hidden_states, position_ids)
assert output.shape == hidden_states.shape
def test_moe_layer(self):
"""Test decoder layer with MoE."""
config = MockDeepSeekConfig()
# Layer 1 should be MoE (at first_k_dense_replace)
layer = DeepSeekV3DecoderLayer(config, layer_idx=1).to(self.device, self.dtype)
assert isinstance(layer.mlp, DeepSeekV3MoE)
batch_size, seq_len = 2, 4
hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
)
position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
output = layer(hidden_states, position_ids)
assert output.shape == hidden_states.shape
class TestDeepSeekV3Model:
"""Test DeepSeekV3Model implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_forward_with_input_ids(self):
"""Test model forward with input_ids."""
config = MockDeepSeekConfig()
model = DeepSeekV3Model(config).to(self.device, self.dtype)
batch_size, seq_len = 2, 4
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device)
output = model(input_ids=input_ids)
assert output.last_hidden_state.shape == (batch_size, seq_len, config.hidden_size)
def test_forward_with_inputs_embeds(self):
"""Test model forward with inputs_embeds."""
config = MockDeepSeekConfig()
model = DeepSeekV3Model(config).to(self.device, self.dtype)
batch_size, seq_len = 2, 4
inputs_embeds = torch.randn(
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
)
output = model(inputs_embeds=inputs_embeds)
assert output.last_hidden_state.shape == inputs_embeds.shape
class TestDeepSeekV3ForCausalLM:
"""Test DeepSeekV3ForCausalLM implementation."""
@pytest.fixture(autouse=True)
def setup_method(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16
torch.manual_seed(42)
def test_forward(self):
"""Test causal LM forward pass."""
config = MockDeepSeekConfig()
model = DeepSeekV3ForCausalLM(config).to(self.device, self.dtype)
batch_size, seq_len = 2, 4
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device)
output = model(input_ids=input_ids)
assert output.logits.shape == (batch_size, seq_len, config.vocab_size)
def test_output_dtype(self):
"""Test that logits are float32."""
config = MockDeepSeekConfig()
model = DeepSeekV3ForCausalLM(config).to(self.device, self.dtype)
batch_size, seq_len = 2, 4
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device)
output = model(input_ids=input_ids)
# Logits should be float32 for numerical stability
assert output.logits.dtype == torch.float32
class TestMLAOpRegistration:
"""Test that MLA ops are properly registered."""
def test_torch_mla_registered(self):
"""Test that torch_mla op is registered."""
assert hasattr(torch.ops.auto_deploy, "torch_mla"), "torch_mla op should be registered"
def test_torch_cached_mla_registered(self):
"""Test that torch_cached_mla_with_cache op is registered."""
assert hasattr(torch.ops.auto_deploy, "torch_cached_mla_with_cache"), (
"torch_cached_mla_with_cache op should be registered"
)
def test_torch_mla_callable(self):
"""Test that torch_mla op is callable."""
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
batch_size, seq_len, num_heads = 1, 2, 4
qk_nope_head_dim, qk_rope_head_dim = 64, 32
kv_lora_rank = 128
v_head_dim = 64
q_nope = torch.randn(
batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device
)
q_pe = torch.randn(
batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device
)
compressed_kv = torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device)
kpe = torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
kv_b_proj_weight = torch.randn(
num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank, dtype=dtype, device=device
)
# Should not raise
output = torch.ops.auto_deploy.torch_mla(
q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, True, None, "bsnd"
)
assert output is not None
def test_torch_cached_mla_callable(self):
"""Test that torch_cached_mla_with_cache op is callable."""
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
batch_size, seq_len, num_heads = 1, 1, 4
qk_nope_head_dim, qk_rope_head_dim = 32, 16
kv_lora_rank = 64
v_head_dim = 32
max_seq_len = 32
q_nope = torch.randn(
batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device
)
q_pe = torch.randn(
batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device
)
compressed_kv = torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device)
kpe = torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
kv_b_proj_weight = torch.randn(
num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank, dtype=dtype, device=device
)
batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device)
seq_len_tensor = torch.tensor([seq_len], dtype=torch.int32, device=device)
input_pos = torch.tensor([0], dtype=torch.int32, device=device)
cache_loc = torch.tensor([0], dtype=torch.int32, device=device)
cu_seqlen = torch.tensor([0], dtype=torch.int32, device=device)
mla_cache = torch.zeros(
batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device
)
# Should not raise
output = torch.ops.auto_deploy.torch_cached_mla_with_cache(
q_nope,
q_pe,
compressed_kv,
kpe,
kv_b_proj_weight,
batch_info_host,
seq_len_tensor,
input_pos,
cache_loc,
cu_seqlen,
mla_cache,
None,
kv_lora_rank,
)
assert output is not None
class TestMoEOpRegistration:
"""Test that MoE ops are properly registered."""
def test_torch_moe_registered(self):
"""Test that torch_moe op is registered."""
assert hasattr(torch.ops.auto_deploy, "torch_moe"), "torch_moe op should be registered"
class TestCustomModelRegistration:
"""Test that custom model is properly registered."""
def test_model_registered(self):
"""Test that DeepSeekV3ForCausalLM is registered with the factory."""
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
# Check that the model is registered in the custom model mapping
assert "DeepseekV3Config" in AutoModelForCausalLMFactory._custom_model_mapping
assert (
AutoModelForCausalLMFactory._custom_model_mapping["DeepseekV3Config"]
== DeepSeekV3ForCausalLM
)

View File

@ -1,110 +0,0 @@
"""Testing module patches that enable export of deepseek model."""
import types
import pytest
import torch
from test_common.llm_data import hf_id_to_local_model_dir
from transformers import AutoConfig, AutoModelForCausalLM
from tensorrt_llm._torch.auto_deploy.models.patches.deepseek import (
deepseek_v3_attention,
deepseek_v3_moe_exact,
)
def _load_layer_from_model(model_name_or_path, layer_name):
"""
Loads a specific layer/module from a model without loading the entire model.
Parameters:
model_name_or_path (str): Path or name of the pretrained model.
layer_name (str): Name of the layer to extract.
Returns:
module: The specified layer/module if available, otherwise None.
"""
try:
# Load only the model configuration
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
# Load a subset of layers of the model and configure yarn
config.num_hidden_layers = 1
config.use_cache = False
config.first_k_dense_replace = 0
config.n_routed_experts = 2
config.num_experts_per_tok = 1
config.n_group = 1
config.topk_group = 1
config.hidden_size = 8
config.moe_intermediate_size = 8
config.num_attention_heads = 2
config.num_key_value_heads = 2
config.qk_nope_head_dim = 4
config.qk_rope_head_dim = 2
config.v_head_dim = 4
config.intermediate_size = 8
config.max_position_embeddings = 7
config.rope_scaling = None
# Build the model architecture (no weights loaded yet)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.eval()
# Access the specific layer by its name
module = dict(model.named_modules()).get(layer_name)
if module is None:
print(f"Layer '{layer_name}' not found in the model.")
else:
print(f"Successfully extracted layer '{layer_name}'.")
return module
except Exception as e:
print(f"Error extracting layer: {e}")
return None
def _generate_ds_attention_mask(b, s):
return torch.where(
torch.tril(torch.full((s, s), float("-inf"))).unsqueeze(0).unsqueeze(0).expand(b, 1, s, s)
== float("-inf"),
torch.tensor(0.0),
torch.tensor(float(-3.4028e38)),
)
@pytest.mark.parametrize(
"model_name, module_name, patch, inputs",
[
pytest.param(
hf_id_to_local_model_dir("deepseek-ai/DeepSeek-R1"),
"model.layers.0.self_attn",
deepseek_v3_attention,
[
torch.randn(2, 6, 8, dtype=torch.bfloat16),
_generate_ds_attention_mask(2, 6),
torch.tensor([[0, 1, 2, 3, 4, 5]]),
],
), # attention requires inputs [hidden_states, attention_mask, position_ids]
pytest.param(
hf_id_to_local_model_dir("deepseek-ai/DeepSeek-R1"),
"model.layers.0.mlp",
deepseek_v3_moe_exact,
[torch.randn(2, 6, 8, dtype=torch.bfloat16)],
), # moe requires inputs [hidden_states]
],
)
def test_module_patches(model_name, module_name, patch, inputs):
# Get module
module = _load_layer_from_model(model_name, module_name)
# Pass test inputs to generate reference
ref, *_ = module(*inputs)
# Patch layer
module.forward = types.MethodType(patch, module)
# Generate test output
test, *_ = module(*inputs)
torch.allclose(ref, test, atol=0, rtol=0)

View File

@ -0,0 +1,652 @@
"""Tests for GLM4 MoE Lite custom model implementation.
This module tests the custom GLM4 MoE Lite model implementation which uses
auto_deploy custom ops (torch_mla, torch_moe, etc.) for export compatibility.
"""
import pytest
import torch
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import (
Glm4MoeLiteConfig,
Glm4MoeLiteForCausalLM,
Glm4MoeLiteMLP,
Glm4MoeLiteMoE,
)
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8))
@pytest.fixture(scope="function", autouse=True)
def set_seed():
torch.manual_seed(42)
def _create_small_config() -> Glm4MoeLiteConfig:
"""Create a small GLM4 MoE Lite config for testing."""
return Glm4MoeLiteConfig(
vocab_size=1000,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=3, # Layer 0 dense, layers 1-2 MoE
num_attention_heads=4,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=512,
rms_norm_eps=1e-5,
# MLA params (scaled down)
q_lora_rank=32,
kv_lora_rank=32,
qk_nope_head_dim=8,
qk_rope_head_dim=8,
v_head_dim=16,
# MoE params (scaled down)
n_routed_experts=4,
n_shared_experts=1,
num_experts_per_tok=2,
moe_intermediate_size=32,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
norm_topk_prob=True,
first_k_dense_replace=1,
# RoPE
rope_theta=10000.0,
rope_scaling=None,
# Other
attention_bias=False,
attention_dropout=0.0,
pad_token_id=0,
)
def _create_moe_layer(config: Glm4MoeLiteConfig) -> Glm4MoeLiteMoE:
"""Create a MoE layer from config."""
moe = Glm4MoeLiteMoE(config)
# Initialize gate weights with randn for reproducibility
# (gate weight is initialized with torch.empty which isn't seeded)
moe.gate.weight = torch.nn.Parameter(torch.randn_like(moe.gate.weight))
return moe
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
def test_glm4_moe_lite_moe_layer(B, S, dtype):
"""Test that MoE layer produces valid output."""
device = "cuda"
config = _create_small_config()
moe = _create_moe_layer(config)
moe.to(device=device, dtype=dtype)
moe.eval()
H = config.hidden_size
x = torch.randn(B, S, H, device=device, dtype=dtype)
output = moe(x)
# Check output shape matches input shape
assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}"
# Check output is not all zeros (MoE should transform the input)
assert not torch.allclose(output, torch.zeros_like(output)), "Output should not be all zeros"
# Check output doesn't have NaN or Inf values
assert not torch.isnan(output).any(), "Output contains NaN values"
assert not torch.isinf(output).any(), "Output contains Inf values"
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
def test_glm4_moe_lite_full_model(B, S, dtype):
"""Test that full model produces valid output."""
device = "cuda"
config = _create_small_config()
model = Glm4MoeLiteForCausalLM(config)
model.to(device=device, dtype=dtype)
model.eval()
# Create input
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
output = model(input_ids=input_ids, position_ids=position_ids)
# Check output shape
assert output.logits.shape == (B, S, config.vocab_size), (
f"Expected logits shape {(B, S, config.vocab_size)}, got {output.logits.shape}"
)
# Check output doesn't have NaN or Inf values
assert not torch.isnan(output.logits).any(), "Logits contain NaN values"
assert not torch.isinf(output.logits).any(), "Logits contain Inf values"
def test_glm4_moe_lite_model_can_be_exported():
"""Test that the custom model can be exported with torch_export_to_gm.
This test verifies:
1. The model exports successfully without graph breaks
2. The exported graph module produces outputs with correct shape
3. The outputs contain finite values (no NaN/Inf)
Note: We don't test numerical equivalence between original and exported model
here because torch.export lifts parameters into the graph, creating a different
parameter structure that doesn't match load_state_dict. The numerical correctness
of the model itself is already validated by test_glm4_moe_lite_full_model.
"""
device = "cuda"
dtype = torch.bfloat16
config = _create_small_config()
model = Glm4MoeLiteForCausalLM(config)
model.to(device=device, dtype=dtype)
model.eval()
# Create input
B, S = 2, 8
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
# Define dynamic shapes
batch_size_dynamic = Dim.DYNAMIC
seq_len_dynamic = Dim.DYNAMIC
dynamic_shapes = (
{0: batch_size_dynamic, 1: seq_len_dynamic},
{0: batch_size_dynamic, 1: seq_len_dynamic},
)
# Export the model - this is the main test: verify no graph breaks
gm = torch_export_to_gm(
model,
args=tuple(),
kwargs={"input_ids": input_ids, "position_ids": position_ids},
dynamic_shapes=dynamic_shapes,
)
# Move graph module to device
move_to_device(gm, device)
# Verify the exported model produces valid output
with torch.inference_mode():
out_gm = gm(input_ids=input_ids, position_ids=position_ids)
# Check output structure and shape
assert "logits" in out_gm, "Output should contain 'logits' key"
logits = out_gm["logits"]
assert logits.shape == (B, S, config.vocab_size), (
f"Expected shape {(B, S, config.vocab_size)}, got {logits.shape}"
)
assert torch.isfinite(logits).all(), "Logits should not contain NaN or Inf"
# Test with different input shape to verify dynamic shapes work
B2, S2 = 1, 4
input_ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device)
position_ids2 = torch.arange(S2, device=device).unsqueeze(0).expand(B2, -1)
with torch.inference_mode():
out_gm2 = gm(input_ids=input_ids2, position_ids=position_ids2)
logits2 = out_gm2["logits"]
expected_shape = (B2, S2, config.vocab_size)
assert logits2.shape == expected_shape, (
f"Dynamic shape test failed: expected {expected_shape}, got {logits2.shape}"
)
assert torch.isfinite(logits2).all(), "Logits should not contain NaN or Inf"
def test_glm4_moe_lite_config_registration():
"""Test that the config is properly registered or model_type is correct."""
# Create a config and verify model_type
config = _create_small_config()
assert config.model_type == "glm4_moe_lite"
# Verify our config class can be instantiated with expected attributes
assert hasattr(config, "hidden_size")
assert hasattr(config, "num_attention_heads")
assert hasattr(config, "n_routed_experts")
assert hasattr(config, "kv_lora_rank")
assert hasattr(config, "qk_rope_head_dim")
def test_glm4_moe_lite_layer_types():
"""Test that layer 0 uses dense MLP and later layers use MoE."""
config = _create_small_config()
model = Glm4MoeLiteForCausalLM(config)
# Check layer 0 (should be dense MLP, not MoE)
layer0_mlp = model.model.layers[0].mlp
assert type(layer0_mlp).__name__ == "Glm4MoeLiteMLP", (
f"Layer 0 should use Glm4MoeLiteMLP, got {type(layer0_mlp).__name__}"
)
# Check layer 1+ (should be MoE)
for i in range(1, config.num_hidden_layers):
layer_mlp = model.model.layers[i].mlp
assert type(layer_mlp).__name__ == "Glm4MoeLiteMoE", (
f"Layer {i} should use Glm4MoeLiteMoE, got {type(layer_mlp).__name__}"
)
def test_glm4_moe_lite_expert_structure():
"""Test that experts have correct structure for checkpoint loading."""
config = _create_small_config()
moe = Glm4MoeLiteMoE(config)
# Check that experts is a ModuleList
assert isinstance(moe.experts, torch.nn.ModuleList), "experts should be nn.ModuleList"
# Check number of experts
assert len(moe.experts) == config.n_routed_experts, (
f"Expected {config.n_routed_experts} experts, got {len(moe.experts)}"
)
# Check each expert has the correct structure
for i, expert in enumerate(moe.experts):
assert hasattr(expert, "gate_proj"), f"Expert {i} missing gate_proj"
assert hasattr(expert, "up_proj"), f"Expert {i} missing up_proj"
assert hasattr(expert, "down_proj"), f"Expert {i} missing down_proj"
# Check state_dict keys match expected checkpoint format
state_dict = moe.state_dict()
expected_keys = [
"experts.0.gate_proj.weight",
"experts.0.up_proj.weight",
"experts.0.down_proj.weight",
]
for key in expected_keys:
assert key in state_dict, (
f"Expected key '{key}' in state_dict, got keys: {list(state_dict.keys())[:10]}..."
)
# =============================================================================
# Numerical Equivalence Tests
# These tests compare our custom implementation against the HF implementation
# =============================================================================
def _get_hf_model_class():
"""Get the HF model class for GLM4 MoE Lite.
Returns None if transformers doesn't have glm4_moe_lite (older versions).
"""
try:
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteForCausalLM as HFGlm4MoeLiteForCausalLM,
)
return HFGlm4MoeLiteForCausalLM
except ImportError:
return None
def _get_hf_moe_class():
"""Get the HF MoE class for GLM4 MoE Lite."""
try:
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteMoE as HFGlm4MoeLiteMoE,
)
return HFGlm4MoeLiteMoE
except ImportError:
return None
def _get_hf_attention_class():
"""Get the HF Attention class for GLM4 MoE Lite."""
try:
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteAttention as HFGlm4MoeLiteAttention,
)
return HFGlm4MoeLiteAttention
except ImportError:
return None
def _get_hf_mlp_class():
"""Get the HF MLP class for GLM4 MoE Lite."""
try:
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteMLP as HFGlm4MoeLiteMLP,
)
return HFGlm4MoeLiteMLP
except ImportError:
return None
def _get_hf_config_class():
"""Get the HF Config class for GLM4 MoE Lite."""
try:
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import (
Glm4MoeLiteConfig as HFGlm4MoeLiteConfig,
)
return HFGlm4MoeLiteConfig
except ImportError:
return None
def _convert_hf_moe_state_dict_to_custom(hf_state_dict: dict, n_experts: int) -> dict:
"""Convert HF MoE state dict to our custom format.
HF format (stacked):
experts.gate_up_proj: [n_experts, 2 * intermediate_size, hidden_size]
experts.down_proj: [n_experts, hidden_size, intermediate_size]
Custom format (per-expert):
experts.0.gate_proj.weight: [intermediate_size, hidden_size]
experts.0.up_proj.weight: [intermediate_size, hidden_size]
experts.0.down_proj.weight: [hidden_size, intermediate_size]
"""
custom_state_dict = {}
for key, value in hf_state_dict.items():
if key == "experts.gate_up_proj":
# Split stacked gate_up into individual gate and up per expert
# Shape: [n_experts, 2 * intermediate_size, hidden_size]
intermediate_size = value.shape[1] // 2
for i in range(n_experts):
gate_up = value[i] # [2 * intermediate_size, hidden_size]
gate_weight = gate_up[:intermediate_size] # [intermediate_size, hidden_size]
up_weight = gate_up[intermediate_size:] # [intermediate_size, hidden_size]
custom_state_dict[f"experts.{i}.gate_proj.weight"] = gate_weight
custom_state_dict[f"experts.{i}.up_proj.weight"] = up_weight
elif key == "experts.down_proj":
# Split stacked down into individual down per expert
# Shape: [n_experts, hidden_size, intermediate_size]
for i in range(n_experts):
custom_state_dict[f"experts.{i}.down_proj.weight"] = value[i]
else:
# Copy other keys as-is
custom_state_dict[key] = value
return custom_state_dict
def _convert_hf_full_model_state_dict_to_custom(hf_state_dict: dict, config) -> dict:
"""Convert full HF model state dict to custom format.
Handles MoE expert weight conversion for all MoE layers.
"""
custom_state_dict = {}
n_experts = config.n_routed_experts
for key, value in hf_state_dict.items():
# Check if this is an MoE expert weight
if ".mlp.experts.gate_up_proj" in key:
# Extract layer prefix (e.g., "model.layers.1.mlp.")
prefix = key.replace("experts.gate_up_proj", "")
intermediate_size = value.shape[1] // 2
for i in range(n_experts):
gate_up = value[i]
gate_weight = gate_up[:intermediate_size]
up_weight = gate_up[intermediate_size:]
custom_state_dict[f"{prefix}experts.{i}.gate_proj.weight"] = gate_weight
custom_state_dict[f"{prefix}experts.{i}.up_proj.weight"] = up_weight
elif ".mlp.experts.down_proj" in key:
prefix = key.replace("experts.down_proj", "")
for i in range(n_experts):
custom_state_dict[f"{prefix}experts.{i}.down_proj.weight"] = value[i]
else:
# Copy other keys as-is
custom_state_dict[key] = value
return custom_state_dict
def _create_hf_config():
"""Create HF config that matches our test config."""
HFConfig = _get_hf_config_class()
if HFConfig is None:
return None
config = HFConfig(
vocab_size=1000,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=3,
num_attention_heads=4,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=512,
rms_norm_eps=1e-5,
# MLA params
q_lora_rank=32,
kv_lora_rank=32,
qk_nope_head_dim=8,
qk_rope_head_dim=8,
v_head_dim=16,
# MoE params
n_routed_experts=4,
n_shared_experts=1,
num_experts_per_tok=2,
moe_intermediate_size=32,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
norm_topk_prob=True,
first_k_dense_replace=1,
# RoPE
rope_theta=10000.0,
rope_scaling=None,
# Other
attention_bias=False,
attention_dropout=0.0,
pad_token_id=0,
)
# Set internal attributes needed by HF's MoE implementation
# _experts_implementation tells HF which expert forward function to use
config._experts_implementation = "eager"
return config
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
def test_glm4_moe_lite_moe_numerical_equivalence(B, S, dtype):
"""Test MoE layer produces numerically equivalent output to HF implementation."""
HFMoE = _get_hf_moe_class()
if HFMoE is None:
pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)")
device = "cuda"
config = _create_small_config()
hf_config = _create_hf_config()
# Create HF MoE
hf_moe = HFMoE(hf_config)
hf_moe.to(device=device, dtype=dtype)
hf_moe.eval()
# Initialize weights with randn for reproducibility
# (HF uses torch.empty which may be zeros, causing NaN in computation)
hf_moe.gate.weight = torch.nn.Parameter(torch.randn_like(hf_moe.gate.weight))
hf_moe.experts.gate_up_proj = torch.nn.Parameter(torch.randn_like(hf_moe.experts.gate_up_proj))
hf_moe.experts.down_proj = torch.nn.Parameter(torch.randn_like(hf_moe.experts.down_proj))
# Create custom MoE and load converted weights
custom_moe = Glm4MoeLiteMoE(config)
custom_moe.to(device=device, dtype=dtype)
# Convert HF stacked expert weights to our per-expert format
hf_state_dict = hf_moe.state_dict()
# Debug: print state dict keys and shapes
print("\n=== HF MoE state_dict keys and shapes ===")
for k, v in hf_state_dict.items():
print(f" {k}: {v.shape}")
custom_state_dict = _convert_hf_moe_state_dict_to_custom(hf_state_dict, config.n_routed_experts)
print("\n=== Converted custom state_dict keys and shapes ===")
for k, v in custom_state_dict.items():
print(f" {k}: {v.shape}")
print("\n=== Expected custom MoE state_dict keys ===")
for k, v in custom_moe.state_dict().items():
print(f" {k}: {v.shape}")
custom_moe.load_state_dict(custom_state_dict)
custom_moe.eval()
# Sanity check: verify expert weights match after conversion
# HF has stacked weights: experts.gate_up_proj [n_experts, 2*intermediate, hidden]
# Our model has per-expert: experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight
hf_gate_up = hf_moe.experts.gate_up_proj # [n_experts, 2*intermediate, hidden]
hf_down = hf_moe.experts.down_proj # [n_experts, hidden, intermediate]
intermediate_size = config.moe_intermediate_size
print(f"\n=== Debug: intermediate_size = {intermediate_size} ===")
print(f"hf_gate_up shape: {hf_gate_up.shape}")
print(f"hf_gate_up[0, :2, :2]: {hf_gate_up[0, :2, :2]}")
# Get the converted state dict values for comparison
converted_gate_0 = custom_state_dict["experts.0.gate_proj.weight"]
print(f"converted_gate_0 shape: {converted_gate_0.shape}")
print(f"converted_gate_0[:2, :2]: {converted_gate_0[:2, :2]}")
# After load_state_dict
loaded_gate_0 = custom_moe.experts[0].gate_proj.weight
print(f"loaded_gate_0 shape: {loaded_gate_0.shape}")
print(f"loaded_gate_0[:2, :2]: {loaded_gate_0[:2, :2]}")
for i in range(config.n_routed_experts):
# Check gate_proj
hf_gate = hf_gate_up[i, :intermediate_size, :] # [intermediate, hidden]
custom_gate = custom_moe.experts[i].gate_proj.weight
torch.testing.assert_close(
custom_gate, hf_gate, msg=f"Expert {i} gate_proj weights don't match"
)
# Check up_proj
hf_up = hf_gate_up[i, intermediate_size:, :] # [intermediate, hidden]
custom_up = custom_moe.experts[i].up_proj.weight
torch.testing.assert_close(custom_up, hf_up, msg=f"Expert {i} up_proj weights don't match")
# Check down_proj
hf_down_i = hf_down[i] # [hidden, intermediate]
custom_down = custom_moe.experts[i].down_proj.weight
torch.testing.assert_close(
custom_down, hf_down_i, msg=f"Expert {i} down_proj weights don't match"
)
# Also verify gate weights match
torch.testing.assert_close(
custom_moe.gate.weight, hf_moe.gate.weight, msg="Gate weights don't match"
)
# Create input
H = config.hidden_size
x = torch.randn(B, S, H, device=device, dtype=dtype)
# Run both
hf_out = hf_moe(x)
custom_out = custom_moe(x)
# Handle tuple output from HF (output, router_logits)
if isinstance(hf_out, tuple):
hf_out = hf_out[0]
# Compare
rtol, atol = 0.05, 0.05
torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol)
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
def test_glm4_moe_lite_mlp_numerical_equivalence(B, S, dtype):
"""Test MLP layer produces numerically equivalent output to HF implementation."""
HFMLP = _get_hf_mlp_class()
if HFMLP is None:
pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)")
device = "cuda"
config = _create_small_config()
hf_config = _create_hf_config()
# Create HF MLP
hf_mlp = HFMLP(hf_config)
hf_mlp.to(device=device, dtype=dtype)
hf_mlp.eval()
# Create custom MLP and load same weights
custom_mlp = Glm4MoeLiteMLP(config)
custom_mlp.to(device=device, dtype=dtype)
custom_mlp.load_state_dict(hf_mlp.state_dict())
custom_mlp.eval()
# Create input
H = config.hidden_size
x = torch.randn(B, S, H, device=device, dtype=dtype)
# Run both
hf_out = hf_mlp(x)
custom_out = custom_mlp(x)
# Compare
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol)
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
def test_glm4_moe_lite_full_model_numerical_equivalence(B, S, dtype):
"""Test full model produces numerically equivalent output to HF implementation."""
HFModel = _get_hf_model_class()
if HFModel is None:
pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)")
device = "cuda"
config = _create_small_config()
hf_config = _create_hf_config()
# Create HF model
hf_model = HFModel(hf_config)
hf_model.to(device=device, dtype=dtype)
hf_model.eval()
# Initialize all gate weights for reproducibility
for module in hf_model.modules():
if hasattr(module, "gate") and hasattr(module.gate, "weight"):
module.gate.weight = torch.nn.Parameter(torch.randn_like(module.gate.weight))
# Create custom model and load converted weights
custom_model = Glm4MoeLiteForCausalLM(config)
custom_model.to(device=device, dtype=dtype)
# Convert HF stacked expert weights to our per-expert format
hf_state_dict = hf_model.state_dict()
custom_state_dict = _convert_hf_full_model_state_dict_to_custom(hf_state_dict, config)
custom_model.load_state_dict(custom_state_dict)
custom_model.eval()
# Create input
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
# Run both
hf_out = hf_model(input_ids=input_ids, position_ids=position_ids)
custom_out = custom_model(input_ids=input_ids, position_ids=position_ids)
# Compare logits - cast to same dtype for comparison
# (HF model may output float32 for numerical stability in lm_head)
rtol, atol = 0.05, 0.05
torch.testing.assert_close(
custom_out.logits.float(),
hf_out.logits.float(),
rtol=rtol,
atol=atol,
)

View File

@ -1,97 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorrt_llm._torch.auto_deploy # noqa: F401
def naive_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale):
bsz = q_nope.shape[0]
q_len = q_nope.shape[2]
num_heads = q_nope.shape[1]
qk_nope_head_dim = q_nope.shape[-1]
v_head_dim = wkv_b.weight.shape[-1] - qk_nope_head_dim
qk_head_dim = qk_nope_head_dim + q_pe.shape[-1]
k_pe = k_pe.view(bsz, q_len, 1, q_pe.shape[-1]).transpose(1, 2)
# Up project compressed_kv
kv = (
wkv_b(compressed_kv)
.view(bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
query_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim)
query_states[:, :, :, :qk_nope_head_dim] = q_nope
query_states[:, :, :, qk_nope_head_dim:] = q_pe
key_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim)
key_states[:, :, :, :qk_nope_head_dim] = k_nope
key_states[:, :, :, qk_nope_head_dim:] = k_pe
x = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
is_causal=False,
dropout_p=0.0,
scale=softmax_scale,
).transpose(1, 2)
return x
def mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale):
num_heads = q_nope.shape[1]
qk_nope_head_dim = q_nope.shape[-1]
v_head_dim = wkv_b.weight.shape[-1] - qk_nope_head_dim
kv_lora_rank = compressed_kv.shape[-1]
# Down project q_nope
wkv_b_weight = wkv_b.weight.view(num_heads, -1, kv_lora_rank)
q_nope_proj = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b_weight[:, :qk_nope_head_dim])
# MLA ref operation
x = torch.ops.auto_deploy.torch_attention_deepseek_mla(
q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale
)
# Up project attention scores
x = torch.einsum("bshc,hdc->bshd", x, wkv_b_weight[:, -v_head_dim:])
return x
def test_attn():
# Define test configurations
kv_lora_rank = 4
bsz = 2
q_len = 6
v_head_dim = 2
qk_nope_head_dim = 2
qk_rope_head_dim = 1
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
num_heads = 4
softmax_scale = qk_head_dim**-0.5
# Generate inputs
q_nope = torch.randn(bsz, num_heads, q_len, qk_nope_head_dim)
q_pe = torch.randn(bsz, num_heads, q_len, qk_rope_head_dim)
compressed_kv = torch.randn(bsz, q_len, kv_lora_rank)
k_pe = torch.randn(bsz, q_len, qk_rope_head_dim)
# Define w_kv_b projection matrix
wkv_b = nn.Linear(
kv_lora_rank, num_heads * (qk_head_dim - qk_rope_head_dim + v_head_dim), bias=False
)
# Compute naive attention
out_naive = naive_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale)
# Compute MLA attention
out_mla = mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale)
# Check if the two outputs are close
assert torch.allclose(out_naive, out_mla, rtol=1e-5, atol=1e-5)