mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
parent
d50f010fa9
commit
a2fb5afecf
@ -0,0 +1,348 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Deploying GLM-4.7-Flash with TensorRT-LLM\n",
|
||||
"\n",
|
||||
"This notebook walks you through deploying the `zai-org/GLM-4.7-Flash` model using TensorRT-LLM.\n",
|
||||
"\n",
|
||||
"[TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/) is NVIDIA's open-source library for accelerating and optimizing LLM inference on NVIDIA GPUs. Support for GLM-4.7-Flash is enabled through the AutoDeploy workflow. More details about AutoDeploy can be found [here](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).\n",
|
||||
"\n",
|
||||
"**Model Resources:**\n",
|
||||
"- [HuggingFace Model Card](https://huggingface.co/zai-org/GLM-4.7-Flash)\n",
|
||||
"- [Technical Blog](https://z.ai/blog/glm-4.7)\n",
|
||||
"- [Technical Report (GLM-4.5)](https://arxiv.org/abs/2508.06471)\n",
|
||||
"- [Z.ai API Platform](https://docs.z.ai/guides/llm/glm-4.7)\n",
|
||||
"\n",
|
||||
"**Model Highlights:**\n",
|
||||
"- 30B-A3B Mixture of Experts (MoE) architecture\n",
|
||||
"- 131,072 token context length\n",
|
||||
"- Tool calling support\n",
|
||||
"- MIT License\n",
|
||||
"\n",
|
||||
"**Prerequisites:**\n",
|
||||
"- NVIDIA GPU with recent drivers (≥ 64 GB VRAM for BF16) and CUDA 12.x\n",
|
||||
"- Python 3.10+\n",
|
||||
"- TensorRT-LLM ([container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release) or pip install)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prerequisites & Environment\n",
|
||||
"\n",
|
||||
"Set up a containerized environment for TensorRT-LLM by running the following command in a terminal:\n",
|
||||
"\n",
|
||||
"```shell\n",
|
||||
"docker run --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all -p 8000:8000 nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc1\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You now have TensorRT-LLM set up!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If pip not found\n",
|
||||
"!python -m ensurepip --default-pip"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install torch openai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Verify GPU\n",
|
||||
"\n",
|
||||
"Check that CUDA is available and the GPU is detected correctly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Python: 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0]\n",
|
||||
"CUDA available: True\n",
|
||||
"Num GPUs: 8\n",
|
||||
"GPU[0]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[1]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[2]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[3]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[4]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[5]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[6]: NVIDIA H100 80GB HBM3\n",
|
||||
"GPU[7]: NVIDIA H100 80GB HBM3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Environment check\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"print(f\"Python: {sys.version}\")\n",
|
||||
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
||||
"print(f\"Num GPUs: {torch.cuda.device_count()}\")\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" for i in range(torch.cuda.device_count()):\n",
|
||||
" print(f\"GPU[{i}]: {torch.cuda.get_device_name(i)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## OpenAI-Compatible Server\n",
|
||||
"\n",
|
||||
"Start a local OpenAI-compatible server with TensorRT-LLM via the terminal, within the running docker container.\n",
|
||||
"\n",
|
||||
"Ensure that the following commands are executed from the docker terminal.\n",
|
||||
"\n",
|
||||
"Start with the GLM 4.7 Flash Yaml here: `examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load the Model\n",
|
||||
"\n",
|
||||
"Launch the TensorRT-LLM server with GLM-4.7-Flash:\n",
|
||||
"\n",
|
||||
"```shell\n",
|
||||
"trtllm-serve \"zai-org/GLM-4.7-Flash\" \\\n",
|
||||
" --host 0.0.0.0 \\\n",
|
||||
" --port 8000 \\\n",
|
||||
" --backend _autodeploy \\\n",
|
||||
" --trust_remote_code \\\n",
|
||||
" --extra_llm_api_options examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Your server is now running!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use the API\n",
|
||||
"\n",
|
||||
"Use the OpenAI-compatible client to send requests to the TensorRT-LLM server."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"\n",
|
||||
"# Setup client\n",
|
||||
"BASE_URL = \"http://0.0.0.0:8000/v1\"\n",
|
||||
"API_KEY = \"null\"\n",
|
||||
"client = OpenAI(base_url=BASE_URL, api_key=API_KEY)\n",
|
||||
"\n",
|
||||
"MODEL_ID = \"zai-org/GLM-4.7-Flash\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chat Completion Example\n",
|
||||
"==================================================\n",
|
||||
"Response:\n",
|
||||
"1. **Analyze the Request:** The user wants to know 15% of 85 and wants to see the reasoning behind the calculation.\n",
|
||||
"\n",
|
||||
"2. **Identify the Core Task:** Calculate $15\\% \\times 85$.\n",
|
||||
"\n",
|
||||
"3. **Determine the Mathematical Approach:** There are several ways to solve this:\n",
|
||||
" * *Method 1: Fraction multiplication.* Convert 15% to a fraction ($\\frac{15}{100}$), then multiply by 85.\n",
|
||||
" * *Method 2: Decimal multiplication.* Convert 15% to a decimal ($0.15$), then multiply by 85.\n",
|
||||
" * *Method 3: Decomposition (Breaking it down).* $15\\% = 10\\% + 5\\%$.\n",
|
||||
" * $10\\%$ of $85 = 8.5$\n",
|
||||
" * $5\\%$ of $85 = \\frac{8.5}{2} = 4.25$\n",
|
||||
" * Sum: $8.5 + 4.25 = 12.75$\n",
|
||||
"\n",
|
||||
"4. **Select the Best Approach for Explanation:** Method 3 is often easiest for a general audience to follow step-by-step because it avoids dealing with decimals until the end or simplifies large multiplications. Method 2 is the most direct standard school method. I will use Method 3 (Splitting 15% into 10% and 5%) as the primary reasoning because it is intuitive, but I might briefly mention the standard formula ($\\frac{\\text{percent}}{100} \\times \\text{number}$).\n",
|
||||
"\n",
|
||||
"5. **Execute the Calculation (Method 3):**\n",
|
||||
" * Step 1: Find 10% of 85.\n",
|
||||
" * Moving the decimal point one place to the left: 8.5.\n",
|
||||
" * Step 2: Find 5% of 85.\n",
|
||||
" * Since 5% is half of 10%, take half of 8.5.\n",
|
||||
" * $8.5 / 2 = 4.25$.\n",
|
||||
" * Step 3: Add them together.\n",
|
||||
" * $8.5 + 4.25$.\n",
|
||||
" * $8.5 + 4 = 12.5$.\n",
|
||||
" * $12.5 + 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Basic chat completion\n",
|
||||
"print(\"Chat Completion Example\")\n",
|
||||
"print(\"=\" * 50)\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"What is 15% of 85? Show your reasoning.\"},\n",
|
||||
" ],\n",
|
||||
" temperature=1,\n",
|
||||
" top_p=0.95,\n",
|
||||
" max_tokens=512,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Response:\")\n",
|
||||
"print(response.choices[0].message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Streaming response:\n",
|
||||
"==================================================\n",
|
||||
"1. **Analyze the Request:** The user is asking for the \"first 5 prime numbers\".\n",
|
||||
"\n",
|
||||
"2. **Define \"Prime Number\":** A prime number is a natural number greater than 1 that is not a product of two smaller natural numbers. In other words, it has exactly two distinct positive divisors: 1 and itself.\n",
|
||||
"\n",
|
||||
"3. **Identify the First Numbers:**\n",
|
||||
" * Start checking from 1 (exclusive).\n",
|
||||
" * Check 2: Divisors are 1 and 2. Prime. (1st)\n",
|
||||
" * Check 3: Divisors are 1 and 3. Prime. (2nd)\n",
|
||||
" * Check 4: Divisors are 1, 2, 4. Not prime (2 * 2).\n",
|
||||
" * Check 5: Divisors are 1 and 5. Prime. (3rd)\n",
|
||||
" * Check 6: Divisors are 1, 2, 3, 6. Not prime.\n",
|
||||
" * Check 7: Divisors are 1 and 7. Prime. (4th)\n",
|
||||
" * Check 8: Divisors are 1, 2, 4, 8. Not prime.\n",
|
||||
" * Check 9: Divisors are 1, 3, 9. Not prime.\n",
|
||||
" * Check 10: Divisors are 1, 2, 5, 10. Not prime.\n",
|
||||
" * Check 11: Divisors are 1 and 11. Prime. (5th)\n",
|
||||
"\n",
|
||||
"4. **Compile the List:** 2, 3, 5, 7, 11.\n",
|
||||
"\n",
|
||||
"5. **Formulate the Output:** Present the list clearly.\n",
|
||||
"\n",
|
||||
"6. **Final Review:** Does this answer the user's prompt accurately? Yes.</think>The first 5 prime numbers are **2, 3, 5, 7, and 11**."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Streaming chat completion\n",
|
||||
"print(\"Streaming response:\")\n",
|
||||
"print(\"=\" * 50)\n",
|
||||
"\n",
|
||||
"stream = client.chat.completions.create(\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"What are the first 5 prime numbers?\"},\n",
|
||||
" ],\n",
|
||||
" temperature=0.7,\n",
|
||||
" max_tokens=1024,\n",
|
||||
" stream=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in stream:\n",
|
||||
" if chunk.choices[0].delta.content:\n",
|
||||
" print(chunk.choices[0].delta.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluation Parameters\n",
|
||||
"\n",
|
||||
"For optimal results, use the following parameters based on your task:\n",
|
||||
"\n",
|
||||
"**Default Settings (Most Tasks)**\n",
|
||||
"- `temperature`: 1.0\n",
|
||||
"- `top_p`: 0.95\n",
|
||||
"- `max_tokens`: 131072\n",
|
||||
"\n",
|
||||
"**Agentic Tasks (SWE-bench, Terminal Bench)**\n",
|
||||
"- `temperature`: 0.7\n",
|
||||
"- `top_p`: 1.0\n",
|
||||
"- `max_tokens`: 16384\n",
|
||||
"\n",
|
||||
"**Deterministic Tasks**\n",
|
||||
"- `temperature`: 0\n",
|
||||
"- `max_tokens`: 16384"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Additional Resources\n",
|
||||
"\n",
|
||||
"- [TensorRT-LLM Documentation](https://nvidia.github.io/TensorRT-LLM/)\n",
|
||||
"- [AutoDeploy Guide](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html)\n",
|
||||
"- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)\n",
|
||||
"- [Z.ai Discord Community](https://discord.gg/QR7SARHRxK)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@ -0,0 +1,8 @@
|
||||
compile_backend: torch-cudagraph
|
||||
max_batch_size: 64
|
||||
max_seq_len: 4096
|
||||
enable_chunked_prefill: true
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
|
||||
transforms:
|
||||
fuse_nvfp4_moe:
|
||||
allow_different_input_scales: true
|
||||
@ -223,3 +223,5 @@ models:
|
||||
yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'llama4_maverick_lite.yaml']
|
||||
- name: nvidia/NVIDIA-Nemotron-3-Super-120B-BF16-BF16KV-010726
|
||||
yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml','super_v3.yaml']
|
||||
- name: zai-org/GLM-4.7-Flash
|
||||
yaml_extra: ['glm-4.7-flash.yaml']
|
||||
|
||||
@ -162,15 +162,16 @@ transforms:
|
||||
visualize_namespace:
|
||||
stage: visualize
|
||||
enabled: false
|
||||
############################################################################################
|
||||
###########################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
###########################################################################################
|
||||
insert_cached_attention:
|
||||
stage: cache_init
|
||||
backend: flashinfer
|
||||
insert_cached_mla_attention:
|
||||
stage: cache_init
|
||||
backend: MultiHeadLatentAttention
|
||||
requires_shape_prop: true
|
||||
backend: flashinfer_mla
|
||||
insert_cached_ssm_attention:
|
||||
stage: cache_init
|
||||
backend: triton_ssm
|
||||
|
||||
@ -5,38 +5,3 @@ All AutoDeploy custom operators follow the following naming convention:
|
||||
`torch.ops.auto_deploy.<kernel_backend>_<op_category>_<op_name>`
|
||||
|
||||
The table below lists the operators ordered by their backend.
|
||||
|
||||
### Available Custom Operators
|
||||
|
||||
| Operator Name | Description |
|
||||
|--------------|-------------|
|
||||
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
|
||||
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
|
||||
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
|
||||
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
|
||||
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
|
||||
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) |
|
||||
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) |
|
||||
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
|
||||
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
|
||||
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
|
||||
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
|
||||
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
|
||||
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
|
||||
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
|
||||
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
|
||||
| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support |
|
||||
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
|
||||
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
|
||||
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation |
|
||||
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) |
|
||||
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) |
|
||||
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |
|
||||
|
||||
@ -19,7 +19,6 @@ import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@ -242,299 +241,3 @@ def torch_attention_fake(
|
||||
layout: str = "bnsd",
|
||||
):
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
|
||||
def update_kv_cache(
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor, # metadata
|
||||
input_pos: torch.Tensor, # metadata
|
||||
slot_idx: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
|
||||
This function can be used to build reference attention implementations that use KV cache.
|
||||
"""
|
||||
|
||||
for idx in range(seq_len.shape[0]):
|
||||
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
|
||||
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
|
||||
]
|
||||
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
|
||||
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
|
||||
]
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=())
|
||||
def fused_mla_ref(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
seq_len: torch.Tensor, # metadata
|
||||
input_pos: torch.Tensor, # metadata
|
||||
cache_loc: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
k_cache: torch.Tensor, # caches
|
||||
v_cache: torch.Tensor, # caches
|
||||
freqs_cis: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reference implementation for Fused MLA with KV cache support.
|
||||
This implementation flattens the inputs and can be used as a reference to debug the triton kernels.
|
||||
"""
|
||||
# Compute parameters
|
||||
bs, num_heads, q_len, qk_nope_head_dim = q_nope.shape
|
||||
qk_rope_head_dim = q_pe.shape[-1]
|
||||
v_head_dim = kv.shape[-1] - qk_nope_head_dim
|
||||
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
# Flatten inputs
|
||||
bs_view = (bs * q_len,)
|
||||
|
||||
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
q_nope = q_nope.transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
|
||||
q_pe = q_pe.transpose(1, 2).clone().view(*bs_view, num_heads, qk_rope_head_dim).contiguous()
|
||||
k_nope = k_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
|
||||
k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous()
|
||||
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
|
||||
|
||||
if freqs_cis is not None:
|
||||
cos_base = freqs_cis[0, ...]
|
||||
sin_base = freqs_cis[1, ...]
|
||||
for i in range(seq_len.shape[0]):
|
||||
start = seq_start[i]
|
||||
length = seq_len[i]
|
||||
if q_len == 1:
|
||||
idx = (input_pos[i] + length - 1).item()
|
||||
pos_ids = torch.tensor(idx, device=cos_base.device)
|
||||
else:
|
||||
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
|
||||
|
||||
cos = cos_base[pos_ids] # [..., 1, head_dim]
|
||||
sin = sin_base[pos_ids]
|
||||
q_slice = q_pe[start : start + length]
|
||||
k_slice = k_pe[start : start + length]
|
||||
|
||||
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q_slice,
|
||||
k_slice,
|
||||
cos,
|
||||
sin,
|
||||
-2,
|
||||
)
|
||||
|
||||
q_pe[start : start + length] = q_rot
|
||||
k_pe[start : start + length] = k_rot
|
||||
|
||||
query_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) # [b*s,n,d]
|
||||
query_states[..., :qk_nope_head_dim] = q_nope
|
||||
query_states[..., qk_nope_head_dim:] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim)
|
||||
key_states[..., :qk_nope_head_dim] = k_nope
|
||||
key_states[..., qk_nope_head_dim:] = k_pe
|
||||
|
||||
# Update KV cache
|
||||
update_kv_cache(
|
||||
key_states, value_states, k_cache, v_cache, seq_len, input_pos, cache_loc, seq_start
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
attn_outputs = []
|
||||
for idx in range(seq_len.shape[0]):
|
||||
# Get inputs from KV cache
|
||||
k = k_cache[cache_loc[idx], : input_pos[idx] + seq_len[idx], :, :] # [kv_seq_len, n, d]
|
||||
v = v_cache[cache_loc[idx], : input_pos[idx] + seq_len[idx], :, :] # [kv_seq_len, n, d]
|
||||
# Generate attention mask
|
||||
if q_len == 1:
|
||||
# Generate phase - single token attention mask
|
||||
attn_mask = torch.zeros(
|
||||
1, input_pos[idx] + 1, device=query_states.device, dtype=query_states.dtype
|
||||
)
|
||||
else:
|
||||
# Context phase - causal attention mask
|
||||
temp_mask = torch.ones(
|
||||
seq_len[idx],
|
||||
input_pos[idx] + seq_len[idx],
|
||||
dtype=torch.bool,
|
||||
device=query_states.device,
|
||||
).tril(diagonal=0)
|
||||
attn_bias = torch.zeros(
|
||||
seq_len[idx],
|
||||
input_pos[idx] + seq_len[idx],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
attn_mask = attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")).to(
|
||||
query_states.device
|
||||
)
|
||||
|
||||
# Compute attention weights
|
||||
attn_weights = (
|
||||
torch.matmul(
|
||||
query_states[seq_start[idx] : seq_start[idx] + seq_len[idx], :, :].transpose(0, 1),
|
||||
k.transpose(0, 1).transpose(1, 2),
|
||||
)
|
||||
* 1
|
||||
/ math.sqrt(query_states.size(-1))
|
||||
)
|
||||
attn_weights = attn_weights + attn_mask
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=0.0, training=False)
|
||||
attn_output = torch.matmul(attn_weights, v.transpose(0, 1))
|
||||
attn_outputs.append(attn_output)
|
||||
|
||||
if q_len == 1:
|
||||
attn_output = torch.stack(attn_outputs)
|
||||
else:
|
||||
attn_output = torch.cat(attn_outputs, dim=-2).unsqueeze(0)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
@fused_mla_ref.register_fake
|
||||
def fused_mla_ref_fake(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
seq_len: torch.Tensor, # metadata
|
||||
input_pos: torch.Tensor, # metadata
|
||||
cache_loc: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
k_cache: torch.Tensor, # caches
|
||||
v_cache: torch.Tensor, # caches
|
||||
freqs_cis: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
):
|
||||
"""Fake Fused MLA+Rope with KV cache support."""
|
||||
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_fused_mla", mutates_args=())
|
||||
def fused_mla(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""MultiHeadLatentAttention as implemented in DeepSeekV3Attention. This does not capture KV cache use/update."""
|
||||
# Did not implement KV cache logic, since we would be writing our own custom op and inserting KV caches later
|
||||
# Compute parameters
|
||||
bs, num_heads, q_len, qk_nope_head_dim = q_nope.shape
|
||||
qk_rope_head_dim = q_pe.shape[-1]
|
||||
v_head_dim = kv.shape[-1] - qk_nope_head_dim
|
||||
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
q_pe, k_pe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
|
||||
|
||||
query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
|
||||
query_states[:, :, :, :qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, qk_nope_head_dim:] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
|
||||
key_states[:, :, :, :qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, qk_nope_head_dim:] = k_pe
|
||||
|
||||
# Use old logic
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale
|
||||
|
||||
if attn_weights.size() != (bs, num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bs, num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
assert attention_mask is not None
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bs, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bs, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=0.0, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bs, num_heads, q_len, v_head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bs, num_heads, q_len, v_head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
# We do not return attn_weights along with attn_output
|
||||
return attn_output
|
||||
|
||||
|
||||
@fused_mla.register_fake
|
||||
def fused_mla(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_mla", mutates_args=())
|
||||
def mla(
|
||||
q_nope: torch.Tensor, # Down projected q_nope
|
||||
q_pe: torch.Tensor, # q_pe after applying rope
|
||||
kv: torch.Tensor, # compressed kv after passing through layernorm
|
||||
pe: torch.Tensor, # k_pe after applying rope
|
||||
attention_mask: torch.Tensor, # attention mask
|
||||
softmax_scale: float, # softmax scale
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reference implementation for MLA style attention that handles compressed kv.
|
||||
"""
|
||||
scores = (
|
||||
torch.einsum("bhsc,btc->bsht", q_nope, kv) + torch.einsum("bhsr,btr->bsht", q_pe, pe)
|
||||
) * softmax_scale
|
||||
if attention_mask is not None:
|
||||
scores += attention_mask.unsqueeze(1)
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(q_nope)
|
||||
attn_output = torch.einsum("bsht,btc->bshc", scores, kv)
|
||||
return attn_output
|
||||
|
||||
|
||||
@mla.register_fake
|
||||
def mla(
|
||||
q_nope: torch.Tensor, # Down projected q_nope
|
||||
q_pe: torch.Tensor, # q_pe after applying rope
|
||||
kv: torch.Tensor, # compressed kv after passing through layernorm
|
||||
k_pe: torch.Tensor, # k_pe after applying rope
|
||||
attention_mask: torch.Tensor, # attention mask
|
||||
softmax_scale: float, # softmax scale
|
||||
) -> torch.Tensor:
|
||||
"""MLA style attention that handles compressed kv."""
|
||||
return torch.empty_like(q_nope)
|
||||
|
||||
@ -35,7 +35,31 @@ from ..attention_interface import (
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .torch_attention import repeat_kv, update_kv_cache
|
||||
from .torch_attention import repeat_kv
|
||||
|
||||
|
||||
def _update_kv_cache(
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor, # metadata
|
||||
input_pos: torch.Tensor, # metadata
|
||||
slot_idx: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
|
||||
This function can be used to build reference attention implementations that use KV cache.
|
||||
"""
|
||||
|
||||
for idx in range(seq_len.shape[0]):
|
||||
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
|
||||
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
|
||||
]
|
||||
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
|
||||
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
|
||||
]
|
||||
|
||||
|
||||
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
|
||||
@ -149,7 +173,7 @@ def _torch_context_mha(
|
||||
) -> None:
|
||||
"""Context attention (multiple tokens, potentially multiple sequences) using existing torch functions."""
|
||||
# Update KV cache first using existing function
|
||||
update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
|
||||
_update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
|
||||
|
||||
# Compute attention for each sequence
|
||||
attn_outputs = []
|
||||
|
||||
@ -1,24 +1,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MLA (Multi-head Latent Attention) custom ops.
|
||||
|
||||
"""Multi-head Latent Attention operations.
|
||||
|
||||
This module provides Multi-head Latent Attention (MLA) implementations:
|
||||
- mla: MLA operations and attention descriptor
|
||||
Exports:
|
||||
- TorchBackendMLAAttention: Attention descriptor for MLA (registered as "torch_mla")
|
||||
- FlashInferMLAAttention: Attention descriptor for FlashInfer MLA (registered as "flashinfer_mla")
|
||||
- torch_mla: Source op for MLA attention
|
||||
- torch_backend_mla_with_cache: Cached backend op with FlashInfer-compatible cache
|
||||
- flashinfer_mla_with_cache: Cached backend op using FlashInfer MLA kernels
|
||||
"""
|
||||
|
||||
from .flashinfer_mla import FlashInferMLAAttention, flashinfer_mla_with_cache
|
||||
from .torch_backend_mla import TorchBackendMLAAttention, torch_backend_mla_with_cache
|
||||
from .torch_mla import torch_mla
|
||||
|
||||
__all__ = [
|
||||
"mla",
|
||||
"TorchBackendMLAAttention",
|
||||
"FlashInferMLAAttention",
|
||||
"torch_mla",
|
||||
"torch_backend_mla_with_cache",
|
||||
"flashinfer_mla_with_cache",
|
||||
]
|
||||
|
||||
957
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py
Normal file
957
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py
Normal file
@ -0,0 +1,957 @@
|
||||
"""FlashInfer-based MLA (Multi-head Latent Attention) backend with paged caching.
|
||||
|
||||
This module provides:
|
||||
- FlashInferMLAAttention: attention descriptor using FlashInfer MLA kernels
|
||||
- flashinfer_mla_with_cache: cached backend op with paged KV cache
|
||||
|
||||
FlashInfer MLA uses:
|
||||
- Regular prefill (input_pos == 0): BatchPrefillWithRaggedKVCacheWrapper with expanded K, V
|
||||
- Chunked prefill (input_pos > 0): BatchMLAPagedAttentionWrapper with matrix absorption
|
||||
- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache
|
||||
|
||||
FlashInfer MLA Cache Layout (two separate caches):
|
||||
ckv_cache: [num_pages, page_size, kv_lora_rank]
|
||||
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
|
||||
- No num_heads dimension (MLA-specific optimization)
|
||||
|
||||
Reference: https://docs.flashinfer.ai/api/mla.html
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
from math import prod
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.cuda_graph import cuda_graph_state
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
Constant,
|
||||
MHACallable,
|
||||
PrepareMetadataCallable,
|
||||
PrepareMetadataHostCallable,
|
||||
ResourceHandler,
|
||||
ResourceHandlerDict,
|
||||
SequenceInfo,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLADecodePlanParams:
|
||||
"""Parameters that affect the FlashInfer MLA decode execution plan."""
|
||||
|
||||
num_heads: int
|
||||
kv_lora_rank: int # head_dim_ckv
|
||||
qk_rope_head_dim: int # head_dim_kpe
|
||||
qk_nope_head_dim: int
|
||||
v_head_dim: int
|
||||
num_seq: int
|
||||
page_size: int
|
||||
q_dtype: torch.dtype
|
||||
kv_dtype: torch.dtype
|
||||
sm_scale: Optional[float] = None
|
||||
|
||||
def __hash__(self):
|
||||
"""Convert all fields to a string representation and concatenate them."""
|
||||
return hash("_".join([str(getattr(self, f.name)) for f in fields(self)]))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAPrefillPlanParams:
|
||||
"""Parameters that affect the FlashInfer MLA prefill execution plan."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int # For MLA with expanded KV, same as num_heads
|
||||
head_dim_qk: int # qk_nope_head_dim + qk_rope_head_dim
|
||||
head_dim_vo: int # v_head_dim (value/output head dimension)
|
||||
num_seq: int
|
||||
q_dtype: torch.dtype
|
||||
kv_dtype: torch.dtype
|
||||
sm_scale: Optional[float] = None
|
||||
|
||||
def __hash__(self):
|
||||
"""Convert all fields to a string representation and concatenate them."""
|
||||
return hash("_".join([str(getattr(self, f.name)) for f in fields(self)]))
|
||||
|
||||
|
||||
class _FlashInferMLAPlanner:
|
||||
"""A class interface to handle FlashInfer MLA-related planning/wrapping operations.
|
||||
|
||||
For MLA attention:
|
||||
- Regular prefill uses BatchPrefillWithRaggedKVCacheWrapper with expanded K, V tensors
|
||||
- Chunked prefill uses BatchMLAPagedAttentionWrapper with matrix absorption (same as decode)
|
||||
- Decode uses BatchMLAPagedAttentionWrapper with paged compressed KV cache
|
||||
"""
|
||||
|
||||
workspace_buffer: Optional[torch.Tensor]
|
||||
prefill_wrapper: Optional[flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper]
|
||||
decode_wrapper: Optional["flashinfer.mla.BatchMLAPagedAttentionWrapper"]
|
||||
# Separate wrapper for chunked/incremental prefill (uses same kernel as decode but different planning)
|
||||
chunked_prefill_wrapper: Optional["flashinfer.mla.BatchMLAPagedAttentionWrapper"]
|
||||
cached_cuda_graph_decode_wrappers: Dict[
|
||||
MLADecodePlanParams, "flashinfer.mla.BatchMLAPagedAttentionWrapper"
|
||||
]
|
||||
plan_params_prefill: Optional[MLAPrefillPlanParams]
|
||||
plan_params_decode: Optional[MLADecodePlanParams]
|
||||
plan_params_chunked_prefill: Optional[MLADecodePlanParams]
|
||||
kv_layout: Literal["NHD", "HND"] = "NHD"
|
||||
|
||||
def __init__(self):
|
||||
self.workspace_buffer = None
|
||||
self.prefill_wrapper = None
|
||||
self.decode_wrapper = None
|
||||
self.chunked_prefill_wrapper = None
|
||||
self.cached_cuda_graph_decode_wrappers = {}
|
||||
self.plan_params_prefill = None
|
||||
self.plan_params_decode = None
|
||||
self.plan_params_chunked_prefill = None
|
||||
|
||||
def _init_decode_wrapper(
|
||||
self,
|
||||
use_cuda_graph: bool = False,
|
||||
qo_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_len_arr: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert self.workspace_buffer is not None
|
||||
if use_cuda_graph:
|
||||
return flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=kv_indptr,
|
||||
kv_indices=kv_indices,
|
||||
kv_len_arr=kv_len_arr,
|
||||
)
|
||||
else:
|
||||
return flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
def reset(self, device: torch.device) -> None:
|
||||
self.plan_params_prefill = None
|
||||
self.plan_params_decode = None
|
||||
self.plan_params_chunked_prefill = None
|
||||
|
||||
if isinstance(self.workspace_buffer, torch.Tensor):
|
||||
return
|
||||
|
||||
self.__init__() # reset all state
|
||||
|
||||
# NOTE: avoid OOM for many cudagraphs
|
||||
self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8)
|
||||
|
||||
# Prefill uses BatchPrefillWithRaggedKVCacheWrapper with expanded K, V
|
||||
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
self.kv_layout,
|
||||
)
|
||||
# Decode uses BatchMLAPagedAttentionWrapper with paged compressed KV cache
|
||||
self.decode_wrapper = self._init_decode_wrapper()
|
||||
# Chunked prefill uses same kernel as decode but with variable-length queries
|
||||
self.chunked_prefill_wrapper = self._init_decode_wrapper()
|
||||
|
||||
def plan_prefill(
|
||||
self,
|
||||
qo_indptr_host: torch.Tensor,
|
||||
kv_indptr_host: torch.Tensor,
|
||||
plan_params: MLAPrefillPlanParams,
|
||||
) -> flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper:
|
||||
"""Plan prefill using BatchPrefillWithRaggedKVCacheWrapper.
|
||||
|
||||
For MLA prefill, we expand compressed_kv to get full K, V tensors
|
||||
and use standard ragged KV cache attention with causal masking.
|
||||
|
||||
Args:
|
||||
qo_indptr_host: Cumulative query/output lengths on host.
|
||||
kv_indptr_host: Cumulative key/value lengths on host.
|
||||
plan_params: Parameters for planning (hashable, no tensors).
|
||||
"""
|
||||
if plan_params != self.plan_params_prefill:
|
||||
self.prefill_wrapper.plan(
|
||||
qo_indptr_host,
|
||||
kv_indptr_host,
|
||||
plan_params.num_heads,
|
||||
plan_params.num_kv_heads,
|
||||
plan_params.head_dim_qk,
|
||||
head_dim_vo=plan_params.head_dim_vo,
|
||||
use_fp16_qk_reduction=False,
|
||||
causal=True,
|
||||
q_data_type=plan_params.q_dtype,
|
||||
kv_data_type=plan_params.kv_dtype,
|
||||
sm_scale=plan_params.sm_scale,
|
||||
)
|
||||
self.plan_params_prefill = plan_params
|
||||
|
||||
return self.prefill_wrapper
|
||||
|
||||
def _plan_mla_wrapper(
|
||||
self,
|
||||
wrapper: "flashinfer.mla.BatchMLAPagedAttentionWrapper",
|
||||
qo_indptr: torch.Tensor,
|
||||
kv_page_indptr: torch.Tensor,
|
||||
kv_page_indices: torch.Tensor,
|
||||
kv_last_page_len: torch.Tensor,
|
||||
plan_params: MLADecodePlanParams,
|
||||
):
|
||||
"""Helper to plan a BatchMLAPagedAttentionWrapper."""
|
||||
# Compute actual KV lengths from paging metadata:
|
||||
# kv_len = (num_pages - 1) * page_size + last_page_len
|
||||
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
|
||||
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_page_indptr,
|
||||
kv_page_indices,
|
||||
kv_len_arr,
|
||||
plan_params.num_heads,
|
||||
plan_params.kv_lora_rank, # head_dim_ckv
|
||||
plan_params.qk_rope_head_dim, # head_dim_kpe
|
||||
plan_params.page_size,
|
||||
causal=True,
|
||||
q_data_type=plan_params.q_dtype,
|
||||
kv_data_type=plan_params.kv_dtype,
|
||||
sm_scale=plan_params.sm_scale,
|
||||
)
|
||||
|
||||
def plan_decode(
|
||||
self,
|
||||
kv_page_indptr: torch.Tensor,
|
||||
kv_page_indices: torch.Tensor,
|
||||
kv_last_page_len: torch.Tensor,
|
||||
plan_params: MLADecodePlanParams,
|
||||
) -> "flashinfer.mla.BatchMLAPagedAttentionWrapper":
|
||||
"""Plan decode using BatchMLAPagedAttentionWrapper.
|
||||
|
||||
For MLA decode, we use the paged compressed KV cache with
|
||||
FlashInfer's optimized MLA kernels. Each sequence generates 1 token.
|
||||
|
||||
Args:
|
||||
kv_page_indptr: Cumulative page counts [batch_size + 1].
|
||||
kv_page_indices: Page indices for the KV cache.
|
||||
kv_last_page_len: Length of the last page per sequence.
|
||||
plan_params: Parameters for planning.
|
||||
"""
|
||||
# Decode qo_indptr: [0, 1, 2, ..., batch_size] (1 token per sequence)
|
||||
batch_size = kv_page_indptr.shape[0] - 1
|
||||
qo_indptr = torch.arange(batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32)
|
||||
|
||||
# Compute kv_len_arr for CUDA graph wrapper initialization
|
||||
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
|
||||
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
|
||||
|
||||
# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
|
||||
if (
|
||||
cuda_graph_state.in_warm_up()
|
||||
and plan_params not in self.cached_cuda_graph_decode_wrappers
|
||||
):
|
||||
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
|
||||
# Pass the buffer tensors to the wrapper for use_cuda_graph=True
|
||||
wrapper = self._init_decode_wrapper(
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=kv_page_indptr,
|
||||
kv_indices=kv_page_indices,
|
||||
kv_len_arr=kv_len_arr,
|
||||
)
|
||||
self.cached_cuda_graph_decode_wrappers[plan_params] = wrapper
|
||||
self._plan_mla_wrapper(
|
||||
wrapper, qo_indptr, kv_page_indptr, kv_page_indices, kv_last_page_len, plan_params
|
||||
)
|
||||
|
||||
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
|
||||
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
|
||||
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
|
||||
return wrapper
|
||||
|
||||
# Re-plan if plan_params changed
|
||||
if plan_params != self.plan_params_decode:
|
||||
self._plan_mla_wrapper(
|
||||
self.decode_wrapper,
|
||||
qo_indptr,
|
||||
kv_page_indptr,
|
||||
kv_page_indices,
|
||||
kv_last_page_len,
|
||||
plan_params,
|
||||
)
|
||||
self.plan_params_decode = plan_params
|
||||
|
||||
return self.decode_wrapper
|
||||
|
||||
def plan_chunked_prefill(
|
||||
self,
|
||||
qo_indptr: torch.Tensor,
|
||||
kv_page_indptr: torch.Tensor,
|
||||
kv_page_indices: torch.Tensor,
|
||||
kv_last_page_len: torch.Tensor,
|
||||
plan_params: MLADecodePlanParams,
|
||||
) -> "flashinfer.mla.BatchMLAPagedAttentionWrapper":
|
||||
"""Plan chunked/incremental prefill using BatchMLAPagedAttentionWrapper.
|
||||
|
||||
For chunked prefill (input_pos > 0), we use the same kernel as decode but with
|
||||
variable-length queries. Each sequence can have multiple tokens.
|
||||
|
||||
Args:
|
||||
qo_indptr: Cumulative query lengths [batch_size + 1].
|
||||
kv_page_indptr: Cumulative page counts [batch_size + 1].
|
||||
kv_page_indices: Page indices for the KV cache.
|
||||
kv_last_page_len: Length of the last page per sequence.
|
||||
plan_params: Parameters for planning.
|
||||
"""
|
||||
# Re-plan if plan_params changed
|
||||
if plan_params != self.plan_params_chunked_prefill:
|
||||
self._plan_mla_wrapper(
|
||||
self.chunked_prefill_wrapper,
|
||||
qo_indptr,
|
||||
kv_page_indptr,
|
||||
kv_page_indices,
|
||||
kv_last_page_len,
|
||||
plan_params,
|
||||
)
|
||||
self.plan_params_chunked_prefill = plan_params
|
||||
|
||||
return self.chunked_prefill_wrapper
|
||||
|
||||
def plan_generate_only(
|
||||
self,
|
||||
num_seq: int,
|
||||
cu_num_pages: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
):
|
||||
"""Plan decode-only batches for cached CUDA graph wrappers.
|
||||
|
||||
This is called from the host-side preparation function to plan
|
||||
the decode wrappers for decode-only batches before the actual
|
||||
attention op is invoked.
|
||||
|
||||
Args:
|
||||
num_seq: Number of sequences in the decode batch.
|
||||
cu_num_pages: Cumulative page counts, already sliced to [: num_seq + 1].
|
||||
cache_loc: Page indices for the KV cache.
|
||||
last_page_len: Length of the last page per sequence, already sliced to [:num_seq].
|
||||
"""
|
||||
for plan_params in self.cached_cuda_graph_decode_wrappers:
|
||||
if plan_params.num_seq == num_seq:
|
||||
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
|
||||
|
||||
# For a pure decode batch, qo_indptr is just [0, 1, 2, ..., batch_size]
|
||||
qo_indptr = torch.arange(num_seq + 1, device=cu_num_pages.device, dtype=torch.int32)
|
||||
|
||||
# Compute actual KV lengths from paging metadata:
|
||||
# kv_len = (num_pages - 1) * page_size + last_page_len
|
||||
num_pages_per_seq = cu_num_pages[1:] - cu_num_pages[:-1]
|
||||
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + last_page_len
|
||||
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
cu_num_pages, # kv_page_indptr
|
||||
cache_loc, # kv_page_indices
|
||||
kv_len_arr,
|
||||
plan_params.num_heads,
|
||||
plan_params.kv_lora_rank, # head_dim_ckv
|
||||
plan_params.qk_rope_head_dim, # head_dim_kpe
|
||||
plan_params.page_size,
|
||||
causal=True,
|
||||
q_data_type=plan_params.q_dtype,
|
||||
kv_data_type=plan_params.kv_dtype,
|
||||
sm_scale=plan_params.sm_scale,
|
||||
)
|
||||
|
||||
|
||||
_GlobalFlashInferMLAPlanner = _FlashInferMLAPlanner()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_mla_prepare_metadata", mutates_args=())
|
||||
def prepare_flashinfer_mla_metadata(
|
||||
position_ids: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
seq_len_with_cache: torch.Tensor,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for FlashInfer MLA attention.
|
||||
|
||||
This prepares batch_indices and positions for cache appends, similar to
|
||||
the standard FlashInfer attention preparation.
|
||||
"""
|
||||
# retrieve host-side metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
_GlobalFlashInferMLAPlanner.reset(position_ids.device)
|
||||
|
||||
qo_indptr = cu_seqlen[: num_seq + 1]
|
||||
|
||||
# Compute batch_indices and positions for cache appends
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr, seq_len_with_cache[:num_seq], num_tokens
|
||||
)
|
||||
|
||||
return batch_indices, positions
|
||||
|
||||
|
||||
@prepare_flashinfer_mla_metadata.register_fake
|
||||
def prepare_flashinfer_mla_metadata_fake(
|
||||
position_ids: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
seq_len_with_cache: torch.Tensor,
|
||||
):
|
||||
num_tokens = position_ids.shape[0] * position_ids.shape[1]
|
||||
return (
|
||||
torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), # batch_indices
|
||||
torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), # positions
|
||||
)
|
||||
|
||||
|
||||
def prepare_flashinfer_mla_metadata_host(
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_num_pages_host: torch.Tensor,
|
||||
cache_loc_host: torch.Tensor,
|
||||
last_page_len_host: torch.Tensor,
|
||||
) -> None:
|
||||
"""Host-side preparation for FlashInfer MLA attention.
|
||||
|
||||
For decode-only batches, this function pre-plans the cached CUDA graph
|
||||
wrappers to avoid planning during graph capture/replay.
|
||||
"""
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
|
||||
if num_prefill == 0:
|
||||
_GlobalFlashInferMLAPlanner.plan_generate_only(
|
||||
num_decode,
|
||||
cu_num_pages_host[: num_decode + 1],
|
||||
cache_loc_host,
|
||||
last_page_len_host[:num_decode],
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=())
|
||||
def flashinfer_mla_with_cache(
|
||||
# 5 tensor args (matching torch_mla source op)
|
||||
q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim]
|
||||
q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim]
|
||||
compressed_kv: torch.Tensor, # [B, S, kv_lora_rank]
|
||||
kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim]
|
||||
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# Standard paged metadata
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen_host: torch.Tensor,
|
||||
cu_num_pages: torch.Tensor,
|
||||
cu_num_pages_host: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
last_page_len_host: torch.Tensor,
|
||||
seq_len_with_cache_host: torch.Tensor,
|
||||
# Extra FlashInfer metadata
|
||||
flashinfer_batch_indices: torch.Tensor,
|
||||
flashinfer_positions: torch.Tensor,
|
||||
# Paged caches (two separate caches)
|
||||
ckv_cache: torch.Tensor, # [num_pages, page_size, kv_lora_rank]
|
||||
kpe_cache: torch.Tensor, # [num_pages, page_size, qk_rope_head_dim]
|
||||
# Constants
|
||||
scale: Optional[float],
|
||||
kv_lora_rank: int,
|
||||
) -> torch.Tensor:
|
||||
"""FlashInfer MLA attention with paged cache.
|
||||
|
||||
Uses FlashInfer's optimized kernels:
|
||||
- Prefill: BatchPrefillWithRaggedKVCacheWrapper with expanded K, V tensors
|
||||
- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache
|
||||
|
||||
FlashInfer MLA Cache Layout (two separate caches):
|
||||
ckv_cache: [num_pages, page_size, kv_lora_rank]
|
||||
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
|
||||
|
||||
Args:
|
||||
q_nope: Query non-positional component [B, S, N, qk_nope_head_dim]
|
||||
q_pe: Query positional component [B, S, N, qk_rope_head_dim]
|
||||
compressed_kv: Compressed KV latent [B, S, kv_lora_rank]
|
||||
kpe: Key positional encoding [B, S, 1, qk_rope_head_dim]
|
||||
kv_b_proj_weight: Projection weight [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
(metadata args): Standard paged attention metadata
|
||||
ckv_cache: Paged cache for compressed KV
|
||||
kpe_cache: Paged cache for key positional encoding
|
||||
scale: Softmax scale factor
|
||||
kv_lora_rank: Rank of compressed KV
|
||||
|
||||
Returns:
|
||||
Attention output [B, S, N, v_head_dim]
|
||||
"""
|
||||
# Get dimensions
|
||||
b, s = q_nope.shape[:2]
|
||||
num_heads = q_nope.shape[2]
|
||||
qk_nope_head_dim = q_nope.shape[3]
|
||||
qk_rope_head_dim = q_pe.shape[3]
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
# Infer v_head_dim from kv_b_proj_weight
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
# Get batch info
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
# Set scale
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(qk_head_dim)
|
||||
|
||||
page_size = ckv_cache.shape[1]
|
||||
|
||||
# Flatten inputs to [total_tokens, ...] format
|
||||
bs = b * s
|
||||
q_nope_flat = q_nope.contiguous().view(bs, num_heads, qk_nope_head_dim)
|
||||
q_pe_flat = q_pe.contiguous().view(bs, num_heads, qk_rope_head_dim)
|
||||
compressed_kv_flat = compressed_kv.contiguous().view(bs, kv_lora_rank)
|
||||
kpe_flat = kpe.contiguous().view(bs, qk_rope_head_dim)
|
||||
|
||||
# Convert cache dtype if needed
|
||||
if ckv_cache.dtype == torch.float8_e4m3fn:
|
||||
compressed_kv_flat = compressed_kv_flat.to(torch.float8_e4m3fn)
|
||||
kpe_flat = kpe_flat.to(torch.float8_e4m3fn)
|
||||
|
||||
# Append to paged cache using FlashInfer's append function
|
||||
# Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager
|
||||
flashinfer.page.append_paged_mla_kv_cache(
|
||||
compressed_kv_flat,
|
||||
kpe_flat,
|
||||
flashinfer_batch_indices,
|
||||
flashinfer_positions,
|
||||
ckv_cache,
|
||||
kpe_cache,
|
||||
cache_loc,
|
||||
cu_num_pages[: num_seq + 1],
|
||||
last_page_len[:num_seq],
|
||||
)
|
||||
|
||||
# Pre-allocate output
|
||||
if num_prefill > 0 and num_decode > 0:
|
||||
y = torch.empty(bs, num_heads, v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
else:
|
||||
y = None
|
||||
|
||||
# =========================================================================
|
||||
# PREFILL phase: Use BatchPrefillWithRaggedKVCacheWrapper for regular prefill
|
||||
# or BatchMLAPagedAttentionWrapper for chunked prefill
|
||||
# =========================================================================
|
||||
if num_prefill > 0:
|
||||
q_nope_prefill = q_nope_flat[:num_prefill_tokens]
|
||||
q_pe_prefill = q_pe_flat[:num_prefill_tokens]
|
||||
compressed_kv_prefill = compressed_kv_flat[:num_prefill_tokens]
|
||||
kpe_prefill = kpe_flat[:num_prefill_tokens]
|
||||
|
||||
# Check if any prefill sequence has cached tokens (chunked prefill)
|
||||
# seq_len_with_cache > current_seq_len means there are cached tokens
|
||||
q_lens = cu_seqlen_host[1 : num_prefill + 1] - cu_seqlen_host[:num_prefill]
|
||||
kv_lens = seq_len_with_cache_host[:num_prefill]
|
||||
is_chunked_prefill = (kv_lens > q_lens).any().item()
|
||||
|
||||
if is_chunked_prefill:
|
||||
# =================================================================
|
||||
# CHUNKED PREFILL: Use BatchMLAPagedAttentionWrapper with absorption
|
||||
# Same approach as decode, but with variable-length Q sequences
|
||||
# =================================================================
|
||||
|
||||
# Extract W_kn and W_v from kv_b_proj_weight
|
||||
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
||||
kv_b_proj_reshaped = kv_b_proj_weight.view(
|
||||
num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank
|
||||
)
|
||||
# W_kn: [N, qk_nope_head_dim, kv_lora_rank]
|
||||
w_kn = kv_b_proj_reshaped[:, :qk_nope_head_dim, :]
|
||||
# W_v: [N, v_head_dim, kv_lora_rank]
|
||||
w_v = kv_b_proj_reshaped[:, qk_nope_head_dim:, :]
|
||||
|
||||
# Absorb W_kn into q_nope:
|
||||
# q_nope_prefill: [num_prefill_tokens, N, qk_nope_head_dim]
|
||||
# w_kn: [N, qk_nope_head_dim, kv_lora_rank]
|
||||
# q_nope_absorbed: [num_prefill_tokens, N, kv_lora_rank]
|
||||
q_nope_absorbed = torch.einsum("bnd,ndk->bnk", q_nope_prefill, w_kn).contiguous()
|
||||
|
||||
# Build qo_indptr for variable-length prefill sequences
|
||||
qo_indptr = cu_seqlen_host[: num_prefill + 1].to(
|
||||
device=cu_num_pages.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
pp_chunked = MLADecodePlanParams(
|
||||
num_heads=num_heads,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
num_seq=num_prefill,
|
||||
page_size=page_size,
|
||||
q_dtype=q_nope.dtype,
|
||||
kv_dtype=ckv_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
wrapper_chunked = _GlobalFlashInferMLAPlanner.plan_chunked_prefill(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_page_indptr=cu_num_pages[: num_prefill + 1],
|
||||
kv_page_indices=cache_loc,
|
||||
kv_last_page_len=last_page_len[:num_prefill],
|
||||
plan_params=pp_chunked,
|
||||
)
|
||||
|
||||
# Run paged MLA attention in compressed space
|
||||
y_prefill_compressed = wrapper_chunked.run(
|
||||
q_nope_absorbed,
|
||||
q_pe_prefill,
|
||||
ckv_cache,
|
||||
kpe_cache,
|
||||
)
|
||||
|
||||
# Project output back from latent space to v_head_dim
|
||||
# y_prefill_compressed: [num_prefill_tokens, N, kv_lora_rank]
|
||||
# w_v: [N, v_head_dim, kv_lora_rank]
|
||||
# y_prefill: [num_prefill_tokens, N, v_head_dim]
|
||||
y_prefill = torch.einsum("bnk,nvk->bnv", y_prefill_compressed, w_v)
|
||||
|
||||
else:
|
||||
# =================================================================
|
||||
# REGULAR PREFILL: Use BatchPrefillWithRaggedKVCacheWrapper
|
||||
# Expand compressed_kv to K, V and use ragged attention
|
||||
# =================================================================
|
||||
|
||||
# Expand compressed_kv using kv_b_proj_weight to get k_nope and v
|
||||
# compressed_kv: [num_prefill_tokens, kv_lora_rank]
|
||||
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# kv_expanded: [num_prefill_tokens, N * (qk_nope_head_dim + v_head_dim)]
|
||||
kv_expanded = torch.matmul(compressed_kv_prefill, kv_b_proj_weight.t())
|
||||
kv_expanded = kv_expanded.view(
|
||||
num_prefill_tokens, num_heads, qk_nope_head_dim + v_head_dim
|
||||
)
|
||||
|
||||
# Split into k_nope and v
|
||||
k_nope_prefill = kv_expanded[:, :, :qk_nope_head_dim] # [tokens, N, qk_nope_head_dim]
|
||||
v_prefill = kv_expanded[:, :, qk_nope_head_dim:].contiguous() # [tokens, N, v_head_dim]
|
||||
|
||||
# Expand kpe to all heads: [tokens, qk_rope_head_dim] -> [tokens, N, qk_rope_head_dim]
|
||||
kpe_expanded = kpe_prefill.unsqueeze(1).expand(-1, num_heads, -1).contiguous()
|
||||
|
||||
# Concatenate to form full Q and K
|
||||
# Q: [tokens, N, qk_head_dim]
|
||||
q_prefill = torch.cat([q_nope_prefill, q_pe_prefill], dim=-1).contiguous()
|
||||
# K: [tokens, N, qk_head_dim]
|
||||
k_prefill = torch.cat([k_nope_prefill, kpe_expanded], dim=-1).contiguous()
|
||||
|
||||
pp_prefill = MLAPrefillPlanParams(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_heads, # For MLA with expanded KV, same as num_heads
|
||||
head_dim_qk=qk_head_dim,
|
||||
head_dim_vo=v_head_dim,
|
||||
num_seq=num_prefill,
|
||||
q_dtype=q_nope.dtype,
|
||||
kv_dtype=k_prefill.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
wrapper_prefill = _GlobalFlashInferMLAPlanner.plan_prefill(
|
||||
qo_indptr_host=cu_seqlen_host[: num_prefill + 1],
|
||||
kv_indptr_host=cu_seqlen_host[: num_prefill + 1], # Same as qo for self-attention
|
||||
plan_params=pp_prefill,
|
||||
)
|
||||
|
||||
y_prefill = wrapper_prefill.run(
|
||||
q_prefill,
|
||||
k_prefill,
|
||||
v_prefill,
|
||||
)
|
||||
|
||||
if y is not None:
|
||||
y[:num_prefill_tokens] = y_prefill
|
||||
else:
|
||||
y = y_prefill
|
||||
|
||||
# =========================================================================
|
||||
# DECODE phase: Use BatchMLAPagedAttentionWrapper with paged compressed KV
|
||||
# =========================================================================
|
||||
if num_decode > 0:
|
||||
q_nope_decode = q_nope_flat[num_prefill_tokens:num_total_tokens].contiguous()
|
||||
q_pe_decode = q_pe_flat[num_prefill_tokens:num_total_tokens].contiguous()
|
||||
|
||||
# FlashInfer MLA operates in the compressed latent space.
|
||||
# We need to:
|
||||
# 1. Absorb W_kn (K-nope projection) into q_nope
|
||||
# 2. Run attention in compressed space
|
||||
# 3. Project output back using W_v
|
||||
|
||||
# Extract W_kn and W_v from kv_b_proj_weight
|
||||
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
||||
kv_b_proj_reshaped = kv_b_proj_weight.view(
|
||||
num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank
|
||||
)
|
||||
# W_kn: [N, qk_nope_head_dim, kv_lora_rank]
|
||||
w_kn = kv_b_proj_reshaped[:, :qk_nope_head_dim, :]
|
||||
# W_v: [N, v_head_dim, kv_lora_rank]
|
||||
w_v = kv_b_proj_reshaped[:, qk_nope_head_dim:, :]
|
||||
|
||||
# Absorb W_kn into q_nope:
|
||||
# q_nope_decode: [num_decode, N, qk_nope_head_dim]
|
||||
# w_kn: [N, qk_nope_head_dim, kv_lora_rank]
|
||||
# q_nope_absorbed: [num_decode, N, kv_lora_rank]
|
||||
q_nope_absorbed = torch.einsum("bnd,ndk->bnk", q_nope_decode, w_kn).contiguous()
|
||||
|
||||
pp_decode = MLADecodePlanParams(
|
||||
num_heads=num_heads,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
num_seq=num_decode,
|
||||
page_size=page_size,
|
||||
q_dtype=q_nope.dtype,
|
||||
kv_dtype=ckv_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
wrapper_decode = _GlobalFlashInferMLAPlanner.plan_decode(
|
||||
kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1],
|
||||
kv_page_indices=cache_loc,
|
||||
kv_last_page_len=last_page_len[num_prefill:num_seq],
|
||||
plan_params=pp_decode,
|
||||
)
|
||||
|
||||
# Run attention in compressed space
|
||||
# y_decode_compressed: [num_decode, N, kv_lora_rank]
|
||||
# Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager
|
||||
y_decode_compressed = wrapper_decode.run(
|
||||
q_nope_absorbed,
|
||||
q_pe_decode,
|
||||
ckv_cache,
|
||||
kpe_cache,
|
||||
)
|
||||
|
||||
# Project output back from latent space to v_head_dim
|
||||
# y_decode_compressed: [num_decode, N, kv_lora_rank]
|
||||
# w_v: [N, v_head_dim, kv_lora_rank]
|
||||
# y_decode: [num_decode, N, v_head_dim]
|
||||
y_decode = torch.einsum("bnk,nvk->bnv", y_decode_compressed, w_v)
|
||||
|
||||
if y is not None:
|
||||
y[num_prefill_tokens:num_total_tokens] = y_decode
|
||||
else:
|
||||
y = y_decode
|
||||
|
||||
return y.view(b, s, num_heads, v_head_dim)
|
||||
|
||||
|
||||
@flashinfer_mla_with_cache.register_fake
|
||||
def flashinfer_mla_with_cache_fake(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
compressed_kv: torch.Tensor,
|
||||
kpe: torch.Tensor,
|
||||
kv_b_proj_weight: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen_host: torch.Tensor,
|
||||
cu_num_pages: torch.Tensor,
|
||||
cu_num_pages_host: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
last_page_len_host: torch.Tensor,
|
||||
seq_len_with_cache_host: torch.Tensor,
|
||||
flashinfer_batch_indices: torch.Tensor,
|
||||
flashinfer_positions: torch.Tensor,
|
||||
ckv_cache: torch.Tensor,
|
||||
kpe_cache: torch.Tensor,
|
||||
scale: Optional[float],
|
||||
kv_lora_rank: int,
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for flashinfer_mla_with_cache."""
|
||||
num_heads = q_nope.shape[2]
|
||||
qk_nope_head_dim = q_nope.shape[-1]
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
return q_nope.new_empty(
|
||||
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
|
||||
).contiguous()
|
||||
|
||||
|
||||
class MLAPagedResourceHandler(ResourceHandler):
|
||||
"""Handler for paged resources in MLA that require per-layer contiguous memory.
|
||||
|
||||
While MLA uses paged caching, the underlying flashinfer MLA kernel uses a uint32_t to track the
|
||||
strides for the cache. The KVCacheManager will allocate a contiguous tensor for the cache
|
||||
across all layers with dim 0 representing the layer index. Hence, the per-layer cache has very
|
||||
large strides to jump between pages which causes overflow in the MLA kernel that uses uint32_t
|
||||
for strides.
|
||||
|
||||
We use a separate handler for this purpose to avoid registering the cache with the
|
||||
KVCacheManager and instead rely on local allocation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_paged(self) -> bool:
|
||||
"""Whether the resource is paged."""
|
||||
return True
|
||||
|
||||
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
|
||||
"""Initialize the ContiguousPagedResourceHandler.
|
||||
|
||||
Args:
|
||||
token_shape: The shape of the resource per token.
|
||||
dtype: The dtype of the resource.
|
||||
"""
|
||||
self.token_shape = token_shape
|
||||
self.dtype = dtype
|
||||
|
||||
def _get_bytes_per_token(self) -> int:
|
||||
"""The size of the resource per token in bytes."""
|
||||
return prod(self.token_shape) * self.dtype.itemsize
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate contiguous paged resource.
|
||||
|
||||
Args:
|
||||
sequence_info: SequenceInfo with device and page information.
|
||||
|
||||
Returns:
|
||||
Contiguous tensor of shape [num_blocks, tokens_per_block, *token_shape].
|
||||
"""
|
||||
return torch.empty(
|
||||
sequence_info.num_blocks,
|
||||
sequence_info.tokens_per_block,
|
||||
*self.token_shape,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@AttentionRegistry.register("flashinfer_mla")
|
||||
class FlashInferMLAAttention(AttentionDescriptor):
|
||||
"""Attention descriptor for FlashInfer-based MLA with paged cache.
|
||||
|
||||
This descriptor uses FlashInfer's optimized MLA kernels:
|
||||
- Source op: torch_mla (same as torch_mla backend)
|
||||
- Cached op: flashinfer_mla_with_cache with paged cache
|
||||
|
||||
FlashInfer MLA Cache Layout (two separate caches):
|
||||
ckv_cache: [num_pages, page_size, kv_lora_rank]
|
||||
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
|
||||
- No num_heads dimension (MLA-specific optimization)
|
||||
|
||||
Reference: https://docs.flashinfer.ai/api/mla.html
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_planner(cls) -> _FlashInferMLAPlanner:
|
||||
return _GlobalFlashInferMLAPlanner
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the backend."""
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
"""Get the number of tensor arguments expected by the source op."""
|
||||
return 5 # q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
"""Get the source attention op that we target for replacement."""
|
||||
return torch.ops.auto_deploy.torch_mla
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
"""Get the cached attention op."""
|
||||
return torch.ops.auto_deploy.flashinfer_mla_with_cache.default
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
"""Get the list of standard metadata arguments for paged attention."""
|
||||
return [
|
||||
"batch_info_host",
|
||||
"cu_seqlen_host",
|
||||
"cu_num_pages",
|
||||
"cu_num_pages_host",
|
||||
"cache_loc",
|
||||
"last_page_len",
|
||||
"last_page_len_host",
|
||||
"seq_len_with_cache_host",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_prepare_extra_metadata_info(
|
||||
cls, any_source_attn_node: Node
|
||||
) -> Tuple[Optional[PrepareMetadataCallable], int, List[Constant]]:
|
||||
"""Get the prepare_metadata op for FlashInfer MLA."""
|
||||
return (torch.ops.auto_deploy.flashinfer_mla_prepare_metadata.default, 2, [])
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
"""Get cache initializers using FlashInfer MLA paged cache layout.
|
||||
|
||||
Creates two separate paged caches:
|
||||
- ckv_cache: [num_pages, page_size, kv_lora_rank]
|
||||
- kpe_cache: [num_pages, page_size, qk_rope_head_dim]
|
||||
"""
|
||||
# Extract dimensions from source node args
|
||||
# torch_mla signature: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, ...
|
||||
compressed_kv_fake: FakeTensor = source_attn_node.args[2].meta["val"]
|
||||
kpe_fake: FakeTensor = source_attn_node.args[3].meta["val"]
|
||||
|
||||
# Get dimensions
|
||||
# compressed_kv: [B, S, kv_lora_rank]
|
||||
# kpe: [B, S, 1, qk_rope_head_dim]
|
||||
kv_lora_rank = compressed_kv_fake.shape[-1]
|
||||
qk_rope_head_dim = kpe_fake.shape[-1]
|
||||
|
||||
# flashinfer mla requires kv_lora_rank to be 512 and qk_rope_head_dim to be 64
|
||||
if kv_lora_rank != 512:
|
||||
raise ValueError("kv_lora_rank must be 512 for flashinfer_mla")
|
||||
if qk_rope_head_dim != 64:
|
||||
raise ValueError("qk_rope_head_dim must be 64 for flashinfer_mla")
|
||||
|
||||
cache_dtype = cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype)
|
||||
|
||||
# FlashInfer MLA uses two separate paged caches with no num_heads dimension
|
||||
return {
|
||||
"ckv_cache": MLAPagedResourceHandler(
|
||||
kv_lora_rank,
|
||||
dtype=cache_dtype,
|
||||
),
|
||||
"kpe_cache": MLAPagedResourceHandler(
|
||||
qk_rope_head_dim,
|
||||
dtype=cache_dtype,
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
|
||||
"""Get function for host-side preparation."""
|
||||
return prepare_flashinfer_mla_metadata_host
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
"""Get constants to pass to the cached attention op."""
|
||||
# Extract kv_lora_rank for cache operations
|
||||
compressed_kv_fake = source_attn_node.args[2].meta["val"]
|
||||
kv_lora_rank = compressed_kv_fake.shape[-1]
|
||||
|
||||
# Get scale from kwargs
|
||||
scale = source_attn_node.kwargs.get("scale", None)
|
||||
|
||||
return [scale, kv_lora_rank]
|
||||
@ -1,280 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom ops for MultiHead Latent attention."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ..attention.triton_attention import _decode_attention, _prefill_attention
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
|
||||
def _precompute_inv_freq(
|
||||
max_seq_len: int, head_dim: int, rope_theta: float, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (
|
||||
rope_theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)
|
||||
)
|
||||
t = torch.arange(max_seq_len, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(t.device))
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)])
|
||||
return cos_sin_stacked
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
|
||||
)
|
||||
def fused_flattened_mla_with_cache(
|
||||
# Q, K, V
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# CONSTANTS
|
||||
softmax_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Flattened & fused MLA with cache with triton kernels."""
|
||||
# b, s info
|
||||
# NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
|
||||
# Generally speaking, we expect one of two cases here:
|
||||
# 1. b > 0, s==1: this indicates a generate-only batch of tokens.
|
||||
# 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
|
||||
# and number of tokens per sequence are encoded in seq_len and seq_start.
|
||||
|
||||
# check for sequence info and truncate metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
|
||||
seq_len = seq_len[:num_seq]
|
||||
input_pos = input_pos[:num_seq]
|
||||
cache_loc = cache_loc[:num_seq]
|
||||
seq_start = cu_seqlen[:num_seq]
|
||||
|
||||
# Get parameters
|
||||
b, num_heads, s, qk_nope_head_dim = q_nope.shape
|
||||
qk_rope_head_dim = q_pe.shape[-1]
|
||||
v_head_dim = kv.shape[-1] - qk_nope_head_dim
|
||||
|
||||
# Get k_nope and value_states
|
||||
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# Flatten inputs
|
||||
if s == 1:
|
||||
bs_view = (b, s)
|
||||
else:
|
||||
bs_view = (b * s,)
|
||||
|
||||
# TODO(suyogg): do something about all these clones, transposes, and contiguous-es
|
||||
q_nope = q_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
|
||||
q_pe = q_pe.clone().transpose(1, 2).view(*bs_view, num_heads, qk_rope_head_dim).contiguous()
|
||||
k_nope = k_nope.clone().transpose(1, 2).view(*bs_view, num_heads, qk_nope_head_dim).contiguous()
|
||||
k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous()
|
||||
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
|
||||
# Apply RoPE
|
||||
if rope_theta is not None:
|
||||
max_seq_len = (input_pos + seq_len).max().item()
|
||||
cos_sin_stacked = _precompute_inv_freq(
|
||||
max_seq_len, qk_rope_head_dim, rope_theta, q_pe.device
|
||||
)
|
||||
|
||||
# Extract cos and sin from freqs_cis
|
||||
cos_base = cos_sin_stacked[0, ...]
|
||||
sin_base = cos_sin_stacked[1, ...]
|
||||
|
||||
# TODO: Use triton kernels for RoPE
|
||||
# TODO: Add yarn support
|
||||
for i in range(seq_len.shape[0]):
|
||||
start = seq_start[i]
|
||||
length = seq_len[i]
|
||||
|
||||
# build position_ids
|
||||
if s == 1:
|
||||
idx = (input_pos[i] + length - 1).item()
|
||||
pos_ids = torch.tensor(idx, device=cos_base.device)
|
||||
else:
|
||||
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
|
||||
|
||||
cos = cos_base[pos_ids] # [..., 1, head_dim]
|
||||
sin = sin_base[pos_ids]
|
||||
q_slice = q_pe[start : start + length]
|
||||
k_slice = k_pe[start : start + length]
|
||||
|
||||
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q_slice,
|
||||
k_slice,
|
||||
cos,
|
||||
sin,
|
||||
-2,
|
||||
)
|
||||
|
||||
q_pe[start : start + length] = q_rot
|
||||
k_pe[start : start + length] = k_rot
|
||||
|
||||
# Create query_states, key_states
|
||||
query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d]
|
||||
key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d]
|
||||
|
||||
# Compute scale if not provided
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(qk_head_dim)
|
||||
|
||||
# Compute attention
|
||||
y = torch.empty_like(value_states)
|
||||
if s == 1:
|
||||
# generate-only phase (decode)
|
||||
_decode_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
scale,
|
||||
y,
|
||||
)
|
||||
|
||||
else:
|
||||
# mixed context + generate phase (prefill)
|
||||
_prefill_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
seq_len,
|
||||
seq_start,
|
||||
scale,
|
||||
y,
|
||||
)
|
||||
|
||||
y = (
|
||||
y.view(b, s, -1, v_head_dim).transpose(1, 2).contiguous()
|
||||
) # BNSD format as expected by the callsite.
|
||||
return y
|
||||
|
||||
|
||||
@fused_flattened_mla_with_cache.register_fake
|
||||
def fused_flattened_mla_with_cache_fake(
|
||||
# Q, K, V
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# CONSTANTS
|
||||
softmax_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
):
|
||||
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
|
||||
|
||||
@AttentionRegistry.register("MultiHeadLatentAttention")
|
||||
class MultiHeadLatentAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the backend."""
|
||||
return "bnsd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
"""Get the number of qkv arguments expected by the source op."""
|
||||
return 4
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache.default
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
q_nope_fake = source_attn_node.args[0].meta["val"]
|
||||
q_pe_fake = source_attn_node.args[1].meta["val"]
|
||||
kv_fake = source_attn_node.args[2].meta["val"]
|
||||
|
||||
num_kv_heads = kv_fake.shape[1]
|
||||
head_dim = q_nope_fake.shape[-1]
|
||||
rope_dim = q_pe_fake.shape[-1]
|
||||
|
||||
return {
|
||||
"k_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim + rope_dim,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
|
||||
),
|
||||
"v_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
softmax_scale = None
|
||||
rope_theta = 10000.0 # TODO: remove once MLA is unfused
|
||||
return [softmax_scale, rope_theta]
|
||||
@ -0,0 +1,518 @@
|
||||
"""Custom ops for MultiHead Latent Attention (MLA) with FlashInfer-compatible cache.
|
||||
|
||||
This module provides:
|
||||
- torch_cached_mla_with_cache: cached backend op
|
||||
- TorchBackendMLAAttention: attention descriptor
|
||||
|
||||
FlashInfer MLA Cache Layout:
|
||||
mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
- No num_heads dimension (MLA-specific optimization)
|
||||
- compressed_kv_cached = mla_cache[:, :, :kv_lora_rank] (zero-copy slice)
|
||||
- kpe_cached = mla_cache[:, :, kv_lora_rank:] (zero-copy slice)
|
||||
|
||||
The implementation uses:
|
||||
- Prefill: Expand compressed_kv -> full K, V, compute normal attention
|
||||
- Generate: Weight absorption for efficiency (Q @ W^T instead of expanding cached KV)
|
||||
|
||||
Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
Constant,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
def _update_mla_cache(
|
||||
compressed_kv: torch.Tensor, # [total_tokens, kv_lora_rank]
|
||||
kpe: torch.Tensor, # [total_tokens, qk_rope_head_dim]
|
||||
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
kv_lora_rank: int,
|
||||
) -> None:
|
||||
"""Update FlashInfer MLA cache with compressed_kv and kpe values.
|
||||
|
||||
FlashInfer MLA cache layout: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
- First kv_lora_rank dims: compressed KV latent (before kv_b_proj)
|
||||
- Last qk_rope_head_dim dims: key positional encoding
|
||||
"""
|
||||
for idx in range(seq_len.shape[0]):
|
||||
start = seq_start[idx].item()
|
||||
length = seq_len[idx].item()
|
||||
cache_idx = slot_idx[idx].item()
|
||||
pos = input_pos[idx].item()
|
||||
|
||||
# Update compressed_kv portion
|
||||
mla_cache[cache_idx, pos : pos + length, :kv_lora_rank] = compressed_kv[
|
||||
start : start + length
|
||||
]
|
||||
# Update kpe portion
|
||||
mla_cache[cache_idx, pos : pos + length, kv_lora_rank:] = kpe[start : start + length]
|
||||
|
||||
|
||||
def _torch_mla_generate_with_absorption(
|
||||
q_nope: torch.Tensor, # [B, 1, N, qk_nope_head_dim]
|
||||
q_pe: torch.Tensor, # [B, 1, N, qk_rope_head_dim]
|
||||
compressed_kv: torch.Tensor, # [B, 1, kv_lora_rank]
|
||||
kpe: torch.Tensor, # [B, 1, 1, qk_rope_head_dim]
|
||||
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
slot_idx: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
scale: float,
|
||||
kv_lora_rank: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
"""Generate-only MLA attention with weight absorption.
|
||||
|
||||
Weight absorption: Instead of expanding all cached KV, we absorb kv_b_proj into Q.
|
||||
Q_absorbed = Q_nope @ W_k^T where W_k is the k_nope portion of kv_b_proj_weight
|
||||
|
||||
This avoids expanding potentially thousands of cached tokens.
|
||||
"""
|
||||
b = q_nope.shape[0]
|
||||
|
||||
# Extract k_nope and v portions from kv_b_proj_weight
|
||||
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# Reshape to [N, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
||||
weight_reshaped = kv_b_proj_weight.view(num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank)
|
||||
w_k_nope = weight_reshaped[:, :qk_nope_head_dim, :] # [N, qk_nope_head_dim, kv_lora_rank]
|
||||
w_v = weight_reshaped[:, qk_nope_head_dim:, :] # [N, v_head_dim, kv_lora_rank]
|
||||
|
||||
# Update cache with new tokens
|
||||
compressed_kv_flat = compressed_kv.squeeze(1) # [B, kv_lora_rank]
|
||||
kpe_flat = kpe.squeeze(1).squeeze(1) # [B, qk_rope_head_dim]
|
||||
|
||||
for i in range(b):
|
||||
cache_idx = slot_idx[i].item()
|
||||
pos = input_pos[i].item()
|
||||
mla_cache[cache_idx, pos, :kv_lora_rank] = compressed_kv_flat[i]
|
||||
mla_cache[cache_idx, pos, kv_lora_rank:] = kpe_flat[i]
|
||||
|
||||
# Compute attention for each sequence using weight absorption
|
||||
for i in range(b):
|
||||
cache_idx = slot_idx[i].item()
|
||||
pos = input_pos[i].item()
|
||||
|
||||
# Get query for this sequence: [N, qk_nope_head_dim], [N, qk_rope_head_dim]
|
||||
q_nope_i = q_nope[i, 0] # [N, qk_nope_head_dim]
|
||||
q_pe_i = q_pe[i, 0] # [N, qk_rope_head_dim]
|
||||
|
||||
# Retrieve cached data up to current position
|
||||
cached_data = mla_cache[cache_idx, : pos + 1] # [seq_len, kv_lora_rank + qk_rope_head_dim]
|
||||
compressed_kv_cached = cached_data[:, :kv_lora_rank] # [seq_len, kv_lora_rank]
|
||||
kpe_cached = cached_data[:, kv_lora_rank:] # [seq_len, qk_rope_head_dim]
|
||||
|
||||
# =====================================================================
|
||||
# Weight absorption for Q_nope part
|
||||
# =====================================================================
|
||||
# q_absorbed = q_nope @ w_k_nope^T (absorb k_nope projection into query)
|
||||
# q_nope_i: [N, qk_nope_head_dim]
|
||||
# w_k_nope: [N, qk_nope_head_dim, kv_lora_rank]
|
||||
# q_absorbed: [N, kv_lora_rank]
|
||||
q_absorbed = torch.einsum("nd,ndk->nk", q_nope_i, w_k_nope)
|
||||
|
||||
# Attention scores from absorbed Q and compressed KV
|
||||
# Compute in fp32 to match FlashInfer's use_fp16_qk_reduction=False
|
||||
# q_absorbed: [N, kv_lora_rank], compressed_kv_cached: [seq_len, kv_lora_rank]
|
||||
# scores_nope: [N, seq_len]
|
||||
scores_nope = torch.matmul(q_absorbed.float(), compressed_kv_cached.float().t())
|
||||
|
||||
# =====================================================================
|
||||
# Q_pe part - standard attention with kpe
|
||||
# =====================================================================
|
||||
# q_pe_i: [N, qk_rope_head_dim], kpe_cached: [seq_len, qk_rope_head_dim]
|
||||
# scores_pe: [N, seq_len]
|
||||
scores_pe = torch.matmul(q_pe_i.float(), kpe_cached.float().t())
|
||||
|
||||
# Combined attention scores (already in fp32)
|
||||
attn_scores = (scores_nope + scores_pe) * scale # [N, seq_len]
|
||||
|
||||
# Softmax (already in fp32, convert back to input dtype)
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1).to(q_nope.dtype) # [N, seq_len]
|
||||
|
||||
# =====================================================================
|
||||
# Compute output with absorbed value projection
|
||||
# =====================================================================
|
||||
# v_out = attn_weights @ compressed_kv @ w_v^T
|
||||
# First: weighted_kv = attn_weights @ compressed_kv_cached -> [N, kv_lora_rank]
|
||||
weighted_kv = torch.matmul(attn_weights, compressed_kv_cached) # [N, kv_lora_rank]
|
||||
|
||||
# Then: attn_out = weighted_kv @ w_v^T -> [N, v_head_dim]
|
||||
# w_v: [N, v_head_dim, kv_lora_rank]
|
||||
# weighted_kv: [N, kv_lora_rank]
|
||||
attn_out = torch.einsum("nk,nvk->nv", weighted_kv, w_v) # [N, v_head_dim]
|
||||
|
||||
out[i] = attn_out
|
||||
|
||||
|
||||
def _torch_mla_context_with_expansion(
|
||||
q_nope: torch.Tensor, # [total_tokens, N, qk_nope_head_dim]
|
||||
q_pe: torch.Tensor, # [total_tokens, N, qk_rope_head_dim]
|
||||
compressed_kv: torch.Tensor, # [total_tokens, kv_lora_rank]
|
||||
kpe: torch.Tensor, # [total_tokens, 1, qk_rope_head_dim]
|
||||
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
input_pos: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
scale: float,
|
||||
kv_lora_rank: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
"""Context MLA attention with kv_b_proj expansion.
|
||||
|
||||
For prefill, we expand compressed_kv using kv_b_proj_weight and compute
|
||||
standard attention. This is more efficient than absorption for prefill
|
||||
since we only expand the current tokens, not the full cache.
|
||||
"""
|
||||
|
||||
# Flatten kpe: [total_tokens, 1, qk_rope_head_dim] -> [total_tokens, qk_rope_head_dim]
|
||||
kpe_flat = kpe.squeeze(1)
|
||||
|
||||
# Update cache first with compressed representation
|
||||
_update_mla_cache(
|
||||
compressed_kv,
|
||||
kpe_flat,
|
||||
mla_cache,
|
||||
seq_len,
|
||||
input_pos,
|
||||
slot_idx,
|
||||
seq_start,
|
||||
kv_lora_rank,
|
||||
)
|
||||
|
||||
# Compute attention for each sequence
|
||||
attn_outputs = []
|
||||
for idx in range(seq_len.shape[0]):
|
||||
seq_len_i = seq_len[idx].item()
|
||||
input_pos_i = input_pos[idx].item()
|
||||
slot_idx_i = slot_idx[idx].item()
|
||||
seq_start_i = seq_start[idx].item()
|
||||
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
|
||||
# Get query for this sequence
|
||||
q_nope_seq = q_nope[
|
||||
seq_start_i : seq_start_i + seq_len_i
|
||||
] # [seq_len_i, N, qk_nope_head_dim]
|
||||
q_pe_seq = q_pe[seq_start_i : seq_start_i + seq_len_i] # [seq_len_i, N, qk_rope_head_dim]
|
||||
|
||||
# Get cached data for attention (includes just-added tokens)
|
||||
kv_seq_len = input_pos_i + seq_len_i
|
||||
cached_data = mla_cache[
|
||||
slot_idx_i, :kv_seq_len
|
||||
] # [kv_seq_len, kv_lora_rank + qk_rope_head_dim]
|
||||
compressed_kv_cached = cached_data[:, :kv_lora_rank] # [kv_seq_len, kv_lora_rank]
|
||||
kpe_cached = cached_data[:, kv_lora_rank:] # [kv_seq_len, qk_rope_head_dim]
|
||||
|
||||
# =====================================================================
|
||||
# Expand compressed_kv using kv_b_proj_weight for this sequence
|
||||
# =====================================================================
|
||||
# compressed_kv_cached: [kv_seq_len, kv_lora_rank]
|
||||
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# kv_expanded: [kv_seq_len, N * (qk_nope_head_dim + v_head_dim)]
|
||||
kv_expanded = torch.matmul(compressed_kv_cached, kv_b_proj_weight.t())
|
||||
|
||||
# Reshape to [kv_seq_len, N, qk_nope_head_dim + v_head_dim]
|
||||
kv_expanded = kv_expanded.view(kv_seq_len, num_heads, qk_nope_head_dim + v_head_dim)
|
||||
|
||||
# Split into k_nope and v
|
||||
k_nope_expanded = kv_expanded[:, :, :qk_nope_head_dim] # [kv_seq_len, N, qk_nope_head_dim]
|
||||
v_expanded = kv_expanded[:, :, qk_nope_head_dim:] # [kv_seq_len, N, v_head_dim]
|
||||
|
||||
# Expand kpe to all heads
|
||||
kpe_expanded = kpe_cached.unsqueeze(1).expand(
|
||||
-1, num_heads, -1
|
||||
) # [kv_seq_len, N, qk_rope_head_dim]
|
||||
|
||||
# Construct full query and key
|
||||
query_full = torch.cat([q_nope_seq, q_pe_seq], dim=-1) # [seq_len_i, N, qk_head_dim]
|
||||
key_full = torch.cat(
|
||||
[k_nope_expanded, kpe_expanded], dim=-1
|
||||
) # [kv_seq_len, N, qk_head_dim]
|
||||
|
||||
# Transpose for attention: [1, N, seq_len, head_dim]
|
||||
query_t = query_full.transpose(0, 1).unsqueeze(0) # [1, N, seq_len_i, qk_head_dim]
|
||||
key_t = key_full.transpose(0, 1).unsqueeze(0) # [1, N, kv_seq_len, qk_head_dim]
|
||||
|
||||
# Compute attention scores in fp32 to match FlashInfer's use_fp16_qk_reduction=False
|
||||
# FlashInfer uses fp32 accumulation for QK^T, so we do the same for numerical consistency
|
||||
attn_scores = (
|
||||
torch.matmul(query_t.float(), key_t.float().transpose(-2, -1)) * scale
|
||||
) # [1, N, seq_len_i, kv_seq_len] in fp32
|
||||
|
||||
# Apply causal mask
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len_i, kv_seq_len, device=q_nope.device, dtype=torch.bool),
|
||||
diagonal=kv_seq_len - seq_len_i + 1,
|
||||
)
|
||||
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
||||
|
||||
# Softmax (already in fp32, convert back to input dtype)
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1).to(q_nope.dtype)
|
||||
|
||||
# Value: [1, N, kv_seq_len, v_head_dim]
|
||||
v_t = v_expanded.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
# Compute output
|
||||
attn_out = torch.matmul(attn_weights, v_t) # [1, N, seq_len_i, v_head_dim]
|
||||
attn_out = attn_out[0].transpose(0, 1) # [seq_len_i, N, v_head_dim]
|
||||
|
||||
attn_outputs.append(attn_out)
|
||||
|
||||
# Concatenate all outputs
|
||||
if len(attn_outputs) == 0:
|
||||
out.zero_()
|
||||
elif len(attn_outputs) == 1:
|
||||
out.copy_(attn_outputs[0])
|
||||
else:
|
||||
out.copy_(torch.cat(attn_outputs, dim=0))
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_cached_mla_with_cache", mutates_args=())
|
||||
def torch_backend_mla_with_cache(
|
||||
# 5 tensor args (get_num_qkv_args = 5)
|
||||
q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim]
|
||||
q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim]
|
||||
compressed_kv: torch.Tensor, # [B, S, kv_lora_rank]
|
||||
kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim]
|
||||
kv_b_proj_weight: torch.Tensor, # [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# Standard metadata
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# Cache (FlashInfer layout)
|
||||
mla_cache: torch.Tensor, # [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
# Constants
|
||||
scale: Optional[float] = None,
|
||||
kv_lora_rank: int = 512,
|
||||
) -> torch.Tensor:
|
||||
"""Torch backend MLA with FlashInfer-compatible compressed cache.
|
||||
|
||||
FlashInfer MLA Cache Layout:
|
||||
mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
- compressed_kv = mla_cache[:, :, :kv_lora_rank] (zero-copy slice)
|
||||
- kpe = mla_cache[:, :, kv_lora_rank:] (zero-copy slice)
|
||||
|
||||
Prefill (context): Expand compressed_kv, compute normal attention
|
||||
Generate (decode): Use weight absorption for efficiency
|
||||
"""
|
||||
# Get dimensions
|
||||
b, s = q_nope.shape[:2]
|
||||
num_heads = q_nope.shape[2]
|
||||
qk_nope_head_dim = q_nope.shape[3]
|
||||
qk_rope_head_dim = q_pe.shape[3]
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
# Infer v_head_dim from kv_b_proj_weight
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
# Get cleaned up metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
seq_len = seq_len[:num_seq]
|
||||
input_pos = input_pos[:num_seq]
|
||||
slot_idx = slot_idx[:num_seq]
|
||||
seq_start = cu_seqlen[:num_seq]
|
||||
|
||||
# Set scale
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(qk_head_dim)
|
||||
|
||||
# Define output shape: [B, S, N, v_head_dim]
|
||||
output_shape = (b, s, num_heads, v_head_dim)
|
||||
|
||||
if s == 1:
|
||||
# =====================================================================
|
||||
# Generate phase: Use weight absorption
|
||||
# =====================================================================
|
||||
y = q_nope.new_empty(b, num_heads, v_head_dim).contiguous()
|
||||
|
||||
_torch_mla_generate_with_absorption(
|
||||
q_nope,
|
||||
q_pe,
|
||||
compressed_kv,
|
||||
kpe,
|
||||
kv_b_proj_weight,
|
||||
mla_cache,
|
||||
slot_idx,
|
||||
input_pos,
|
||||
scale,
|
||||
kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
v_head_dim,
|
||||
y,
|
||||
)
|
||||
|
||||
return y.unsqueeze(1) # [B, 1, N, v_head_dim]
|
||||
else:
|
||||
# =====================================================================
|
||||
# Context phase: Expand and compute normal attention
|
||||
# =====================================================================
|
||||
bs_view = (b * s,)
|
||||
|
||||
q_nope_flat = q_nope.contiguous().view(*bs_view, num_heads, qk_nope_head_dim)
|
||||
q_pe_flat = q_pe.contiguous().view(*bs_view, num_heads, qk_rope_head_dim)
|
||||
compressed_kv_flat = compressed_kv.contiguous().view(*bs_view, kv_lora_rank)
|
||||
kpe_flat = kpe.contiguous().view(*bs_view, 1, qk_rope_head_dim)
|
||||
|
||||
y = q_nope.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
|
||||
|
||||
_torch_mla_context_with_expansion(
|
||||
q_nope_flat,
|
||||
q_pe_flat,
|
||||
compressed_kv_flat,
|
||||
kpe_flat,
|
||||
kv_b_proj_weight,
|
||||
mla_cache,
|
||||
input_pos,
|
||||
slot_idx,
|
||||
seq_len,
|
||||
seq_start,
|
||||
scale,
|
||||
kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
v_head_dim,
|
||||
y,
|
||||
)
|
||||
|
||||
return y.view(*output_shape)
|
||||
|
||||
|
||||
@torch_backend_mla_with_cache.register_fake
|
||||
def torch_backend_mla_with_cache_fake(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
compressed_kv: torch.Tensor,
|
||||
kpe: torch.Tensor,
|
||||
kv_b_proj_weight: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
mla_cache: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
kv_lora_rank: int = 512,
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for torch_backend_mla_with_cache."""
|
||||
num_heads = q_nope.shape[2]
|
||||
qk_nope_head_dim = q_nope.shape[-1]
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
return q_nope.new_empty(
|
||||
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
|
||||
).contiguous()
|
||||
|
||||
|
||||
@AttentionRegistry.register("torch_mla")
|
||||
class TorchBackendMLAAttention(AttentionDescriptor):
|
||||
"""Attention descriptor for Multi-head Latent Attention (MLA).
|
||||
|
||||
This descriptor uses FlashInfer-compatible compressed cache:
|
||||
- torch_mla: source op that expands compressed_kv for attention
|
||||
- torch_cached_mla_with_cache: cached op with absorption for generate
|
||||
|
||||
FlashInfer MLA Cache Layout:
|
||||
mla_cache: [max_batch, max_seq, head_dim_ckv + head_dim_kpe]
|
||||
- No num_heads dimension (MLA-specific optimization)
|
||||
- ckv_cached = mla_cache[:, :, :head_dim_ckv] (zero-copy slice)
|
||||
- kpe_cached = mla_cache[:, :, head_dim_ckv:] (zero-copy slice)
|
||||
|
||||
Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the backend."""
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
"""Get the number of tensor arguments expected by the source op."""
|
||||
return 5 # q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
"""Get the source attention op that we target for replacement."""
|
||||
return torch.ops.auto_deploy.torch_mla
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
"""Get the cached attention op."""
|
||||
return torch.ops.auto_deploy.torch_cached_mla_with_cache.default
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
"""Get the list of standard metadata arguments."""
|
||||
return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
"""Get cache initializers using FlashInfer MLA cache layout."""
|
||||
# Extract dimensions from source node args
|
||||
# torch_mla signature: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, ...
|
||||
compressed_kv_fake = source_attn_node.args[2].meta["val"]
|
||||
kpe_fake = source_attn_node.args[3].meta["val"]
|
||||
|
||||
# Get dimensions
|
||||
# compressed_kv: [B, S, kv_lora_rank]
|
||||
# kpe: [B, S, 1, qk_rope_head_dim]
|
||||
kv_lora_rank = compressed_kv_fake.shape[-1]
|
||||
qk_rope_head_dim = kpe_fake.shape[-1]
|
||||
|
||||
# FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
# No num_heads dimension - this is the key MLA optimization
|
||||
return {
|
||||
"mla_cache": UnpagedResourceHandler(
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
"""Get constants to pass to the cached attention op."""
|
||||
# Extract kv_lora_rank for cache slicing
|
||||
compressed_kv_fake = source_attn_node.args[2].meta["val"]
|
||||
kv_lora_rank = compressed_kv_fake.shape[-1]
|
||||
|
||||
# Get scale from kwargs
|
||||
scale = source_attn_node.kwargs.get("scale", None)
|
||||
|
||||
return [scale, kv_lora_rank]
|
||||
157
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py
Normal file
157
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""Torch reference implementation for Multi-head Latent Attention (MLA).
|
||||
|
||||
This module provides the source op for MLA that:
|
||||
- Accepts compressed_kv (before kv_b_proj) for FlashInfer-compatible caching
|
||||
- Expands compressed_kv using kv_b_proj_weight for attention computation
|
||||
- Computes standard attention with the expanded K, V
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_mla", mutates_args=())
|
||||
def torch_mla(
|
||||
q_nope: torch.Tensor, # [B, S, N, qk_nope_head_dim]
|
||||
q_pe: torch.Tensor, # [B, S, N, qk_rope_head_dim] (RoPE applied)
|
||||
compressed_kv: torch.Tensor, # [B, S, kv_lora_rank] - BEFORE kv_b_proj
|
||||
kpe: torch.Tensor, # [B, S, 1, qk_rope_head_dim] (RoPE applied)
|
||||
kv_b_proj_weight: torch.Tensor, # [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
is_causal: bool = True,
|
||||
scale: Optional[float] = None,
|
||||
layout: str = "bsnd",
|
||||
) -> torch.Tensor:
|
||||
"""Multi-head Latent Attention (MLA) with FlashInfer-compatible compressed KV.
|
||||
|
||||
This op expands compressed_kv using kv_b_proj_weight and computes attention.
|
||||
For prefill, this is the standard formulation. For the cached version,
|
||||
weight absorption is used for efficiency.
|
||||
|
||||
Args:
|
||||
q_nope: Query non-positional component [B, S, N, qk_nope_head_dim] (bsnd)
|
||||
q_pe: Query positional component with RoPE applied [B, S, N, qk_rope_head_dim] (bsnd)
|
||||
compressed_kv: Compressed KV latent [B, S, kv_lora_rank] (before kv_b_proj)
|
||||
kpe: Key positional encoding with RoPE applied [B, S, 1, qk_rope_head_dim] (bsnd)
|
||||
kv_b_proj_weight: Projection weights [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
is_causal: Whether to apply causal masking (default: True)
|
||||
scale: Softmax scale factor (default: 1/sqrt(qk_head_dim))
|
||||
layout: Input/output layout, either "bsnd" or "bnsd" (default: "bsnd")
|
||||
|
||||
Returns:
|
||||
Attention output with shape [B, S, N, v_head_dim] (bsnd)
|
||||
"""
|
||||
if layout not in ("bnsd", "bsnd"):
|
||||
raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}")
|
||||
|
||||
# Get dimensions
|
||||
if layout == "bsnd":
|
||||
bs, s_q, num_heads, qk_nope_head_dim = q_nope.shape
|
||||
qk_rope_head_dim = q_pe.shape[-1]
|
||||
else:
|
||||
bs, num_heads, s_q, qk_nope_head_dim = q_nope.shape
|
||||
qk_rope_head_dim = q_pe.shape[-1]
|
||||
|
||||
s_k = compressed_kv.shape[1]
|
||||
|
||||
# Infer dimensions from kv_b_proj_weight
|
||||
# kv_b_proj_weight: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads # qk_nope_head_dim + v_head_dim
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
# Set scale
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(qk_head_dim)
|
||||
|
||||
# =========================================================================
|
||||
# Expand compressed_kv using kv_b_proj_weight (this is the prefill path)
|
||||
# =========================================================================
|
||||
# compressed_kv: [B, S, kv_lora_rank]
|
||||
# kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank]
|
||||
# kv = compressed_kv @ kv_b_proj_weight.T -> [B, S, num_heads * kv_head_dim]
|
||||
kv = torch.matmul(compressed_kv, kv_b_proj_weight.t())
|
||||
|
||||
# Reshape to [B, S, N, kv_head_dim]
|
||||
kv = kv.view(bs, s_k, num_heads, kv_head_dim)
|
||||
|
||||
# Split into k_nope and value_states
|
||||
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# k_nope and value_states are always [B, S, N, D] from the kv reshape above.
|
||||
# We need them in [B, N, S, D] for attention computation.
|
||||
k_nope = k_nope.transpose(1, 2).contiguous()
|
||||
value_states = value_states.transpose(1, 2).contiguous()
|
||||
|
||||
# Convert inputs to computation layout [B, N, S, D] if they come in bsnd format
|
||||
if layout == "bsnd":
|
||||
# [B, S, N, D] -> [B, N, S, D]
|
||||
q_nope = q_nope.transpose(1, 2).contiguous()
|
||||
q_pe = q_pe.transpose(1, 2).contiguous()
|
||||
kpe = kpe.transpose(1, 2).contiguous()
|
||||
|
||||
# kpe is [B, 1, S, qk_rope_head_dim], expand to num_heads
|
||||
kpe_expanded = kpe.expand(bs, num_heads, s_k, qk_rope_head_dim)
|
||||
|
||||
# Construct full query and key states
|
||||
# query_states: [B, N, S, qk_head_dim]
|
||||
query_states = torch.cat([q_nope, q_pe], dim=-1)
|
||||
# key_states: [B, N, S, qk_head_dim]
|
||||
key_states = torch.cat([k_nope, kpe_expanded], dim=-1)
|
||||
|
||||
# Compute attention scores: Q @ K^T
|
||||
attn_scores = (
|
||||
torch.matmul(query_states, key_states.transpose(-2, -1)) * scale
|
||||
) # [B, N, s_q, s_k]
|
||||
|
||||
# Apply causal mask if specified
|
||||
if is_causal and s_q == s_k:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(s_q, s_k, device=q_nope.device, dtype=torch.bool),
|
||||
diagonal=1,
|
||||
)
|
||||
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
||||
|
||||
# Compute attention weights and output
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q_nope.dtype)
|
||||
attn_out = torch.matmul(attn_weights, value_states) # [B, N, s_q, v_head_dim]
|
||||
|
||||
# Convert back to requested layout
|
||||
if layout == "bsnd":
|
||||
return attn_out.transpose(1, 2).contiguous() # [B, S, N, v_head_dim]
|
||||
else:
|
||||
return attn_out.contiguous() # [B, N, S, v_head_dim]
|
||||
|
||||
|
||||
@torch_mla.register_fake
|
||||
def torch_mla_fake(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
compressed_kv: torch.Tensor,
|
||||
kpe: torch.Tensor,
|
||||
kv_b_proj_weight: torch.Tensor,
|
||||
is_causal: bool = True,
|
||||
scale: Optional[float] = None,
|
||||
layout: str = "bsnd",
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for torch_mla."""
|
||||
# Infer v_head_dim from kv_b_proj_weight
|
||||
qk_nope_head_dim = q_nope.shape[-1]
|
||||
num_heads = q_nope.shape[2] if layout == "bsnd" else q_nope.shape[1]
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
# Output shape depends on layout
|
||||
if layout == "bsnd":
|
||||
# Input: [B, S, N, D], Output: [B, S, N, v_head_dim]
|
||||
return q_nope.new_empty(
|
||||
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
|
||||
).contiguous()
|
||||
else:
|
||||
# Input: [B, N, S, D], Output: [B, N, S, v_head_dim]
|
||||
return q_nope.new_empty(
|
||||
q_nope.shape[0], q_nope.shape[1], q_nope.shape[2], v_head_dim
|
||||
).contiguous()
|
||||
@ -1,9 +1,11 @@
|
||||
from .modeling_eagle import Eagle3DrafterForCausalLM
|
||||
from .modeling_deepseek import DeepSeekV3ForCausalLM
|
||||
from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM
|
||||
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
|
||||
from .modeling_nemotron_h import NemotronHForCausalLM
|
||||
|
||||
__all__ = (
|
||||
"Eagle3DrafterForCausalLM",
|
||||
"DeepSeekV3ForCausalLM",
|
||||
"Glm4MoeLiteForCausalLM",
|
||||
"NemotronFlashForCausalLM",
|
||||
"NemotronFlashPreTrainedTokenizerFast",
|
||||
"NemotronHForCausalLM",
|
||||
|
||||
@ -0,0 +1,655 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Slimmed down PyTorch DeepSeekV3 model implementation for auto_deploy export.
|
||||
|
||||
Source:
|
||||
https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
|
||||
|
||||
This implementation differs from the original in the following ways:
|
||||
* Simplified for prefill-only inference (no KV caching)
|
||||
* Uses auto_deploy custom ops for export compatibility
|
||||
* Removed flash attention variants (uses torch_mla custom op)
|
||||
* Removed gradient checkpointing and training code paths
|
||||
* Removed attention dropout (inference only)
|
||||
|
||||
This allows us to have a clean export-ready implementation with auto_deploy custom ops.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
class DeepSeekV3RMSNorm(nn.Module):
|
||||
"""RMS Normalization for DeepSeekV3."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.auto_deploy.triton_rms_norm(
|
||||
hidden_states, self.weight, self.variance_epsilon
|
||||
).to(hidden_states.dtype)
|
||||
|
||||
|
||||
class DeepSeekV3RotaryEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding for DeepSeekV3.
|
||||
|
||||
Simplified version that precomputes and caches cos/sin values.
|
||||
Returns full cached values (not sliced by seq_len) to enable export.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build cos/sin cache
|
||||
self._set_cos_sin_cache(max_position_embeddings)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len: int):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, seq_len: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Return full cached cos/sin (not sliced) for export compatibility
|
||||
return (
|
||||
self.cos_cached.to(dtype=x.dtype, device=x.device),
|
||||
self.sin_cached.to(dtype=x.dtype, device=x.device),
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV3YarnRotaryEmbedding(DeepSeekV3RotaryEmbedding):
|
||||
"""YaRN-extended rotary embedding for DeepSeekV3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
base: float = 10000.0,
|
||||
scaling_factor: float = 1.0,
|
||||
original_max_position_embeddings: int = 4096,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1.0,
|
||||
mscale_all_dim: float = 0.0,
|
||||
):
|
||||
self.scaling_factor = scaling_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = mscale
|
||||
self.mscale_all_dim = mscale_all_dim
|
||||
super().__init__(dim, max_position_embeddings, base)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len: int):
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.dim
|
||||
|
||||
freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
freq_inter = 1.0 / (
|
||||
self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
|
||||
)
|
||||
|
||||
low, high = self._yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.original_max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, dtype=torch.float32)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
_mscale = float(
|
||||
self._yarn_get_mscale(self.scaling_factor, self.mscale)
|
||||
/ self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
||||
)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", (emb.cos() * _mscale), persistent=False)
|
||||
self.register_buffer("sin_cached", (emb.sin() * _mscale), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def _yarn_find_correction_dim(
|
||||
num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048
|
||||
) -> float:
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
def _yarn_find_correction_range(
|
||||
self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int
|
||||
) -> Tuple[int, int]:
|
||||
low = math.floor(
|
||||
self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
@staticmethod
|
||||
def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
@staticmethod
|
||||
def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor:
|
||||
if min_val == max_val:
|
||||
max_val += 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
|
||||
return torch.clamp(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepSeekV3MLP(nn.Module):
|
||||
"""MLP layer for DeepSeekV3 (SwiGLU activation)."""
|
||||
|
||||
def __init__(
|
||||
self, config, hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size or config.hidden_size
|
||||
self.intermediate_size = intermediate_size or config.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class DeepSeekV3MoEGate(nn.Module):
|
||||
"""MoE Gating for DeepSeekV3 with noaux_tc top-k selection."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32)
|
||||
)
|
||||
self.register_buffer(
|
||||
"e_score_correction_bias",
|
||||
torch.zeros(self.n_routed_experts, dtype=torch.float32),
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""Initialize gate weights using kaiming uniform (matches original DeepSeek implementation)."""
|
||||
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass returning (selected_experts, routing_weights)."""
|
||||
bsz, seq_len, hidden_dim = hidden_states.shape
|
||||
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Compute router logits
|
||||
if self.weight.dtype == torch.float32:
|
||||
router_logits = F.linear(hidden_states_flat.float(), self.weight)
|
||||
else:
|
||||
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
|
||||
hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Use fused noaux_tc_op kernel for top-k selection
|
||||
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
|
||||
router_logits,
|
||||
self.e_score_correction_bias,
|
||||
self.n_group,
|
||||
self.topk_group,
|
||||
self.top_k,
|
||||
self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
class DeepSeekV3MoE(nn.Module):
|
||||
"""Mixture of Experts layer for DeepSeekV3."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
# Routed experts
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepSeekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# Gate
|
||||
self.gate = DeepSeekV3MoEGate(config)
|
||||
|
||||
# Shared experts (if configured)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepSeekV3MLP(config, intermediate_size=intermediate_size)
|
||||
else:
|
||||
self.shared_experts = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
|
||||
selected_experts, routing_weights = self.gate(hidden_states)
|
||||
|
||||
# Use torch_moe custom op for routed experts
|
||||
final_hidden_states = torch.ops.auto_deploy.torch_moe(
|
||||
hidden_states.view(-1, hidden_states.shape[-1]),
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_weight=[expert.gate_proj.weight for expert in self.experts],
|
||||
w2_weight=[expert.down_proj.weight for expert in self.experts],
|
||||
w3_weight=[expert.up_proj.weight for expert in self.experts],
|
||||
is_gated_mlp=True,
|
||||
act_fn=int(ActivationType.Silu),
|
||||
)
|
||||
|
||||
final_hidden_states = final_hidden_states.view(*orig_shape)
|
||||
|
||||
# Add shared experts output if present
|
||||
if self.shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + self.shared_experts(identity)
|
||||
|
||||
return final_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
|
||||
class DeepSeekV3Attention(nn.Module):
|
||||
"""Multi-head Latent Attention (MLA) for DeepSeekV3.
|
||||
|
||||
Uses compressed KV representation with latent projections.
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
# Q projection (with optional LoRA)
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = DeepSeekV3RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
# KV projection with MQA
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = DeepSeekV3RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Output projection
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
# Initialize rotary embedding
|
||||
self._init_rope()
|
||||
|
||||
# Softmax scale
|
||||
self.softmax_scale = self.q_head_dim ** (-0.5)
|
||||
if config.rope_scaling is not None:
|
||||
mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = DeepSeekV3YarnRotaryEmbedding._yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = DeepSeekV3RotaryEmbedding(
|
||||
self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
|
||||
if scaling_type == "yarn":
|
||||
kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rotary_emb = DeepSeekV3YarnRotaryEmbedding(
|
||||
self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Default to base rotary embedding for unsupported types
|
||||
self.rotary_emb = DeepSeekV3RotaryEmbedding(
|
||||
self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Q projection
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
|
||||
# Shape: [B, S, N, q_head_dim] (BSND layout)
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# KV projection - keep compressed form
|
||||
kv_a_output = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Apply layernorm to compressed_kv
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
|
||||
# k_pe: [B, S, 1, qk_rope_head_dim] (BSND layout, shared across heads)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
|
||||
kv_seq_len = q_len
|
||||
|
||||
# Get cos/sin for RoPE
|
||||
cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
|
||||
cos = cos[position_ids] # [B, S, head_dim]
|
||||
sin = sin[position_ids] # [B, S, head_dim]
|
||||
|
||||
# Apply RoPE using custom op
|
||||
q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q_pe,
|
||||
k_pe,
|
||||
cos,
|
||||
sin,
|
||||
2, # unsqueeze_dim=2 for BSND layout
|
||||
)
|
||||
|
||||
# Call MLA with compressed KV
|
||||
attn_output = torch.ops.auto_deploy.torch_mla(
|
||||
q_nope, # [B, S, N, qk_nope_head_dim]
|
||||
q_pe_rotated, # [B, S, N, qk_rope_head_dim]
|
||||
compressed_kv, # [B, S, kv_lora_rank]
|
||||
kpe, # [B, S, 1, qk_rope_head_dim]
|
||||
self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank]
|
||||
True, # is_causal
|
||||
self.softmax_scale,
|
||||
"bsnd", # layout
|
||||
)
|
||||
|
||||
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class DeepSeekV3DecoderLayer(nn.Module):
|
||||
"""Transformer decoder layer for DeepSeekV3."""
|
||||
|
||||
def __init__(self, config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Attention
|
||||
self.self_attn = DeepSeekV3Attention(config, layer_idx=layer_idx)
|
||||
|
||||
# MLP or MoE
|
||||
# MoE layers are used after first_k_dense_replace and at moe_layer_freq intervals
|
||||
use_moe = (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
if use_moe:
|
||||
self.mlp = DeepSeekV3MoE(config)
|
||||
else:
|
||||
self.mlp = DeepSeekV3MLP(config)
|
||||
|
||||
# Layer norms
|
||||
self.input_layernorm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DeepSeekV3RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Self attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, position_ids)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP/MoE
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV3Output(ModelOutput):
|
||||
"""Output for DeepSeekV3Model."""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV3CausalLMOutput(ModelOutput):
|
||||
"""Output for DeepSeekV3ForCausalLM."""
|
||||
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class DeepSeekV3PreTrainedModel(PreTrainedModel):
|
||||
"""Base class for DeepSeekV3 models."""
|
||||
|
||||
base_model_prefix = "model"
|
||||
_no_split_modules = ["DeepSeekV3DecoderLayer"]
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class DeepSeekV3Model(DeepSeekV3PreTrainedModel):
|
||||
"""DeepSeekV3 transformer decoder model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepSeekV3DecoderLayer(config, layer_idx=idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> DeepSeekV3Output:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("Cannot specify both input_ids and inputs_embeds")
|
||||
elif input_ids is None and inputs_embeds is None:
|
||||
raise ValueError("Must specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states = decoder_layer(hidden_states, position_ids)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return DeepSeekV3Output(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class DeepSeekV3ForCausalLM(DeepSeekV3PreTrainedModel, GenerationMixin):
|
||||
"""DeepSeekV3 model with language modeling head."""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = DeepSeekV3Model(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> DeepSeekV3CausalLMOutput:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = self.lm_head(hidden_states).float()
|
||||
|
||||
return DeepSeekV3CausalLMOutput(logits=logits)
|
||||
|
||||
|
||||
# Register with AutoModelForCausalLMFactory
|
||||
AutoModelForCausalLMFactory.register_custom_model_cls("DeepseekV3Config", DeepSeekV3ForCausalLM)
|
||||
@ -0,0 +1,830 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Slimmed down PyTorch GLM4 MoE Lite model implementation for auto_deploy export.
|
||||
|
||||
Source:
|
||||
https://huggingface.co/zai-org/GLM-4.7-Flash
|
||||
|
||||
This implementation differs from the original HuggingFace version in the following ways:
|
||||
* Bundled config class to work with transformers v4.57 (model requires v5.0)
|
||||
* Simplified for prefill-only inference (no KV caching)
|
||||
* Uses auto_deploy custom ops for export compatibility
|
||||
* Removed flash attention variants (uses torch_mla custom op)
|
||||
* Removed gradient checkpointing and training code paths
|
||||
* Removed attention dropout (inference only)
|
||||
|
||||
The GLM4 MoE Lite model uses Multi-head Latent Attention (MLA), similar to DeepSeek V3.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
class Glm4MoeLiteConfig(PretrainedConfig):
|
||||
"""Configuration class for GLM4 MoE Lite model.
|
||||
|
||||
This config class is bundled with the custom model implementation to enable
|
||||
loading on transformers v4.57 (the model requires v5.0 where the config is
|
||||
natively registered).
|
||||
"""
|
||||
|
||||
model_type = "glm4_moe_lite"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 154880,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 10240,
|
||||
num_hidden_layers: int = 47,
|
||||
num_attention_heads: int = 20,
|
||||
num_key_value_heads: int = 20,
|
||||
hidden_act: str = "silu",
|
||||
max_position_embeddings: int = 202752,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
# MLA parameters
|
||||
q_lora_rank: int = 768,
|
||||
kv_lora_rank: int = 512,
|
||||
qk_nope_head_dim: int = 192,
|
||||
qk_rope_head_dim: int = 64,
|
||||
v_head_dim: int = 256,
|
||||
# MoE parameters
|
||||
n_routed_experts: int = 64,
|
||||
n_shared_experts: int = 1,
|
||||
num_experts_per_tok: int = 4,
|
||||
moe_intermediate_size: int = 1536,
|
||||
n_group: int = 1,
|
||||
topk_group: int = 1,
|
||||
routed_scaling_factor: float = 1.8,
|
||||
norm_topk_prob: bool = True,
|
||||
first_k_dense_replace: int = 1,
|
||||
# RoPE parameters
|
||||
rope_theta: float = 1000000.0,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
# Other parameters
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
tie_word_embeddings: bool = False,
|
||||
pad_token_id: int = 154820,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
# MLA
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
# MoE
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
# RoPE
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# Other
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Register config with transformers' AutoConfig so it can be loaded from HF hub
|
||||
# Use exist_ok=True to handle cases where transformers already has this model type registered
|
||||
# (e.g., transformers v5.0+). In those cases, AutoConfig will use the built-in config,
|
||||
# but AutoModelForCausalLMFactory will still use our custom model implementation.
|
||||
try:
|
||||
AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig, exist_ok=True)
|
||||
except TypeError:
|
||||
# Older transformers versions don't support exist_ok parameter
|
||||
try:
|
||||
AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig)
|
||||
except ValueError:
|
||||
# Already registered by transformers, that's fine
|
||||
pass
|
||||
|
||||
|
||||
class Glm4MoeLiteRMSNorm(nn.Module):
|
||||
"""RMS Normalization for GLM4 MoE Lite.
|
||||
|
||||
Uses standard torch operations so AD fusion passes can replace with
|
||||
the appropriate backend (flashinfer/triton) based on config.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class Glm4MoeLiteRotaryEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding for GLM4 MoE Lite.
|
||||
|
||||
Simplified version that precomputes and caches cos/sin values.
|
||||
Returns full cached values (not sliced by seq_len) to enable export.
|
||||
|
||||
Uses _ad_ prefix for buffer names to work with AutoDeploy's lift_to_meta.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
base: float = 10000.0,
|
||||
attention_scaling: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.attention_scaling = attention_scaling
|
||||
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build cos/sin cache with AD-specific naming
|
||||
self._set_cos_sin_cache(max_position_embeddings)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len: int):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
# Use _ad_ prefix for AutoDeploy compatibility with lift_to_meta
|
||||
self.register_buffer("_ad_cos_cached", emb.cos() * self.attention_scaling, persistent=False)
|
||||
self.register_buffer("_ad_sin_cached", emb.sin() * self.attention_scaling, persistent=False)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, seq_len: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Return full cached cos/sin (not sliced) for export compatibility
|
||||
return (
|
||||
self._ad_cos_cached.to(dtype=x.dtype, device=x.device),
|
||||
self._ad_sin_cached.to(dtype=x.dtype, device=x.device),
|
||||
)
|
||||
|
||||
|
||||
class Glm4MoeLiteYarnRotaryEmbedding(Glm4MoeLiteRotaryEmbedding):
|
||||
"""YaRN-extended rotary embedding for GLM4 MoE Lite."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
base: float = 10000.0,
|
||||
scaling_factor: float = 1.0,
|
||||
original_max_position_embeddings: int = 4096,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1.0,
|
||||
mscale_all_dim: float = 0.0,
|
||||
attention_scaling: float = 1.0,
|
||||
):
|
||||
self.scaling_factor = scaling_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = mscale
|
||||
self.mscale_all_dim = mscale_all_dim
|
||||
super().__init__(dim, max_position_embeddings, base, attention_scaling)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len: int):
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.dim
|
||||
|
||||
freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
freq_inter = 1.0 / (
|
||||
self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
|
||||
)
|
||||
|
||||
low, high = self._yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.original_max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, dtype=torch.float32)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
_mscale = float(
|
||||
self._yarn_get_mscale(self.scaling_factor, self.mscale)
|
||||
/ self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
||||
)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
# Use _ad_ prefix for AutoDeploy compatibility with lift_to_meta
|
||||
# Note: attention_scaling is already incorporated in _mscale for YaRN
|
||||
self.register_buffer("_ad_cos_cached", (emb.cos() * _mscale), persistent=False)
|
||||
self.register_buffer("_ad_sin_cached", (emb.sin() * _mscale), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def _yarn_find_correction_dim(
|
||||
num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048
|
||||
) -> float:
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
def _yarn_find_correction_range(
|
||||
self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int
|
||||
) -> Tuple[int, int]:
|
||||
low = math.floor(
|
||||
self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
@staticmethod
|
||||
def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
@staticmethod
|
||||
def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor:
|
||||
if min_val == max_val:
|
||||
max_val += 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
|
||||
return torch.clamp(linear_func, 0, 1)
|
||||
|
||||
|
||||
class Glm4MoeLiteMLP(nn.Module):
|
||||
"""MLP layer for GLM4 MoE Lite (SwiGLU activation)."""
|
||||
|
||||
def __init__(
|
||||
self, config, hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size or config.hidden_size
|
||||
self.intermediate_size = intermediate_size or config.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Glm4MoeLiteMoEGate(nn.Module):
|
||||
"""MoE Gating for GLM4 MoE Lite with top-k selection.
|
||||
|
||||
Uses fused TensorRT-LLM custom ops for efficient routing:
|
||||
- dsv3_router_gemm_op: Fused router GEMM for non-float32 weights
|
||||
- noaux_tc_op: Fused sigmoid + bias + group top-k + normalize + scale
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.norm_topk_prob = getattr(config, "norm_topk_prob", True)
|
||||
|
||||
# noaux_tc_op always normalizes, so norm_topk_prob must be True
|
||||
if not self.norm_topk_prob:
|
||||
raise ValueError(
|
||||
"Glm4MoeLiteMoEGate requires norm_topk_prob=True when using fused ops. "
|
||||
"The noaux_tc_op kernel always normalizes routing weights."
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32)
|
||||
)
|
||||
self.register_buffer(
|
||||
"e_score_correction_bias",
|
||||
torch.zeros(self.n_routed_experts, dtype=torch.float32),
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""Initialize gate weights using kaiming uniform."""
|
||||
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass returning (selected_experts, routing_weights).
|
||||
|
||||
Uses fused TensorRT-LLM ops for efficient routing:
|
||||
1. dsv3_router_gemm_op: Router GEMM (when weights are not float32)
|
||||
2. noaux_tc_op: Fused sigmoid + bias + group top-k + normalize + scale
|
||||
"""
|
||||
bsz, seq_len, hidden_dim = hidden_states.shape
|
||||
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Router GEMM - use fused op when weights are not float32
|
||||
if self.weight.dtype == torch.float32:
|
||||
router_logits = F.linear(hidden_states_flat.float(), self.weight)
|
||||
else:
|
||||
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
|
||||
hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Fused routing: sigmoid + bias + group top-k + normalize + scale
|
||||
# noaux_tc_op internally applies:
|
||||
# 1. Sigmoid to router_logits
|
||||
# 2. Adds e_score_correction_bias
|
||||
# 3. Group-wise top-2 scoring and top group selection
|
||||
# 4. Top-k expert selection from selected groups
|
||||
# 5. Gathers weights from sigmoid scores
|
||||
# 6. Normalizes and scales by routed_scaling_factor
|
||||
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
|
||||
router_logits,
|
||||
self.e_score_correction_bias,
|
||||
self.n_group,
|
||||
self.topk_group,
|
||||
self.top_k,
|
||||
self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
class Glm4MoeLiteMoE(nn.Module):
|
||||
"""Mixture of Experts layer for GLM4 MoE Lite."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
# Routed experts - use ModuleList with individual expert modules
|
||||
# This creates state_dict keys like: experts.0.gate_proj.weight
|
||||
# which matches the checkpoint structure
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Glm4MoeLiteMLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# Gate
|
||||
self.gate = Glm4MoeLiteMoEGate(config)
|
||||
|
||||
# Shared experts
|
||||
if config.n_shared_experts is not None and config.n_shared_experts > 0:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = Glm4MoeLiteMLP(config, intermediate_size=intermediate_size)
|
||||
else:
|
||||
self.shared_experts = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
|
||||
selected_experts, routing_weights = self.gate(hidden_states)
|
||||
|
||||
# Use torch_moe custom op for routed experts
|
||||
final_hidden_states = torch.ops.auto_deploy.torch_moe(
|
||||
hidden_states.view(-1, hidden_states.shape[-1]),
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_weight=[expert.gate_proj.weight for expert in self.experts],
|
||||
w2_weight=[expert.down_proj.weight for expert in self.experts],
|
||||
w3_weight=[expert.up_proj.weight for expert in self.experts],
|
||||
is_gated_mlp=True,
|
||||
act_fn=int(ActivationType.Silu),
|
||||
)
|
||||
|
||||
final_hidden_states = final_hidden_states.view(*orig_shape)
|
||||
|
||||
# Add shared experts output if present
|
||||
if self.shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + self.shared_experts(identity)
|
||||
|
||||
return final_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
|
||||
class Glm4MoeLiteAttention(nn.Module):
|
||||
"""Multi-head Latent Attention (MLA) for GLM4 MoE Lite.
|
||||
|
||||
Uses compressed KV representation with latent projections.
|
||||
Receives position embeddings from the model level (shared rotary embedding).
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
||||
|
||||
# Q projection (with optional LoRA)
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = Glm4MoeLiteRMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False
|
||||
)
|
||||
|
||||
# KV projection with MQA
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = Glm4MoeLiteRMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Output projection
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
# Softmax scale
|
||||
self.softmax_scale = self.qk_head_dim ** (-0.5)
|
||||
# Apply mscale adjustment if using YaRN scaling with factor
|
||||
if (
|
||||
config.rope_scaling is not None
|
||||
and isinstance(config.rope_scaling, dict)
|
||||
and "factor" in config.rope_scaling
|
||||
):
|
||||
mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = Glm4MoeLiteYarnRotaryEmbedding._yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Q projection
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
|
||||
# Shape: [B, S, N, qk_head_dim] (BSND layout)
|
||||
q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# KV projection - keep compressed form
|
||||
kv_a_output = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Apply layernorm to compressed_kv
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
|
||||
# k_pe: [B, S, 1, qk_rope_head_dim] (BSND layout, shared across heads)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
|
||||
# Get cos/sin from position_embeddings (full cached from shared rotary embedding)
|
||||
cos, sin = position_embeddings # Full table: [max_seq_len, head_dim]
|
||||
cos = cos[position_ids] # [B, S, head_dim]
|
||||
sin = sin[position_ids] # [B, S, head_dim]
|
||||
|
||||
# Apply RoPE using custom op
|
||||
q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q_pe,
|
||||
k_pe,
|
||||
cos,
|
||||
sin,
|
||||
2, # unsqueeze_dim=2 for BSND layout
|
||||
)
|
||||
|
||||
# Call MLA with compressed KV
|
||||
attn_output = torch.ops.auto_deploy.torch_mla(
|
||||
q_nope, # [B, S, N, qk_nope_head_dim]
|
||||
q_pe_rotated, # [B, S, N, qk_rope_head_dim]
|
||||
compressed_kv, # [B, S, kv_lora_rank]
|
||||
kpe, # [B, S, 1, qk_rope_head_dim]
|
||||
self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank]
|
||||
True, # is_causal
|
||||
self.softmax_scale,
|
||||
"bsnd", # layout
|
||||
)
|
||||
|
||||
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Glm4MoeLiteDecoderLayer(nn.Module):
|
||||
"""Transformer decoder layer for GLM4 MoE Lite."""
|
||||
|
||||
def __init__(self, config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Attention
|
||||
self.self_attn = Glm4MoeLiteAttention(config, layer_idx=layer_idx)
|
||||
|
||||
# MLP or MoE
|
||||
# Layer 0 to first_k_dense_replace-1 use dense MLP, rest use MoE
|
||||
use_moe = config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
|
||||
if use_moe:
|
||||
self.mlp = Glm4MoeLiteMoE(config)
|
||||
else:
|
||||
self.mlp = Glm4MoeLiteMLP(config)
|
||||
|
||||
# Layer norms
|
||||
self.input_layernorm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Glm4MoeLiteRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, position_ids, position_embeddings)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP/MoE
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass
|
||||
class Glm4MoeLiteOutput(ModelOutput):
|
||||
"""Output for Glm4MoeLiteModel."""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Glm4MoeLiteCausalLMOutput(ModelOutput):
|
||||
"""Output for Glm4MoeLiteForCausalLM."""
|
||||
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class Glm4MoeLitePreTrainedModel(PreTrainedModel):
|
||||
"""Base class for GLM4 MoE Lite models."""
|
||||
|
||||
config_class = Glm4MoeLiteConfig
|
||||
base_model_prefix = "model"
|
||||
_no_split_modules = ["Glm4MoeLiteDecoderLayer"]
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Glm4MoeLiteModel(Glm4MoeLitePreTrainedModel):
|
||||
"""GLM4 MoE Lite transformer decoder model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Glm4MoeLiteDecoderLayer(config, layer_idx=idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Shared rotary embedding at model level (not per-layer)
|
||||
# This creates a single set of cos/sin buffers for all layers
|
||||
self.rotary_emb = self._init_rope(config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def _init_rope(self, config):
|
||||
"""Initialize shared rotary embedding for all layers."""
|
||||
qk_rope_head_dim = config.qk_rope_head_dim
|
||||
|
||||
# Compute attention_scaling for RoPE (same logic as in attention)
|
||||
attention_scaling = 1.0
|
||||
if (
|
||||
config.rope_scaling is not None
|
||||
and isinstance(config.rope_scaling, dict)
|
||||
and "factor" in config.rope_scaling
|
||||
):
|
||||
mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = Glm4MoeLiteYarnRotaryEmbedding._yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
attention_scaling = mscale
|
||||
|
||||
# Check if rope_scaling is None, empty, or missing required "factor" key
|
||||
use_yarn = (
|
||||
config.rope_scaling is not None
|
||||
and isinstance(config.rope_scaling, dict)
|
||||
and "factor" in config.rope_scaling
|
||||
)
|
||||
|
||||
if not use_yarn:
|
||||
return Glm4MoeLiteRotaryEmbedding(
|
||||
qk_rope_head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
attention_scaling=attention_scaling,
|
||||
)
|
||||
else:
|
||||
scaling_factor = config.rope_scaling["factor"]
|
||||
kwargs = {
|
||||
key: config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in config.rope_scaling
|
||||
}
|
||||
return Glm4MoeLiteYarnRotaryEmbedding(
|
||||
qk_rope_head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=config.rope_theta,
|
||||
attention_scaling=attention_scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> Glm4MoeLiteOutput:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("Cannot specify both input_ids and inputs_embeds")
|
||||
elif input_ids is None and inputs_embeds is None:
|
||||
raise ValueError("Must specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Compute position embeddings once from shared rotary embedding
|
||||
# This returns full cached cos/sin tables
|
||||
position_embeddings = self.rotary_emb(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states = decoder_layer(hidden_states, position_ids, position_embeddings)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return Glm4MoeLiteOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class Glm4MoeLiteForCausalLM(Glm4MoeLitePreTrainedModel, GenerationMixin):
|
||||
"""GLM4 MoE Lite model with language modeling head."""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config)
|
||||
self.model = Glm4MoeLiteModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> Glm4MoeLiteCausalLMOutput:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = self.lm_head(hidden_states).float()
|
||||
|
||||
return Glm4MoeLiteCausalLMOutput(logits=logits)
|
||||
|
||||
|
||||
# Register with AutoModelForCausalLMFactory
|
||||
AutoModelForCausalLMFactory.register_custom_model_cls("Glm4MoeLiteConfig", Glm4MoeLiteForCausalLM)
|
||||
@ -1,185 +0,0 @@
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
|
||||
def deepseek_v3_attention(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""DeepSeekV3Attention forward function rewritten to wrap MultiheadLatentAttention as a custom op."""
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
"Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# If else paths are determined by config.json
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
_, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
raise ValueError("past_key_value is not supported")
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
# Use custom op to capture mla. This does not handle KV cache
|
||||
# as passing transformers Cache into a custom op is throwing an error.
|
||||
# Would not be an issue, cause we intend to replace mla op with our implementation further along the pipeline
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_deepseek_fused_mla(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv,
|
||||
k_pe,
|
||||
cos,
|
||||
sin,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
self.softmax_scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
# This patched module matches exactly with HF generate
|
||||
@torch.inference_mode()
|
||||
def deepseek_v3_moe_exact(self, hidden_states):
|
||||
"""DeepSeekV3MoE forward function rewritten to enable torch export.
|
||||
|
||||
This custom implementation matches exactly with the deepseek implementation. There are
|
||||
some errors in the output tensors when the index_add based implementation is used, leading
|
||||
to some mismatch in the outputs for some prompts. This ensures exact match between HF output
|
||||
without custom patch and with custom patch.
|
||||
"""
|
||||
identity = hidden_states
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
|
||||
selected_experts, routing_weights, *_ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
idxs = torch.argsort(selected_experts.view(-1), stable=True)
|
||||
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.experts_per_rank
|
||||
).permute(2, 1, 0)
|
||||
outputs = []
|
||||
for expert_idx in range(len(self.experts)):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
_, top_x = torch.where(expert_mask[expert_idx])
|
||||
# Sort the top_xs and idx
|
||||
sorted, _ = torch.sort(top_x)
|
||||
tokens_for_this_expert = hidden_states[None, sorted].reshape(-1, hidden_dim)
|
||||
expert_out = expert_layer(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
|
||||
outs = torch.cat(outputs, dim=0)
|
||||
# Wrap torch.zeros() in a custom op to fix meta device issue during inference.
|
||||
new_x = torch.zeros(
|
||||
(*selected_experts.view(-1).shape, hidden_dim),
|
||||
device=selected_experts.device,
|
||||
dtype=outs.dtype,
|
||||
)
|
||||
new_x[idxs] = outs
|
||||
final_hidden_states = (
|
||||
new_x.view(*selected_experts.shape, -1)
|
||||
.type(routing_weights.dtype)
|
||||
.mul_(routing_weights.unsqueeze(-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + self.shared_experts(identity)
|
||||
|
||||
return final_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def deepseek_v3_moe(self, hidden_states):
|
||||
"""DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export."""
|
||||
|
||||
selected_experts, routing_weights, *_ = self.gate(hidden_states)
|
||||
final_hidden_states = torch.ops.auto_deploy.torch_moe(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_weight=[expert.gate_proj.weight for expert in self.experts],
|
||||
w2_weight=[expert.down_proj.weight for expert in self.experts],
|
||||
w3_weight=[expert.up_proj.weight for expert in self.experts],
|
||||
)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + self.shared_experts(hidden_states)
|
||||
|
||||
return final_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
|
||||
def deepseek_v3_rope(self, x, seq_len=None):
|
||||
"""DeepSeekV3 Rotary Embedding forward function rewritten to enable torch export.
|
||||
We return the full cached cos and sin values, instead of slicing them based on seq_len as this
|
||||
would cause an issue during the generate phase (when seq_len=1 from input_ids). We also move the cos
|
||||
sin buffers to appropriate device to enable export.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.cos_cached.to(dtype=x.dtype).to(device=x.device),
|
||||
self.sin_cached.to(dtype=x.dtype).to(device=x.device),
|
||||
)
|
||||
|
||||
|
||||
_from_config_original = AutoModelForCausalLM.from_config
|
||||
|
||||
CUSTOM_MODULE_PATCHES: Dict[str, callable] = {
|
||||
"DeepseekV3MoE": deepseek_v3_moe,
|
||||
"DeepseekV2MoE": deepseek_v3_moe,
|
||||
"DeepseekV3RotaryEmbedding": deepseek_v3_rope,
|
||||
"DeepseekV3YarnRotaryEmbedding": deepseek_v3_rope,
|
||||
"DeepseekV2RotaryEmbedding": deepseek_v3_rope,
|
||||
"DeepseekV2YarnRotaryEmbedding": deepseek_v3_rope,
|
||||
}
|
||||
|
||||
|
||||
def get_model_from_config_patched(config, **kwargs):
|
||||
model = _from_config_original(config, **kwargs)
|
||||
# Patch modules
|
||||
for _, module in model.named_modules():
|
||||
if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys():
|
||||
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# TODO: figure out how this can be incorporated into the export patch system
|
||||
AutoModelForCausalLM.from_config = get_model_from_config_patched
|
||||
@ -83,7 +83,7 @@ class QuantConfigReaderRegistry:
|
||||
|
||||
@QuantConfigReaderRegistry.register("modelopt")
|
||||
class ModelOPTQuantConfigReader(QuantConfigReader):
|
||||
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*")
|
||||
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate")
|
||||
|
||||
def read_config(self, config: Dict) -> Dict:
|
||||
producer = config.get("producer", {}).get("name")
|
||||
|
||||
@ -411,3 +411,75 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestGLM4Flash(LlmapiAccuracyTestHarness):
|
||||
"""Accuracy regression tests for GLM-4.7-Flash.
|
||||
|
||||
TODO: enable in CI, see https://github.com/NVIDIA/TensorRT-LLM/issues/11117
|
||||
|
||||
In the meantime, you should run this test locally:
|
||||
|
||||
```
|
||||
cd tests/integration/defs
|
||||
TRTLLM_ACCURACY_NO_REFERENCE=1 pytest -svv "accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]"
|
||||
```
|
||||
"""
|
||||
|
||||
MODEL_NAME = "zai-org/GLM-4.7-Flash"
|
||||
MODEL_PATH = MODEL_NAME # Model is in HF_CACHE
|
||||
# Set minimum possible seq len + small buffer, for test speed & memory usage
|
||||
MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN,
|
||||
GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN)
|
||||
MAX_NUM_TOKENS = MAX_SEQ_LEN
|
||||
|
||||
def get_default_kwargs(self, enable_chunked_prefill=False):
|
||||
config = {
|
||||
"skip_tokenizer_init": False,
|
||||
"trust_remote_code": True,
|
||||
"compile_backend": "torch-cudagraph",
|
||||
"max_batch_size": 128,
|
||||
"max_seq_len": self.MAX_SEQ_LEN,
|
||||
"max_num_tokens": self.MAX_NUM_TOKENS,
|
||||
"skip_loading_weights": False,
|
||||
"disable_overlap_scheduler": False,
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False,
|
||||
"free_gpu_memory_fraction": 0.88
|
||||
},
|
||||
"model_kwargs": {
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"transforms": {
|
||||
"fuse_nvfp4_moe": {
|
||||
"allow_different_input_scales": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
if enable_chunked_prefill:
|
||||
config["enable_chunked_prefill"] = True
|
||||
config[
|
||||
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
|
||||
return config
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
eos_id = -1
|
||||
beam_width = 1
|
||||
return SamplingParams(end_id=eos_id,
|
||||
pad_id=eos_id,
|
||||
n=beam_width,
|
||||
use_beam_search=beam_width > 1)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
||||
def test_auto_dtype(self, enable_chunked_prefill):
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill)
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
**kwargs) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -400,7 +400,8 @@ _SMALL_MODEL_CONFIGS = {
|
||||
"num_hidden_layers": 2,
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 64,
|
||||
"kv_lora_rank": 128,
|
||||
"kv_lora_rank": 512, # NOTE: must be 512 (default) for flashinfer_mla to work
|
||||
"qk_rope_head_dim": 64, # NOTE: must be 64 (default) for flashinfer_mla to work
|
||||
"moe_intermediate_size": 128,
|
||||
"n_group": 2,
|
||||
"topk_group": 2,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,780 @@
|
||||
"""Comprehensive test suite for torch MLA backend operations.
|
||||
|
||||
Tests the torch_mla source op and torch_backend_mla_with_cache cached op
|
||||
with FlashInfer-compatible compressed cache layout.
|
||||
|
||||
Key features:
|
||||
- 5 tensor arguments: q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight
|
||||
- Compressed cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
- Prefill: Expand compressed_kv, compute normal attention
|
||||
- Generate: Weight absorption for efficiency
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
|
||||
def numpy_mla_reference_with_expansion(
|
||||
q_nope: np.ndarray,
|
||||
q_pe: np.ndarray,
|
||||
compressed_kv: np.ndarray,
|
||||
kpe: np.ndarray,
|
||||
kv_b_proj_weight: np.ndarray,
|
||||
mla_cache: np.ndarray,
|
||||
seq_len: np.ndarray,
|
||||
input_pos: np.ndarray,
|
||||
cache_loc: np.ndarray,
|
||||
seq_start: np.ndarray,
|
||||
scale: float = None,
|
||||
kv_lora_rank: int = None,
|
||||
is_generate: bool = False,
|
||||
):
|
||||
"""Numpy reference implementation of MLA attention with FlashInfer cache layout.
|
||||
|
||||
This expands compressed_kv using kv_b_proj_weight for attention computation.
|
||||
"""
|
||||
# Get dimensions
|
||||
if is_generate:
|
||||
batch_size = q_nope.shape[0]
|
||||
num_heads = q_nope.shape[2]
|
||||
qk_nope_head_dim = q_nope.shape[3]
|
||||
qk_rope_head_dim = q_pe.shape[3]
|
||||
else:
|
||||
batch_size = len(seq_len)
|
||||
num_heads = q_nope.shape[2]
|
||||
qk_nope_head_dim = q_nope.shape[3]
|
||||
qk_rope_head_dim = q_pe.shape[3]
|
||||
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
if kv_lora_rank is None:
|
||||
kv_lora_rank = compressed_kv.shape[-1]
|
||||
|
||||
# Infer v_head_dim from kv_b_proj_weight
|
||||
out_features = kv_b_proj_weight.shape[0]
|
||||
kv_head_dim = out_features // num_heads
|
||||
v_head_dim = kv_head_dim - qk_nope_head_dim
|
||||
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(qk_head_dim)
|
||||
|
||||
# Update MLA cache first
|
||||
if is_generate:
|
||||
for i in range(batch_size):
|
||||
cache_idx = cache_loc[i]
|
||||
pos = input_pos[i]
|
||||
mla_cache[cache_idx, pos, :kv_lora_rank] = compressed_kv[i, 0]
|
||||
mla_cache[cache_idx, pos, kv_lora_rank:] = kpe[i, 0, 0]
|
||||
else:
|
||||
for i in range(batch_size):
|
||||
cache_idx = cache_loc[i]
|
||||
pos = input_pos[i]
|
||||
seq_len_i = seq_len[i]
|
||||
seq_start_i = seq_start[i]
|
||||
for j in range(seq_len_i):
|
||||
mla_cache[cache_idx, pos + j, :kv_lora_rank] = compressed_kv[seq_start_i + j]
|
||||
mla_cache[cache_idx, pos + j, kv_lora_rank:] = kpe[seq_start_i + j, 0]
|
||||
|
||||
# Compute attention for each sequence
|
||||
outputs = []
|
||||
|
||||
for i in range(batch_size):
|
||||
cache_idx = cache_loc[i]
|
||||
pos = input_pos[i]
|
||||
seq_len_i = seq_len[i]
|
||||
seq_start_i = seq_start[i]
|
||||
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
|
||||
# Get query for this sequence
|
||||
if is_generate:
|
||||
q_nope_seq = q_nope[i, 0] # [N, qk_nope_head_dim]
|
||||
q_pe_seq = q_pe[i, 0] # [N, qk_rope_head_dim]
|
||||
else:
|
||||
q_nope_seq = q_nope[seq_start_i : seq_start_i + seq_len_i]
|
||||
q_pe_seq = q_pe[seq_start_i : seq_start_i + seq_len_i]
|
||||
|
||||
# Get cached compressed_kv and kpe
|
||||
kv_seq_len = pos + seq_len_i
|
||||
cached_data = mla_cache[cache_idx, :kv_seq_len]
|
||||
compressed_kv_cached = cached_data[:, :kv_lora_rank]
|
||||
kpe_cached = cached_data[:, kv_lora_rank:]
|
||||
|
||||
# Expand compressed_kv using kv_b_proj_weight
|
||||
# compressed_kv_cached: [kv_seq_len, kv_lora_rank]
|
||||
# kv_b_proj_weight: [N * kv_head_dim, kv_lora_rank]
|
||||
kv_expanded = np.matmul(compressed_kv_cached, kv_b_proj_weight.T)
|
||||
kv_expanded = kv_expanded.reshape(kv_seq_len, num_heads, kv_head_dim)
|
||||
|
||||
k_nope = kv_expanded[:, :, :qk_nope_head_dim]
|
||||
v = kv_expanded[:, :, qk_nope_head_dim:]
|
||||
|
||||
# Expand kpe to all heads
|
||||
kpe_expanded = np.broadcast_to(
|
||||
kpe_cached[:, None, :], (kv_seq_len, num_heads, qk_rope_head_dim)
|
||||
)
|
||||
|
||||
# Construct full query and key
|
||||
if is_generate:
|
||||
query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)
|
||||
else:
|
||||
query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)
|
||||
|
||||
key_full = np.concatenate([k_nope, kpe_expanded], axis=-1)
|
||||
|
||||
# Compute attention scores
|
||||
if is_generate:
|
||||
attn_scores = np.einsum("nh,knh->nk", query_full, key_full) * scale
|
||||
else:
|
||||
attn_scores = np.einsum("snh,knh->snk", query_full, key_full) * scale
|
||||
causal_mask = np.triu(np.ones((seq_len_i, kv_seq_len)), k=kv_seq_len - seq_len_i + 1)
|
||||
attn_scores = np.where(causal_mask[:, None, :], -np.inf, attn_scores)
|
||||
|
||||
# Apply softmax
|
||||
attn_scores_max = np.max(attn_scores, axis=-1, keepdims=True)
|
||||
attn_scores_exp = np.exp(attn_scores - attn_scores_max)
|
||||
attn_weights = attn_scores_exp / np.sum(attn_scores_exp, axis=-1, keepdims=True)
|
||||
|
||||
# Compute output
|
||||
if is_generate:
|
||||
attn_out = np.einsum("nk,knh->nh", attn_weights, v)
|
||||
else:
|
||||
attn_out = np.einsum("snk,knh->snh", attn_weights, v)
|
||||
|
||||
outputs.append(attn_out)
|
||||
|
||||
# Concatenate outputs
|
||||
if len(outputs) == 0:
|
||||
return np.zeros((1, 0, num_heads, v_head_dim), dtype=np.float32)
|
||||
elif is_generate:
|
||||
result = np.stack(outputs, axis=0)
|
||||
return result[:, None, :, :]
|
||||
else:
|
||||
result = np.concatenate(outputs, axis=0)
|
||||
return result[None, :, :, :]
|
||||
|
||||
|
||||
class TestTorchMLASourceOp:
|
||||
"""Test torch_mla source op (without cache)."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
"""Setup test configuration."""
|
||||
self.device = "cuda"
|
||||
self.dtype = torch.bfloat16
|
||||
self.atol = 1e-2
|
||||
self.rtol = 1e-2
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def _create_mla_data(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
kv_lora_rank: int,
|
||||
v_head_dim: int,
|
||||
layout: str = "bsnd",
|
||||
):
|
||||
"""Create test data for MLA source op with compressed_kv."""
|
||||
kv_head_dim = qk_nope_head_dim + v_head_dim
|
||||
|
||||
if layout == "bsnd":
|
||||
q_nope = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
q_pe = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_rope_head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
compressed_kv = torch.randn(
|
||||
batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device
|
||||
)
|
||||
kpe = torch.randn(
|
||||
batch_size, seq_len, 1, qk_rope_head_dim, dtype=self.dtype, device=self.device
|
||||
)
|
||||
else: # bnsd
|
||||
q_nope = torch.randn(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
qk_nope_head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
q_pe = torch.randn(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
qk_rope_head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
compressed_kv = torch.randn(
|
||||
batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device
|
||||
)
|
||||
kpe = torch.randn(
|
||||
batch_size, 1, seq_len, qk_rope_head_dim, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
# kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank]
|
||||
kv_b_proj_weight = torch.randn(
|
||||
num_heads * kv_head_dim, kv_lora_rank, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
return {
|
||||
"q_nope": q_nope,
|
||||
"q_pe": q_pe,
|
||||
"compressed_kv": compressed_kv,
|
||||
"kpe": kpe,
|
||||
"kv_b_proj_weight": kv_b_proj_weight,
|
||||
}
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic MLA source op functionality."""
|
||||
batch_size, seq_len, num_heads = 2, 4, 8
|
||||
qk_nope_head_dim, qk_rope_head_dim = 128, 64
|
||||
kv_lora_rank = 512
|
||||
v_head_dim = 128
|
||||
|
||||
data = self._create_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
)
|
||||
|
||||
output = torch.ops.auto_deploy.torch_mla(
|
||||
data["q_nope"],
|
||||
data["q_pe"],
|
||||
data["compressed_kv"],
|
||||
data["kpe"],
|
||||
data["kv_b_proj_weight"],
|
||||
True, # is_causal
|
||||
None, # scale
|
||||
"bsnd", # layout
|
||||
)
|
||||
|
||||
# Verify output shape: [B, S, N, v_head_dim]
|
||||
expected_shape = (batch_size, seq_len, num_heads, v_head_dim)
|
||||
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
|
||||
|
||||
# Verify output is finite
|
||||
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
|
||||
|
||||
def test_both_layouts(self):
|
||||
"""Test MLA source op with both bsnd and bnsd layouts."""
|
||||
batch_size, seq_len, num_heads = 2, 4, 8
|
||||
qk_nope_head_dim, qk_rope_head_dim = 64, 32
|
||||
kv_lora_rank = 256
|
||||
v_head_dim = 64
|
||||
|
||||
for layout in ["bsnd", "bnsd"]:
|
||||
data = self._create_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
layout,
|
||||
)
|
||||
|
||||
output = torch.ops.auto_deploy.torch_mla(
|
||||
data["q_nope"],
|
||||
data["q_pe"],
|
||||
data["compressed_kv"],
|
||||
data["kpe"],
|
||||
data["kv_b_proj_weight"],
|
||||
True,
|
||||
None,
|
||||
layout,
|
||||
)
|
||||
|
||||
if layout == "bsnd":
|
||||
expected_shape = (batch_size, seq_len, num_heads, v_head_dim)
|
||||
else:
|
||||
expected_shape = (batch_size, num_heads, seq_len, v_head_dim)
|
||||
|
||||
assert output.shape == expected_shape, (
|
||||
f"Layout {layout}: Expected {expected_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
def test_custom_scale(self):
|
||||
"""Test MLA source op with custom scale."""
|
||||
batch_size, seq_len, num_heads = 1, 2, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 32, 16
|
||||
kv_lora_rank = 128
|
||||
v_head_dim = 32
|
||||
|
||||
data = self._create_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
)
|
||||
|
||||
# Test with default scale
|
||||
output_default = torch.ops.auto_deploy.torch_mla(
|
||||
data["q_nope"],
|
||||
data["q_pe"],
|
||||
data["compressed_kv"],
|
||||
data["kpe"],
|
||||
data["kv_b_proj_weight"],
|
||||
True,
|
||||
None,
|
||||
"bsnd",
|
||||
)
|
||||
|
||||
# Test with custom scale
|
||||
custom_scale = 0.5
|
||||
output_custom = torch.ops.auto_deploy.torch_mla(
|
||||
data["q_nope"],
|
||||
data["q_pe"],
|
||||
data["compressed_kv"],
|
||||
data["kpe"],
|
||||
data["kv_b_proj_weight"],
|
||||
True,
|
||||
custom_scale,
|
||||
"bsnd",
|
||||
)
|
||||
|
||||
# Outputs should be different
|
||||
assert not torch.allclose(output_default, output_custom, atol=1e-3), (
|
||||
"Custom scale should affect output"
|
||||
)
|
||||
|
||||
|
||||
class TestTorchBackendMLAWithCache:
|
||||
"""Test torch_backend_mla_with_cache cached op."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
"""Setup test configuration."""
|
||||
self.device = "cuda"
|
||||
self.dtype = torch.bfloat16
|
||||
self.atol = 5e-2
|
||||
self.rtol = 5e-2
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def _create_cached_mla_data(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
kv_lora_rank: int,
|
||||
v_head_dim: int,
|
||||
max_seq_len: int,
|
||||
cache_offset: int = 0,
|
||||
):
|
||||
"""Create test data for cached MLA op with FlashInfer layout."""
|
||||
kv_head_dim = qk_nope_head_dim + v_head_dim
|
||||
|
||||
# Create input tensors (BSND layout)
|
||||
q_nope = torch.randn(
|
||||
batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=self.dtype, device=self.device
|
||||
)
|
||||
q_pe = torch.randn(
|
||||
batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=self.dtype, device=self.device
|
||||
)
|
||||
compressed_kv = torch.randn(
|
||||
batch_size, seq_len, kv_lora_rank, dtype=self.dtype, device=self.device
|
||||
)
|
||||
kpe = torch.randn(
|
||||
batch_size, seq_len, 1, qk_rope_head_dim, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
# kv_b_proj_weight: [num_heads * kv_head_dim, kv_lora_rank]
|
||||
kv_b_proj_weight = torch.randn(
|
||||
num_heads * kv_head_dim, kv_lora_rank, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
# Create FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
|
||||
mla_cache = torch.zeros(
|
||||
batch_size,
|
||||
max_seq_len,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-fill cache with random data if cache_offset > 0
|
||||
if cache_offset > 0:
|
||||
mla_cache[:, :cache_offset, :] = torch.randn(
|
||||
batch_size,
|
||||
cache_offset,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Setup metadata
|
||||
seq_len_tensor = torch.full((batch_size,), seq_len, device=self.device, dtype=torch.int32)
|
||||
input_pos = torch.full((batch_size,), cache_offset, device=self.device, dtype=torch.int32)
|
||||
cache_loc = torch.arange(batch_size, device=self.device, dtype=torch.int32)
|
||||
|
||||
if seq_len == 1:
|
||||
# Generate phase
|
||||
batch_info_host = torch.tensor(
|
||||
[0, 0, batch_size], device=self.device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen = torch.arange(batch_size, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
# Context phase
|
||||
batch_info_host = torch.tensor(
|
||||
[batch_size, batch_size * seq_len, 0], device=self.device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen = torch.arange(
|
||||
0, batch_size * seq_len, seq_len, device=self.device, dtype=torch.int32
|
||||
)
|
||||
# Flatten inputs for context phase
|
||||
q_nope = q_nope.view(1, batch_size * seq_len, num_heads, qk_nope_head_dim)
|
||||
q_pe = q_pe.view(1, batch_size * seq_len, num_heads, qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(1, batch_size * seq_len, kv_lora_rank)
|
||||
kpe = kpe.view(1, batch_size * seq_len, 1, qk_rope_head_dim)
|
||||
|
||||
return {
|
||||
"q_nope": q_nope,
|
||||
"q_pe": q_pe,
|
||||
"compressed_kv": compressed_kv,
|
||||
"kpe": kpe,
|
||||
"kv_b_proj_weight": kv_b_proj_weight,
|
||||
"batch_info_host": batch_info_host,
|
||||
"seq_len": seq_len_tensor,
|
||||
"input_pos": input_pos,
|
||||
"cache_loc": cache_loc,
|
||||
"cu_seqlen": cu_seqlen,
|
||||
"mla_cache": mla_cache,
|
||||
"kv_lora_rank": kv_lora_rank,
|
||||
}
|
||||
|
||||
def _run_cached_mla(self, data, scale=None):
|
||||
"""Run cached MLA operation."""
|
||||
return torch.ops.auto_deploy.torch_cached_mla_with_cache(
|
||||
data["q_nope"],
|
||||
data["q_pe"],
|
||||
data["compressed_kv"],
|
||||
data["kpe"],
|
||||
data["kv_b_proj_weight"],
|
||||
data["batch_info_host"],
|
||||
data["seq_len"],
|
||||
data["input_pos"],
|
||||
data["cache_loc"],
|
||||
data["cu_seqlen"],
|
||||
data["mla_cache"],
|
||||
scale,
|
||||
data["kv_lora_rank"],
|
||||
)
|
||||
|
||||
def test_generate_phase_basic(self):
|
||||
"""Test generate phase (single token) basic functionality."""
|
||||
batch_size, seq_len, num_heads = 2, 1, 8
|
||||
qk_nope_head_dim, qk_rope_head_dim = 64, 32
|
||||
kv_lora_rank = 256
|
||||
v_head_dim = 64
|
||||
max_seq_len = 128
|
||||
cache_offset = 5
|
||||
|
||||
data = self._create_cached_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
max_seq_len,
|
||||
cache_offset,
|
||||
)
|
||||
|
||||
output = self._run_cached_mla(data)
|
||||
|
||||
# Verify output shape
|
||||
expected_shape = (batch_size, seq_len, num_heads, v_head_dim)
|
||||
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
|
||||
|
||||
# Verify output is finite
|
||||
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
|
||||
|
||||
def test_context_phase_basic(self):
|
||||
"""Test context phase (multi-token) basic functionality."""
|
||||
batch_size, seq_len, num_heads = 2, 4, 8
|
||||
qk_nope_head_dim, qk_rope_head_dim = 64, 32
|
||||
kv_lora_rank = 256
|
||||
v_head_dim = 64
|
||||
max_seq_len = 128
|
||||
|
||||
data = self._create_cached_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
output = self._run_cached_mla(data)
|
||||
|
||||
# Verify output shape
|
||||
expected_shape = (1, batch_size * seq_len, num_heads, v_head_dim)
|
||||
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
|
||||
|
||||
# Verify output is finite
|
||||
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
|
||||
|
||||
def test_cache_update_correctness(self):
|
||||
"""Test that cache is updated correctly during forward pass."""
|
||||
batch_size, seq_len, num_heads = 1, 1, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 32, 16
|
||||
kv_lora_rank = 128
|
||||
v_head_dim = 32
|
||||
max_seq_len = 32
|
||||
cache_offset = 5
|
||||
|
||||
data = self._create_cached_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
max_seq_len,
|
||||
cache_offset,
|
||||
)
|
||||
|
||||
# Store original cache values at target position
|
||||
original_cache_at_pos = data["mla_cache"][0, cache_offset].clone()
|
||||
|
||||
# Run forward pass
|
||||
_ = self._run_cached_mla(data)
|
||||
|
||||
# Check cache was updated at the correct position
|
||||
updated_cache_at_pos = data["mla_cache"][0, cache_offset]
|
||||
|
||||
# The cache should have been updated
|
||||
assert not torch.allclose(original_cache_at_pos, updated_cache_at_pos, atol=1e-6), (
|
||||
"Cache should have been updated at the target position"
|
||||
)
|
||||
|
||||
def test_cache_layout_flashinfer_compatible(self):
|
||||
"""Test that cache layout matches FlashInfer spec (no num_heads dimension)."""
|
||||
batch_size, seq_len, num_heads = 2, 1, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 64, 32
|
||||
kv_lora_rank = 512 # DeepSeek-style
|
||||
v_head_dim = 128
|
||||
max_seq_len = 64
|
||||
|
||||
data = self._create_cached_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
# Verify cache shape matches FlashInfer layout: [batch, seq, kv_lora_rank + rope_dim]
|
||||
expected_cache_shape = (batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim)
|
||||
assert data["mla_cache"].shape == expected_cache_shape, (
|
||||
f"Cache shape {data['mla_cache'].shape} doesn't match FlashInfer layout {expected_cache_shape}"
|
||||
)
|
||||
|
||||
# Verify zero-copy slicing works
|
||||
compressed_kv_slice = data["mla_cache"][:, :, :kv_lora_rank]
|
||||
kpe_slice = data["mla_cache"][:, :, kv_lora_rank:]
|
||||
|
||||
assert compressed_kv_slice.shape == (batch_size, max_seq_len, kv_lora_rank)
|
||||
assert kpe_slice.shape == (batch_size, max_seq_len, qk_rope_head_dim)
|
||||
|
||||
# Verify slices share memory (zero-copy)
|
||||
assert compressed_kv_slice.data_ptr() == data["mla_cache"].data_ptr(), (
|
||||
"compressed_kv slice should be zero-copy"
|
||||
)
|
||||
|
||||
def test_generate_with_reference(self):
|
||||
"""Test generate phase against numpy reference."""
|
||||
batch_size, seq_len, num_heads = 2, 1, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 32, 16
|
||||
kv_lora_rank = 64
|
||||
v_head_dim = 32
|
||||
max_seq_len = 64
|
||||
cache_offset = 3
|
||||
|
||||
data = self._create_cached_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
max_seq_len,
|
||||
cache_offset,
|
||||
)
|
||||
|
||||
# Run backend
|
||||
output = self._run_cached_mla(data)
|
||||
|
||||
# Run numpy reference
|
||||
reference = numpy_mla_reference_with_expansion(
|
||||
data["q_nope"].cpu().float().numpy(),
|
||||
data["q_pe"].cpu().float().numpy(),
|
||||
data["compressed_kv"].cpu().float().numpy(),
|
||||
data["kpe"].cpu().float().numpy(),
|
||||
data["kv_b_proj_weight"].cpu().float().numpy(),
|
||||
data["mla_cache"].cpu().float().numpy(),
|
||||
data["seq_len"].cpu().numpy(),
|
||||
data["input_pos"].cpu().numpy(),
|
||||
data["cache_loc"].cpu().numpy(),
|
||||
data["cu_seqlen"].cpu().numpy(),
|
||||
None,
|
||||
kv_lora_rank,
|
||||
is_generate=True,
|
||||
)
|
||||
|
||||
reference_torch = torch.from_numpy(reference).to(output.device, output.dtype)
|
||||
assert torch.allclose(output, reference_torch, atol=self.atol, rtol=self.rtol), (
|
||||
f"Generate phase output doesn't match reference. "
|
||||
f"Max diff: {(output - reference_torch).abs().max():.6f}"
|
||||
)
|
||||
|
||||
def test_dtype_preservation(self):
|
||||
"""Test that output dtype matches input dtype."""
|
||||
batch_size, seq_len, num_heads = 1, 1, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 32, 16
|
||||
kv_lora_rank = 64
|
||||
v_head_dim = 32
|
||||
max_seq_len = 32
|
||||
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
self.dtype = dtype
|
||||
data = self._create_cached_mla_data(
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
kv_lora_rank,
|
||||
v_head_dim,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
output = self._run_cached_mla(data)
|
||||
assert output.dtype == dtype, f"Expected dtype {dtype}, got {output.dtype}"
|
||||
|
||||
def test_memory_efficiency(self):
|
||||
"""Test that cache uses compressed dimensions (no num_heads)."""
|
||||
batch_size = 1
|
||||
max_seq_len = 1024
|
||||
kv_lora_rank = 512
|
||||
qk_rope_head_dim = 64
|
||||
num_heads = 128 # DeepSeek V3
|
||||
|
||||
# FlashInfer compressed cache size
|
||||
compressed_cache_size = batch_size * max_seq_len * (kv_lora_rank + qk_rope_head_dim)
|
||||
|
||||
# Expanded per-head cache size (what we avoid)
|
||||
qk_nope_head_dim = 128
|
||||
v_head_dim = 128
|
||||
expanded_cache_size = (
|
||||
batch_size
|
||||
* max_seq_len
|
||||
* num_heads
|
||||
* (qk_nope_head_dim + v_head_dim + qk_rope_head_dim)
|
||||
)
|
||||
|
||||
# Verify compression ratio
|
||||
compression_ratio = expanded_cache_size / compressed_cache_size
|
||||
assert compression_ratio > 50, f"Expected >50x compression, got {compression_ratio:.1f}x"
|
||||
|
||||
|
||||
class TestMLADescriptor:
|
||||
"""Test MultiHeadLatentAttention descriptor configuration."""
|
||||
|
||||
def _get_mla_descriptor(self):
|
||||
"""Get MLA descriptor from registry."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionRegistry
|
||||
|
||||
return AttentionRegistry.get("torch_mla")
|
||||
|
||||
def test_descriptor_registration(self):
|
||||
"""Test that MLA descriptor is properly registered."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionRegistry
|
||||
|
||||
assert AttentionRegistry.has("torch_mla"), "torch_mla should be registered"
|
||||
|
||||
def test_descriptor_layout(self):
|
||||
"""Test that MLA descriptor uses correct layout."""
|
||||
mla_descriptor = self._get_mla_descriptor()
|
||||
|
||||
assert mla_descriptor.get_attention_layout() == "bsnd", "MLA should use bsnd layout"
|
||||
|
||||
def test_descriptor_num_qkv_args(self):
|
||||
"""Test that MLA descriptor expects 5 tensor args."""
|
||||
mla_descriptor = self._get_mla_descriptor()
|
||||
|
||||
assert mla_descriptor.get_num_qkv_args() == 5, (
|
||||
"MLA should expect 5 tensor args (q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight)"
|
||||
)
|
||||
|
||||
def test_descriptor_source_op(self):
|
||||
"""Test that MLA descriptor points to correct source op."""
|
||||
mla_descriptor = self._get_mla_descriptor()
|
||||
|
||||
source_op = mla_descriptor.get_source_attention_op()
|
||||
assert source_op == torch.ops.auto_deploy.torch_mla, "MLA should use torch_mla as source op"
|
||||
|
||||
def test_descriptor_cached_op(self):
|
||||
"""Test that MLA descriptor points to correct cached op."""
|
||||
mla_descriptor = self._get_mla_descriptor()
|
||||
|
||||
cached_op = mla_descriptor.get_cached_attention_op()
|
||||
assert cached_op == torch.ops.auto_deploy.torch_cached_mla_with_cache.default, (
|
||||
"MLA should use torch_cached_mla_with_cache as cached op"
|
||||
)
|
||||
|
||||
def test_descriptor_standard_metadata(self):
|
||||
"""Test that MLA descriptor uses standard metadata args."""
|
||||
mla_descriptor = self._get_mla_descriptor()
|
||||
|
||||
expected_args = ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
|
||||
actual_args = mla_descriptor.get_standard_metadata_args()
|
||||
assert actual_args == expected_args, (
|
||||
f"Expected standard metadata {expected_args}, got {actual_args}"
|
||||
)
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_attention import update_kv_cache
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.torch_backend_attention import (
|
||||
_update_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
def test_update_kv_cache():
|
||||
@ -26,7 +28,7 @@ def test_update_kv_cache():
|
||||
print("slot_idx: " + str(torch.tensor([0, 1])))
|
||||
print("seq_start: " + str(torch.tensor([0, 3])))
|
||||
|
||||
update_kv_cache(
|
||||
_update_kv_cache(
|
||||
k.view(batch_size * seq_length, n_heads, K_D_HEAD),
|
||||
v.view(batch_size * seq_length, n_heads, V_D_HEAD),
|
||||
k_cache,
|
||||
|
||||
@ -0,0 +1,494 @@
|
||||
"""Testing custom DeepSeekV3 model implementation for auto_deploy export."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_deepseek import (
|
||||
DeepSeekV3Attention,
|
||||
DeepSeekV3DecoderLayer,
|
||||
DeepSeekV3ForCausalLM,
|
||||
DeepSeekV3MLP,
|
||||
DeepSeekV3Model,
|
||||
DeepSeekV3MoE,
|
||||
DeepSeekV3RMSNorm,
|
||||
DeepSeekV3RotaryEmbedding,
|
||||
DeepSeekV3YarnRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class MockDeepSeekConfig(PretrainedConfig):
|
||||
"""Mock DeepSeek config for testing the custom model components."""
|
||||
|
||||
model_type = "deepseek_v3"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Attention config
|
||||
self.num_attention_heads = 8
|
||||
self.qk_nope_head_dim = 64
|
||||
self.qk_rope_head_dim = 32
|
||||
self.v_head_dim = 64
|
||||
self.kv_lora_rank = 128
|
||||
self.q_lora_rank = None # No LoRA for Q in tests
|
||||
self.hidden_size = 256
|
||||
self.rope_theta = 10000.0
|
||||
self.max_position_embeddings = 512
|
||||
self.attention_bias = False
|
||||
self.rope_scaling = None
|
||||
self.rms_norm_eps = 1e-6
|
||||
|
||||
# MLP config
|
||||
self.intermediate_size = 512
|
||||
self.hidden_act = "silu"
|
||||
|
||||
# MoE config
|
||||
self.n_routed_experts = 4
|
||||
self.num_experts_per_tok = 2
|
||||
self.moe_intermediate_size = 256
|
||||
self.n_shared_experts = 1
|
||||
self.routed_scaling_factor = 1.0
|
||||
self.n_group = 1
|
||||
self.topk_group = 1
|
||||
|
||||
# Model config
|
||||
self.num_hidden_layers = 2
|
||||
self.first_k_dense_replace = 1 # First layer is dense, second is MoE
|
||||
self.moe_layer_freq = 1
|
||||
self.vocab_size = 1000
|
||||
self.pad_token_id = 0
|
||||
self.initializer_range = 0.02
|
||||
|
||||
|
||||
class TestDeepSeekV3RMSNorm:
|
||||
"""Test DeepSeekV3RMSNorm implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward_shape(self):
|
||||
"""Test that RMSNorm preserves input shape."""
|
||||
hidden_size = 256
|
||||
norm = DeepSeekV3RMSNorm(hidden_size).to(self.device, self.dtype)
|
||||
|
||||
x = torch.randn(2, 4, hidden_size, dtype=self.dtype, device=self.device)
|
||||
output = norm(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
assert torch.isfinite(output).all()
|
||||
|
||||
def test_output_normalized(self):
|
||||
"""Test that output has approximately unit variance."""
|
||||
hidden_size = 256
|
||||
norm = DeepSeekV3RMSNorm(hidden_size).to(self.device, torch.float32)
|
||||
|
||||
x = torch.randn(2, 4, hidden_size, dtype=torch.float32, device=self.device)
|
||||
output = norm(x)
|
||||
|
||||
# RMS should be close to 1 after normalization (scaled by weight)
|
||||
rms = torch.sqrt((output**2).mean(-1))
|
||||
assert torch.allclose(rms, torch.ones_like(rms), atol=0.1)
|
||||
|
||||
|
||||
class TestDeepSeekV3RotaryEmbedding:
|
||||
"""Test DeepSeekV3 Rotary Embedding implementations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_base_rope_shape(self):
|
||||
"""Test base rotary embedding output shape."""
|
||||
dim = 32
|
||||
max_pos = 512
|
||||
rope = DeepSeekV3RotaryEmbedding(dim, max_pos).to(self.device)
|
||||
|
||||
x = torch.randn(2, 4, 8, dim, dtype=self.dtype, device=self.device)
|
||||
cos, sin = rope(x)
|
||||
|
||||
# Should return full cached values
|
||||
assert cos.shape == (max_pos, dim)
|
||||
assert sin.shape == (max_pos, dim)
|
||||
|
||||
def test_yarn_rope_shape(self):
|
||||
"""Test YaRN rotary embedding output shape."""
|
||||
dim = 32
|
||||
max_pos = 512
|
||||
rope = DeepSeekV3YarnRotaryEmbedding(
|
||||
dim,
|
||||
max_pos,
|
||||
scaling_factor=2.0,
|
||||
original_max_position_embeddings=256,
|
||||
).to(self.device)
|
||||
|
||||
x = torch.randn(2, 4, 8, dim, dtype=self.dtype, device=self.device)
|
||||
cos, sin = rope(x)
|
||||
|
||||
assert cos.shape == (max_pos, dim)
|
||||
assert sin.shape == (max_pos, dim)
|
||||
|
||||
|
||||
class TestDeepSeekV3MLP:
|
||||
"""Test DeepSeekV3MLP implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward_shape(self):
|
||||
"""Test MLP output shape."""
|
||||
config = MockDeepSeekConfig()
|
||||
mlp = DeepSeekV3MLP(config).to(self.device, self.dtype)
|
||||
|
||||
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
|
||||
output = mlp(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
assert torch.isfinite(output).all()
|
||||
|
||||
|
||||
class TestDeepSeekV3Attention:
|
||||
"""Test DeepSeekV3Attention (MLA) implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward_shape(self):
|
||||
"""Test attention output shape."""
|
||||
config = MockDeepSeekConfig()
|
||||
attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
output = attn(hidden_states, position_ids)
|
||||
|
||||
assert output.shape == hidden_states.shape
|
||||
assert torch.isfinite(output).all()
|
||||
|
||||
def test_different_batch_sizes(self):
|
||||
"""Test attention with different batch sizes."""
|
||||
config = MockDeepSeekConfig()
|
||||
attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype)
|
||||
|
||||
for batch_size in [1, 2, 4]:
|
||||
seq_len = 4
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
|
||||
)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
)
|
||||
|
||||
output = attn(hidden_states, position_ids)
|
||||
assert output.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_different_sequence_lengths(self):
|
||||
"""Test attention with different sequence lengths."""
|
||||
config = MockDeepSeekConfig()
|
||||
attn = DeepSeekV3Attention(config, layer_idx=0).to(self.device, self.dtype)
|
||||
|
||||
for seq_len in [1, 4, 16]:
|
||||
batch_size = 2
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
|
||||
)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
)
|
||||
|
||||
output = attn(hidden_states, position_ids)
|
||||
assert output.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
|
||||
class TestDeepSeekV3MoE:
|
||||
"""Test DeepSeekV3MoE implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward_shape(self):
|
||||
"""Test MoE output shape."""
|
||||
config = MockDeepSeekConfig()
|
||||
moe = DeepSeekV3MoE(config).to(self.device, self.dtype)
|
||||
|
||||
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
|
||||
output = moe(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
assert torch.isfinite(output).all()
|
||||
|
||||
def test_with_shared_experts(self):
|
||||
"""Test MoE with shared experts."""
|
||||
config = MockDeepSeekConfig()
|
||||
config.n_shared_experts = 2
|
||||
moe = DeepSeekV3MoE(config).to(self.device, self.dtype)
|
||||
|
||||
assert moe.shared_experts is not None
|
||||
|
||||
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
|
||||
output = moe(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
def test_without_shared_experts(self):
|
||||
"""Test MoE without shared experts."""
|
||||
config = MockDeepSeekConfig()
|
||||
config.n_shared_experts = None
|
||||
moe = DeepSeekV3MoE(config).to(self.device, self.dtype)
|
||||
|
||||
assert moe.shared_experts is None
|
||||
|
||||
x = torch.randn(2, 4, config.hidden_size, dtype=self.dtype, device=self.device)
|
||||
output = moe(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
|
||||
class TestDeepSeekV3DecoderLayer:
|
||||
"""Test DeepSeekV3DecoderLayer implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_dense_layer(self):
|
||||
"""Test decoder layer with dense MLP."""
|
||||
config = MockDeepSeekConfig()
|
||||
# Layer 0 should be dense (before first_k_dense_replace)
|
||||
layer = DeepSeekV3DecoderLayer(config, layer_idx=0).to(self.device, self.dtype)
|
||||
|
||||
assert isinstance(layer.mlp, DeepSeekV3MLP)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
output = layer(hidden_states, position_ids)
|
||||
|
||||
assert output.shape == hidden_states.shape
|
||||
|
||||
def test_moe_layer(self):
|
||||
"""Test decoder layer with MoE."""
|
||||
config = MockDeepSeekConfig()
|
||||
# Layer 1 should be MoE (at first_k_dense_replace)
|
||||
layer = DeepSeekV3DecoderLayer(config, layer_idx=1).to(self.device, self.dtype)
|
||||
|
||||
assert isinstance(layer.mlp, DeepSeekV3MoE)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
output = layer(hidden_states, position_ids)
|
||||
|
||||
assert output.shape == hidden_states.shape
|
||||
|
||||
|
||||
class TestDeepSeekV3Model:
|
||||
"""Test DeepSeekV3Model implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward_with_input_ids(self):
|
||||
"""Test model forward with input_ids."""
|
||||
config = MockDeepSeekConfig()
|
||||
model = DeepSeekV3Model(config).to(self.device, self.dtype)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device)
|
||||
|
||||
output = model(input_ids=input_ids)
|
||||
|
||||
assert output.last_hidden_state.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_forward_with_inputs_embeds(self):
|
||||
"""Test model forward with inputs_embeds."""
|
||||
config = MockDeepSeekConfig()
|
||||
model = DeepSeekV3Model(config).to(self.device, self.dtype)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
inputs_embeds = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
output = model(inputs_embeds=inputs_embeds)
|
||||
|
||||
assert output.last_hidden_state.shape == inputs_embeds.shape
|
||||
|
||||
|
||||
class TestDeepSeekV3ForCausalLM:
|
||||
"""Test DeepSeekV3ForCausalLM implementation."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward(self):
|
||||
"""Test causal LM forward pass."""
|
||||
config = MockDeepSeekConfig()
|
||||
model = DeepSeekV3ForCausalLM(config).to(self.device, self.dtype)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device)
|
||||
|
||||
output = model(input_ids=input_ids)
|
||||
|
||||
assert output.logits.shape == (batch_size, seq_len, config.vocab_size)
|
||||
|
||||
def test_output_dtype(self):
|
||||
"""Test that logits are float32."""
|
||||
config = MockDeepSeekConfig()
|
||||
model = DeepSeekV3ForCausalLM(config).to(self.device, self.dtype)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=self.device)
|
||||
|
||||
output = model(input_ids=input_ids)
|
||||
|
||||
# Logits should be float32 for numerical stability
|
||||
assert output.logits.dtype == torch.float32
|
||||
|
||||
|
||||
class TestMLAOpRegistration:
|
||||
"""Test that MLA ops are properly registered."""
|
||||
|
||||
def test_torch_mla_registered(self):
|
||||
"""Test that torch_mla op is registered."""
|
||||
assert hasattr(torch.ops.auto_deploy, "torch_mla"), "torch_mla op should be registered"
|
||||
|
||||
def test_torch_cached_mla_registered(self):
|
||||
"""Test that torch_cached_mla_with_cache op is registered."""
|
||||
assert hasattr(torch.ops.auto_deploy, "torch_cached_mla_with_cache"), (
|
||||
"torch_cached_mla_with_cache op should be registered"
|
||||
)
|
||||
|
||||
def test_torch_mla_callable(self):
|
||||
"""Test that torch_mla op is callable."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_size, seq_len, num_heads = 1, 2, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 64, 32
|
||||
kv_lora_rank = 128
|
||||
v_head_dim = 64
|
||||
|
||||
q_nope = torch.randn(
|
||||
batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
q_pe = torch.randn(
|
||||
batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
compressed_kv = torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
kpe = torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
kv_b_proj_weight = torch.randn(
|
||||
num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
output = torch.ops.auto_deploy.torch_mla(
|
||||
q_nope, q_pe, compressed_kv, kpe, kv_b_proj_weight, True, None, "bsnd"
|
||||
)
|
||||
assert output is not None
|
||||
|
||||
def test_torch_cached_mla_callable(self):
|
||||
"""Test that torch_cached_mla_with_cache op is callable."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_size, seq_len, num_heads = 1, 1, 4
|
||||
qk_nope_head_dim, qk_rope_head_dim = 32, 16
|
||||
kv_lora_rank = 64
|
||||
v_head_dim = 32
|
||||
max_seq_len = 32
|
||||
|
||||
q_nope = torch.randn(
|
||||
batch_size, seq_len, num_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
q_pe = torch.randn(
|
||||
batch_size, seq_len, num_heads, qk_rope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
compressed_kv = torch.randn(batch_size, seq_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
kpe = torch.randn(batch_size, seq_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
kv_b_proj_weight = torch.randn(
|
||||
num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device)
|
||||
seq_len_tensor = torch.tensor([seq_len], dtype=torch.int32, device=device)
|
||||
input_pos = torch.tensor([0], dtype=torch.int32, device=device)
|
||||
cache_loc = torch.tensor([0], dtype=torch.int32, device=device)
|
||||
cu_seqlen = torch.tensor([0], dtype=torch.int32, device=device)
|
||||
|
||||
mla_cache = torch.zeros(
|
||||
batch_size, max_seq_len, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
output = torch.ops.auto_deploy.torch_cached_mla_with_cache(
|
||||
q_nope,
|
||||
q_pe,
|
||||
compressed_kv,
|
||||
kpe,
|
||||
kv_b_proj_weight,
|
||||
batch_info_host,
|
||||
seq_len_tensor,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
cu_seqlen,
|
||||
mla_cache,
|
||||
None,
|
||||
kv_lora_rank,
|
||||
)
|
||||
assert output is not None
|
||||
|
||||
|
||||
class TestMoEOpRegistration:
|
||||
"""Test that MoE ops are properly registered."""
|
||||
|
||||
def test_torch_moe_registered(self):
|
||||
"""Test that torch_moe op is registered."""
|
||||
assert hasattr(torch.ops.auto_deploy, "torch_moe"), "torch_moe op should be registered"
|
||||
|
||||
|
||||
class TestCustomModelRegistration:
|
||||
"""Test that custom model is properly registered."""
|
||||
|
||||
def test_model_registered(self):
|
||||
"""Test that DeepSeekV3ForCausalLM is registered with the factory."""
|
||||
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
|
||||
|
||||
# Check that the model is registered in the custom model mapping
|
||||
assert "DeepseekV3Config" in AutoModelForCausalLMFactory._custom_model_mapping
|
||||
assert (
|
||||
AutoModelForCausalLMFactory._custom_model_mapping["DeepseekV3Config"]
|
||||
== DeepSeekV3ForCausalLM
|
||||
)
|
||||
@ -1,110 +0,0 @@
|
||||
"""Testing module patches that enable export of deepseek model."""
|
||||
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from test_common.llm_data import hf_id_to_local_model_dir
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.patches.deepseek import (
|
||||
deepseek_v3_attention,
|
||||
deepseek_v3_moe_exact,
|
||||
)
|
||||
|
||||
|
||||
def _load_layer_from_model(model_name_or_path, layer_name):
|
||||
"""
|
||||
Loads a specific layer/module from a model without loading the entire model.
|
||||
|
||||
Parameters:
|
||||
model_name_or_path (str): Path or name of the pretrained model.
|
||||
layer_name (str): Name of the layer to extract.
|
||||
|
||||
Returns:
|
||||
module: The specified layer/module if available, otherwise None.
|
||||
"""
|
||||
try:
|
||||
# Load only the model configuration
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
|
||||
# Load a subset of layers of the model and configure yarn
|
||||
config.num_hidden_layers = 1
|
||||
config.use_cache = False
|
||||
config.first_k_dense_replace = 0
|
||||
config.n_routed_experts = 2
|
||||
config.num_experts_per_tok = 1
|
||||
config.n_group = 1
|
||||
config.topk_group = 1
|
||||
config.hidden_size = 8
|
||||
config.moe_intermediate_size = 8
|
||||
config.num_attention_heads = 2
|
||||
config.num_key_value_heads = 2
|
||||
config.qk_nope_head_dim = 4
|
||||
config.qk_rope_head_dim = 2
|
||||
config.v_head_dim = 4
|
||||
config.intermediate_size = 8
|
||||
config.max_position_embeddings = 7
|
||||
|
||||
config.rope_scaling = None
|
||||
|
||||
# Build the model architecture (no weights loaded yet)
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
# Access the specific layer by its name
|
||||
module = dict(model.named_modules()).get(layer_name)
|
||||
if module is None:
|
||||
print(f"Layer '{layer_name}' not found in the model.")
|
||||
else:
|
||||
print(f"Successfully extracted layer '{layer_name}'.")
|
||||
return module
|
||||
except Exception as e:
|
||||
print(f"Error extracting layer: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _generate_ds_attention_mask(b, s):
|
||||
return torch.where(
|
||||
torch.tril(torch.full((s, s), float("-inf"))).unsqueeze(0).unsqueeze(0).expand(b, 1, s, s)
|
||||
== float("-inf"),
|
||||
torch.tensor(0.0),
|
||||
torch.tensor(float(-3.4028e38)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, module_name, patch, inputs",
|
||||
[
|
||||
pytest.param(
|
||||
hf_id_to_local_model_dir("deepseek-ai/DeepSeek-R1"),
|
||||
"model.layers.0.self_attn",
|
||||
deepseek_v3_attention,
|
||||
[
|
||||
torch.randn(2, 6, 8, dtype=torch.bfloat16),
|
||||
_generate_ds_attention_mask(2, 6),
|
||||
torch.tensor([[0, 1, 2, 3, 4, 5]]),
|
||||
],
|
||||
), # attention requires inputs [hidden_states, attention_mask, position_ids]
|
||||
pytest.param(
|
||||
hf_id_to_local_model_dir("deepseek-ai/DeepSeek-R1"),
|
||||
"model.layers.0.mlp",
|
||||
deepseek_v3_moe_exact,
|
||||
[torch.randn(2, 6, 8, dtype=torch.bfloat16)],
|
||||
), # moe requires inputs [hidden_states]
|
||||
],
|
||||
)
|
||||
def test_module_patches(model_name, module_name, patch, inputs):
|
||||
# Get module
|
||||
module = _load_layer_from_model(model_name, module_name)
|
||||
|
||||
# Pass test inputs to generate reference
|
||||
ref, *_ = module(*inputs)
|
||||
|
||||
# Patch layer
|
||||
module.forward = types.MethodType(patch, module)
|
||||
|
||||
# Generate test output
|
||||
test, *_ = module(*inputs)
|
||||
|
||||
torch.allclose(ref, test, atol=0, rtol=0)
|
||||
@ -0,0 +1,652 @@
|
||||
"""Tests for GLM4 MoE Lite custom model implementation.
|
||||
|
||||
This module tests the custom GLM4 MoE Lite model implementation which uses
|
||||
auto_deploy custom ops (torch_mla, torch_moe, etc.) for export compatibility.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteConfig,
|
||||
Glm4MoeLiteForCausalLM,
|
||||
Glm4MoeLiteMLP,
|
||||
Glm4MoeLiteMoE,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8))
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def set_seed():
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _create_small_config() -> Glm4MoeLiteConfig:
|
||||
"""Create a small GLM4 MoE Lite config for testing."""
|
||||
return Glm4MoeLiteConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=3, # Layer 0 dense, layers 1-2 MoE
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=512,
|
||||
rms_norm_eps=1e-5,
|
||||
# MLA params (scaled down)
|
||||
q_lora_rank=32,
|
||||
kv_lora_rank=32,
|
||||
qk_nope_head_dim=8,
|
||||
qk_rope_head_dim=8,
|
||||
v_head_dim=16,
|
||||
# MoE params (scaled down)
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
num_experts_per_tok=2,
|
||||
moe_intermediate_size=32,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
norm_topk_prob=True,
|
||||
first_k_dense_replace=1,
|
||||
# RoPE
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
# Other
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
pad_token_id=0,
|
||||
)
|
||||
|
||||
|
||||
def _create_moe_layer(config: Glm4MoeLiteConfig) -> Glm4MoeLiteMoE:
|
||||
"""Create a MoE layer from config."""
|
||||
moe = Glm4MoeLiteMoE(config)
|
||||
# Initialize gate weights with randn for reproducibility
|
||||
# (gate weight is initialized with torch.empty which isn't seeded)
|
||||
moe.gate.weight = torch.nn.Parameter(torch.randn_like(moe.gate.weight))
|
||||
return moe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.no_grad()
|
||||
def test_glm4_moe_lite_moe_layer(B, S, dtype):
|
||||
"""Test that MoE layer produces valid output."""
|
||||
device = "cuda"
|
||||
config = _create_small_config()
|
||||
|
||||
moe = _create_moe_layer(config)
|
||||
moe.to(device=device, dtype=dtype)
|
||||
moe.eval()
|
||||
|
||||
H = config.hidden_size
|
||||
x = torch.randn(B, S, H, device=device, dtype=dtype)
|
||||
|
||||
output = moe(x)
|
||||
|
||||
# Check output shape matches input shape
|
||||
assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}"
|
||||
|
||||
# Check output is not all zeros (MoE should transform the input)
|
||||
assert not torch.allclose(output, torch.zeros_like(output)), "Output should not be all zeros"
|
||||
|
||||
# Check output doesn't have NaN or Inf values
|
||||
assert not torch.isnan(output).any(), "Output contains NaN values"
|
||||
assert not torch.isinf(output).any(), "Output contains Inf values"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.no_grad()
|
||||
def test_glm4_moe_lite_full_model(B, S, dtype):
|
||||
"""Test that full model produces valid output."""
|
||||
device = "cuda"
|
||||
config = _create_small_config()
|
||||
|
||||
model = Glm4MoeLiteForCausalLM(config)
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
# Create input
|
||||
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
output = model(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
# Check output shape
|
||||
assert output.logits.shape == (B, S, config.vocab_size), (
|
||||
f"Expected logits shape {(B, S, config.vocab_size)}, got {output.logits.shape}"
|
||||
)
|
||||
|
||||
# Check output doesn't have NaN or Inf values
|
||||
assert not torch.isnan(output.logits).any(), "Logits contain NaN values"
|
||||
assert not torch.isinf(output.logits).any(), "Logits contain Inf values"
|
||||
|
||||
|
||||
def test_glm4_moe_lite_model_can_be_exported():
|
||||
"""Test that the custom model can be exported with torch_export_to_gm.
|
||||
|
||||
This test verifies:
|
||||
1. The model exports successfully without graph breaks
|
||||
2. The exported graph module produces outputs with correct shape
|
||||
3. The outputs contain finite values (no NaN/Inf)
|
||||
|
||||
Note: We don't test numerical equivalence between original and exported model
|
||||
here because torch.export lifts parameters into the graph, creating a different
|
||||
parameter structure that doesn't match load_state_dict. The numerical correctness
|
||||
of the model itself is already validated by test_glm4_moe_lite_full_model.
|
||||
"""
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
config = _create_small_config()
|
||||
|
||||
model = Glm4MoeLiteForCausalLM(config)
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
# Create input
|
||||
B, S = 2, 8
|
||||
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
# Define dynamic shapes
|
||||
batch_size_dynamic = Dim.DYNAMIC
|
||||
seq_len_dynamic = Dim.DYNAMIC
|
||||
dynamic_shapes = (
|
||||
{0: batch_size_dynamic, 1: seq_len_dynamic},
|
||||
{0: batch_size_dynamic, 1: seq_len_dynamic},
|
||||
)
|
||||
|
||||
# Export the model - this is the main test: verify no graph breaks
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
args=tuple(),
|
||||
kwargs={"input_ids": input_ids, "position_ids": position_ids},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
# Move graph module to device
|
||||
move_to_device(gm, device)
|
||||
|
||||
# Verify the exported model produces valid output
|
||||
with torch.inference_mode():
|
||||
out_gm = gm(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
# Check output structure and shape
|
||||
assert "logits" in out_gm, "Output should contain 'logits' key"
|
||||
logits = out_gm["logits"]
|
||||
assert logits.shape == (B, S, config.vocab_size), (
|
||||
f"Expected shape {(B, S, config.vocab_size)}, got {logits.shape}"
|
||||
)
|
||||
assert torch.isfinite(logits).all(), "Logits should not contain NaN or Inf"
|
||||
|
||||
# Test with different input shape to verify dynamic shapes work
|
||||
B2, S2 = 1, 4
|
||||
input_ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device)
|
||||
position_ids2 = torch.arange(S2, device=device).unsqueeze(0).expand(B2, -1)
|
||||
|
||||
with torch.inference_mode():
|
||||
out_gm2 = gm(input_ids=input_ids2, position_ids=position_ids2)
|
||||
|
||||
logits2 = out_gm2["logits"]
|
||||
expected_shape = (B2, S2, config.vocab_size)
|
||||
assert logits2.shape == expected_shape, (
|
||||
f"Dynamic shape test failed: expected {expected_shape}, got {logits2.shape}"
|
||||
)
|
||||
assert torch.isfinite(logits2).all(), "Logits should not contain NaN or Inf"
|
||||
|
||||
|
||||
def test_glm4_moe_lite_config_registration():
|
||||
"""Test that the config is properly registered or model_type is correct."""
|
||||
# Create a config and verify model_type
|
||||
config = _create_small_config()
|
||||
assert config.model_type == "glm4_moe_lite"
|
||||
|
||||
# Verify our config class can be instantiated with expected attributes
|
||||
assert hasattr(config, "hidden_size")
|
||||
assert hasattr(config, "num_attention_heads")
|
||||
assert hasattr(config, "n_routed_experts")
|
||||
assert hasattr(config, "kv_lora_rank")
|
||||
assert hasattr(config, "qk_rope_head_dim")
|
||||
|
||||
|
||||
def test_glm4_moe_lite_layer_types():
|
||||
"""Test that layer 0 uses dense MLP and later layers use MoE."""
|
||||
config = _create_small_config()
|
||||
model = Glm4MoeLiteForCausalLM(config)
|
||||
|
||||
# Check layer 0 (should be dense MLP, not MoE)
|
||||
layer0_mlp = model.model.layers[0].mlp
|
||||
assert type(layer0_mlp).__name__ == "Glm4MoeLiteMLP", (
|
||||
f"Layer 0 should use Glm4MoeLiteMLP, got {type(layer0_mlp).__name__}"
|
||||
)
|
||||
|
||||
# Check layer 1+ (should be MoE)
|
||||
for i in range(1, config.num_hidden_layers):
|
||||
layer_mlp = model.model.layers[i].mlp
|
||||
assert type(layer_mlp).__name__ == "Glm4MoeLiteMoE", (
|
||||
f"Layer {i} should use Glm4MoeLiteMoE, got {type(layer_mlp).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def test_glm4_moe_lite_expert_structure():
|
||||
"""Test that experts have correct structure for checkpoint loading."""
|
||||
config = _create_small_config()
|
||||
moe = Glm4MoeLiteMoE(config)
|
||||
|
||||
# Check that experts is a ModuleList
|
||||
assert isinstance(moe.experts, torch.nn.ModuleList), "experts should be nn.ModuleList"
|
||||
|
||||
# Check number of experts
|
||||
assert len(moe.experts) == config.n_routed_experts, (
|
||||
f"Expected {config.n_routed_experts} experts, got {len(moe.experts)}"
|
||||
)
|
||||
|
||||
# Check each expert has the correct structure
|
||||
for i, expert in enumerate(moe.experts):
|
||||
assert hasattr(expert, "gate_proj"), f"Expert {i} missing gate_proj"
|
||||
assert hasattr(expert, "up_proj"), f"Expert {i} missing up_proj"
|
||||
assert hasattr(expert, "down_proj"), f"Expert {i} missing down_proj"
|
||||
|
||||
# Check state_dict keys match expected checkpoint format
|
||||
state_dict = moe.state_dict()
|
||||
expected_keys = [
|
||||
"experts.0.gate_proj.weight",
|
||||
"experts.0.up_proj.weight",
|
||||
"experts.0.down_proj.weight",
|
||||
]
|
||||
for key in expected_keys:
|
||||
assert key in state_dict, (
|
||||
f"Expected key '{key}' in state_dict, got keys: {list(state_dict.keys())[:10]}..."
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Numerical Equivalence Tests
|
||||
# These tests compare our custom implementation against the HF implementation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_hf_model_class():
|
||||
"""Get the HF model class for GLM4 MoE Lite.
|
||||
|
||||
Returns None if transformers doesn't have glm4_moe_lite (older versions).
|
||||
"""
|
||||
try:
|
||||
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteForCausalLM as HFGlm4MoeLiteForCausalLM,
|
||||
)
|
||||
|
||||
return HFGlm4MoeLiteForCausalLM
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_hf_moe_class():
|
||||
"""Get the HF MoE class for GLM4 MoE Lite."""
|
||||
try:
|
||||
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteMoE as HFGlm4MoeLiteMoE,
|
||||
)
|
||||
|
||||
return HFGlm4MoeLiteMoE
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_hf_attention_class():
|
||||
"""Get the HF Attention class for GLM4 MoE Lite."""
|
||||
try:
|
||||
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteAttention as HFGlm4MoeLiteAttention,
|
||||
)
|
||||
|
||||
return HFGlm4MoeLiteAttention
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_hf_mlp_class():
|
||||
"""Get the HF MLP class for GLM4 MoE Lite."""
|
||||
try:
|
||||
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteMLP as HFGlm4MoeLiteMLP,
|
||||
)
|
||||
|
||||
return HFGlm4MoeLiteMLP
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_hf_config_class():
|
||||
"""Get the HF Config class for GLM4 MoE Lite."""
|
||||
try:
|
||||
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import (
|
||||
Glm4MoeLiteConfig as HFGlm4MoeLiteConfig,
|
||||
)
|
||||
|
||||
return HFGlm4MoeLiteConfig
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _convert_hf_moe_state_dict_to_custom(hf_state_dict: dict, n_experts: int) -> dict:
|
||||
"""Convert HF MoE state dict to our custom format.
|
||||
|
||||
HF format (stacked):
|
||||
experts.gate_up_proj: [n_experts, 2 * intermediate_size, hidden_size]
|
||||
experts.down_proj: [n_experts, hidden_size, intermediate_size]
|
||||
|
||||
Custom format (per-expert):
|
||||
experts.0.gate_proj.weight: [intermediate_size, hidden_size]
|
||||
experts.0.up_proj.weight: [intermediate_size, hidden_size]
|
||||
experts.0.down_proj.weight: [hidden_size, intermediate_size]
|
||||
"""
|
||||
custom_state_dict = {}
|
||||
|
||||
for key, value in hf_state_dict.items():
|
||||
if key == "experts.gate_up_proj":
|
||||
# Split stacked gate_up into individual gate and up per expert
|
||||
# Shape: [n_experts, 2 * intermediate_size, hidden_size]
|
||||
intermediate_size = value.shape[1] // 2
|
||||
for i in range(n_experts):
|
||||
gate_up = value[i] # [2 * intermediate_size, hidden_size]
|
||||
gate_weight = gate_up[:intermediate_size] # [intermediate_size, hidden_size]
|
||||
up_weight = gate_up[intermediate_size:] # [intermediate_size, hidden_size]
|
||||
custom_state_dict[f"experts.{i}.gate_proj.weight"] = gate_weight
|
||||
custom_state_dict[f"experts.{i}.up_proj.weight"] = up_weight
|
||||
elif key == "experts.down_proj":
|
||||
# Split stacked down into individual down per expert
|
||||
# Shape: [n_experts, hidden_size, intermediate_size]
|
||||
for i in range(n_experts):
|
||||
custom_state_dict[f"experts.{i}.down_proj.weight"] = value[i]
|
||||
else:
|
||||
# Copy other keys as-is
|
||||
custom_state_dict[key] = value
|
||||
|
||||
return custom_state_dict
|
||||
|
||||
|
||||
def _convert_hf_full_model_state_dict_to_custom(hf_state_dict: dict, config) -> dict:
|
||||
"""Convert full HF model state dict to custom format.
|
||||
|
||||
Handles MoE expert weight conversion for all MoE layers.
|
||||
"""
|
||||
custom_state_dict = {}
|
||||
n_experts = config.n_routed_experts
|
||||
|
||||
for key, value in hf_state_dict.items():
|
||||
# Check if this is an MoE expert weight
|
||||
if ".mlp.experts.gate_up_proj" in key:
|
||||
# Extract layer prefix (e.g., "model.layers.1.mlp.")
|
||||
prefix = key.replace("experts.gate_up_proj", "")
|
||||
intermediate_size = value.shape[1] // 2
|
||||
for i in range(n_experts):
|
||||
gate_up = value[i]
|
||||
gate_weight = gate_up[:intermediate_size]
|
||||
up_weight = gate_up[intermediate_size:]
|
||||
custom_state_dict[f"{prefix}experts.{i}.gate_proj.weight"] = gate_weight
|
||||
custom_state_dict[f"{prefix}experts.{i}.up_proj.weight"] = up_weight
|
||||
elif ".mlp.experts.down_proj" in key:
|
||||
prefix = key.replace("experts.down_proj", "")
|
||||
for i in range(n_experts):
|
||||
custom_state_dict[f"{prefix}experts.{i}.down_proj.weight"] = value[i]
|
||||
else:
|
||||
# Copy other keys as-is
|
||||
custom_state_dict[key] = value
|
||||
|
||||
return custom_state_dict
|
||||
|
||||
|
||||
def _create_hf_config():
|
||||
"""Create HF config that matches our test config."""
|
||||
HFConfig = _get_hf_config_class()
|
||||
if HFConfig is None:
|
||||
return None
|
||||
|
||||
config = HFConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=512,
|
||||
rms_norm_eps=1e-5,
|
||||
# MLA params
|
||||
q_lora_rank=32,
|
||||
kv_lora_rank=32,
|
||||
qk_nope_head_dim=8,
|
||||
qk_rope_head_dim=8,
|
||||
v_head_dim=16,
|
||||
# MoE params
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
num_experts_per_tok=2,
|
||||
moe_intermediate_size=32,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
norm_topk_prob=True,
|
||||
first_k_dense_replace=1,
|
||||
# RoPE
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
# Other
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
pad_token_id=0,
|
||||
)
|
||||
|
||||
# Set internal attributes needed by HF's MoE implementation
|
||||
# _experts_implementation tells HF which expert forward function to use
|
||||
config._experts_implementation = "eager"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.no_grad()
|
||||
def test_glm4_moe_lite_moe_numerical_equivalence(B, S, dtype):
|
||||
"""Test MoE layer produces numerically equivalent output to HF implementation."""
|
||||
HFMoE = _get_hf_moe_class()
|
||||
if HFMoE is None:
|
||||
pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)")
|
||||
|
||||
device = "cuda"
|
||||
config = _create_small_config()
|
||||
hf_config = _create_hf_config()
|
||||
|
||||
# Create HF MoE
|
||||
hf_moe = HFMoE(hf_config)
|
||||
hf_moe.to(device=device, dtype=dtype)
|
||||
hf_moe.eval()
|
||||
|
||||
# Initialize weights with randn for reproducibility
|
||||
# (HF uses torch.empty which may be zeros, causing NaN in computation)
|
||||
hf_moe.gate.weight = torch.nn.Parameter(torch.randn_like(hf_moe.gate.weight))
|
||||
hf_moe.experts.gate_up_proj = torch.nn.Parameter(torch.randn_like(hf_moe.experts.gate_up_proj))
|
||||
hf_moe.experts.down_proj = torch.nn.Parameter(torch.randn_like(hf_moe.experts.down_proj))
|
||||
|
||||
# Create custom MoE and load converted weights
|
||||
custom_moe = Glm4MoeLiteMoE(config)
|
||||
custom_moe.to(device=device, dtype=dtype)
|
||||
|
||||
# Convert HF stacked expert weights to our per-expert format
|
||||
hf_state_dict = hf_moe.state_dict()
|
||||
|
||||
# Debug: print state dict keys and shapes
|
||||
print("\n=== HF MoE state_dict keys and shapes ===")
|
||||
for k, v in hf_state_dict.items():
|
||||
print(f" {k}: {v.shape}")
|
||||
|
||||
custom_state_dict = _convert_hf_moe_state_dict_to_custom(hf_state_dict, config.n_routed_experts)
|
||||
|
||||
print("\n=== Converted custom state_dict keys and shapes ===")
|
||||
for k, v in custom_state_dict.items():
|
||||
print(f" {k}: {v.shape}")
|
||||
|
||||
print("\n=== Expected custom MoE state_dict keys ===")
|
||||
for k, v in custom_moe.state_dict().items():
|
||||
print(f" {k}: {v.shape}")
|
||||
|
||||
custom_moe.load_state_dict(custom_state_dict)
|
||||
custom_moe.eval()
|
||||
|
||||
# Sanity check: verify expert weights match after conversion
|
||||
# HF has stacked weights: experts.gate_up_proj [n_experts, 2*intermediate, hidden]
|
||||
# Our model has per-expert: experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight
|
||||
hf_gate_up = hf_moe.experts.gate_up_proj # [n_experts, 2*intermediate, hidden]
|
||||
hf_down = hf_moe.experts.down_proj # [n_experts, hidden, intermediate]
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
|
||||
print(f"\n=== Debug: intermediate_size = {intermediate_size} ===")
|
||||
print(f"hf_gate_up shape: {hf_gate_up.shape}")
|
||||
print(f"hf_gate_up[0, :2, :2]: {hf_gate_up[0, :2, :2]}")
|
||||
|
||||
# Get the converted state dict values for comparison
|
||||
converted_gate_0 = custom_state_dict["experts.0.gate_proj.weight"]
|
||||
print(f"converted_gate_0 shape: {converted_gate_0.shape}")
|
||||
print(f"converted_gate_0[:2, :2]: {converted_gate_0[:2, :2]}")
|
||||
|
||||
# After load_state_dict
|
||||
loaded_gate_0 = custom_moe.experts[0].gate_proj.weight
|
||||
print(f"loaded_gate_0 shape: {loaded_gate_0.shape}")
|
||||
print(f"loaded_gate_0[:2, :2]: {loaded_gate_0[:2, :2]}")
|
||||
|
||||
for i in range(config.n_routed_experts):
|
||||
# Check gate_proj
|
||||
hf_gate = hf_gate_up[i, :intermediate_size, :] # [intermediate, hidden]
|
||||
custom_gate = custom_moe.experts[i].gate_proj.weight
|
||||
torch.testing.assert_close(
|
||||
custom_gate, hf_gate, msg=f"Expert {i} gate_proj weights don't match"
|
||||
)
|
||||
|
||||
# Check up_proj
|
||||
hf_up = hf_gate_up[i, intermediate_size:, :] # [intermediate, hidden]
|
||||
custom_up = custom_moe.experts[i].up_proj.weight
|
||||
torch.testing.assert_close(custom_up, hf_up, msg=f"Expert {i} up_proj weights don't match")
|
||||
|
||||
# Check down_proj
|
||||
hf_down_i = hf_down[i] # [hidden, intermediate]
|
||||
custom_down = custom_moe.experts[i].down_proj.weight
|
||||
torch.testing.assert_close(
|
||||
custom_down, hf_down_i, msg=f"Expert {i} down_proj weights don't match"
|
||||
)
|
||||
|
||||
# Also verify gate weights match
|
||||
torch.testing.assert_close(
|
||||
custom_moe.gate.weight, hf_moe.gate.weight, msg="Gate weights don't match"
|
||||
)
|
||||
|
||||
# Create input
|
||||
H = config.hidden_size
|
||||
x = torch.randn(B, S, H, device=device, dtype=dtype)
|
||||
|
||||
# Run both
|
||||
hf_out = hf_moe(x)
|
||||
custom_out = custom_moe(x)
|
||||
|
||||
# Handle tuple output from HF (output, router_logits)
|
||||
if isinstance(hf_out, tuple):
|
||||
hf_out = hf_out[0]
|
||||
|
||||
# Compare
|
||||
rtol, atol = 0.05, 0.05
|
||||
torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.no_grad()
|
||||
def test_glm4_moe_lite_mlp_numerical_equivalence(B, S, dtype):
|
||||
"""Test MLP layer produces numerically equivalent output to HF implementation."""
|
||||
HFMLP = _get_hf_mlp_class()
|
||||
if HFMLP is None:
|
||||
pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)")
|
||||
|
||||
device = "cuda"
|
||||
config = _create_small_config()
|
||||
hf_config = _create_hf_config()
|
||||
|
||||
# Create HF MLP
|
||||
hf_mlp = HFMLP(hf_config)
|
||||
hf_mlp.to(device=device, dtype=dtype)
|
||||
hf_mlp.eval()
|
||||
|
||||
# Create custom MLP and load same weights
|
||||
custom_mlp = Glm4MoeLiteMLP(config)
|
||||
custom_mlp.to(device=device, dtype=dtype)
|
||||
custom_mlp.load_state_dict(hf_mlp.state_dict())
|
||||
custom_mlp.eval()
|
||||
|
||||
# Create input
|
||||
H = config.hidden_size
|
||||
x = torch.randn(B, S, H, device=device, dtype=dtype)
|
||||
|
||||
# Run both
|
||||
hf_out = hf_mlp(x)
|
||||
custom_out = custom_mlp(x)
|
||||
|
||||
# Compare
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.no_grad()
|
||||
def test_glm4_moe_lite_full_model_numerical_equivalence(B, S, dtype):
|
||||
"""Test full model produces numerically equivalent output to HF implementation."""
|
||||
HFModel = _get_hf_model_class()
|
||||
if HFModel is None:
|
||||
pytest.skip("transformers doesn't have glm4_moe_lite (requires v5.0+)")
|
||||
|
||||
device = "cuda"
|
||||
config = _create_small_config()
|
||||
hf_config = _create_hf_config()
|
||||
|
||||
# Create HF model
|
||||
hf_model = HFModel(hf_config)
|
||||
hf_model.to(device=device, dtype=dtype)
|
||||
hf_model.eval()
|
||||
|
||||
# Initialize all gate weights for reproducibility
|
||||
for module in hf_model.modules():
|
||||
if hasattr(module, "gate") and hasattr(module.gate, "weight"):
|
||||
module.gate.weight = torch.nn.Parameter(torch.randn_like(module.gate.weight))
|
||||
|
||||
# Create custom model and load converted weights
|
||||
custom_model = Glm4MoeLiteForCausalLM(config)
|
||||
custom_model.to(device=device, dtype=dtype)
|
||||
|
||||
# Convert HF stacked expert weights to our per-expert format
|
||||
hf_state_dict = hf_model.state_dict()
|
||||
custom_state_dict = _convert_hf_full_model_state_dict_to_custom(hf_state_dict, config)
|
||||
custom_model.load_state_dict(custom_state_dict)
|
||||
custom_model.eval()
|
||||
|
||||
# Create input
|
||||
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
# Run both
|
||||
hf_out = hf_model(input_ids=input_ids, position_ids=position_ids)
|
||||
custom_out = custom_model(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
# Compare logits - cast to same dtype for comparison
|
||||
# (HF model may output float32 for numerical stability in lm_head)
|
||||
rtol, atol = 0.05, 0.05
|
||||
torch.testing.assert_close(
|
||||
custom_out.logits.float(),
|
||||
hf_out.logits.float(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
@ -1,97 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
|
||||
def naive_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale):
|
||||
bsz = q_nope.shape[0]
|
||||
q_len = q_nope.shape[2]
|
||||
num_heads = q_nope.shape[1]
|
||||
qk_nope_head_dim = q_nope.shape[-1]
|
||||
v_head_dim = wkv_b.weight.shape[-1] - qk_nope_head_dim
|
||||
qk_head_dim = qk_nope_head_dim + q_pe.shape[-1]
|
||||
k_pe = k_pe.view(bsz, q_len, 1, q_pe.shape[-1]).transpose(1, 2)
|
||||
|
||||
# Up project compressed_kv
|
||||
kv = (
|
||||
wkv_b(compressed_kv)
|
||||
.view(bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
query_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim)
|
||||
query_states[:, :, :, :qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, qk_nope_head_dim:] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim)
|
||||
key_states[:, :, :, :qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, qk_nope_head_dim:] = k_pe
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=None,
|
||||
is_causal=False,
|
||||
dropout_p=0.0,
|
||||
scale=softmax_scale,
|
||||
).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale):
|
||||
num_heads = q_nope.shape[1]
|
||||
qk_nope_head_dim = q_nope.shape[-1]
|
||||
v_head_dim = wkv_b.weight.shape[-1] - qk_nope_head_dim
|
||||
kv_lora_rank = compressed_kv.shape[-1]
|
||||
|
||||
# Down project q_nope
|
||||
wkv_b_weight = wkv_b.weight.view(num_heads, -1, kv_lora_rank)
|
||||
q_nope_proj = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b_weight[:, :qk_nope_head_dim])
|
||||
|
||||
# MLA ref operation
|
||||
x = torch.ops.auto_deploy.torch_attention_deepseek_mla(
|
||||
q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale
|
||||
)
|
||||
|
||||
# Up project attention scores
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b_weight[:, -v_head_dim:])
|
||||
return x
|
||||
|
||||
|
||||
def test_attn():
|
||||
# Define test configurations
|
||||
kv_lora_rank = 4
|
||||
bsz = 2
|
||||
q_len = 6
|
||||
v_head_dim = 2
|
||||
qk_nope_head_dim = 2
|
||||
qk_rope_head_dim = 1
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
num_heads = 4
|
||||
softmax_scale = qk_head_dim**-0.5
|
||||
|
||||
# Generate inputs
|
||||
q_nope = torch.randn(bsz, num_heads, q_len, qk_nope_head_dim)
|
||||
q_pe = torch.randn(bsz, num_heads, q_len, qk_rope_head_dim)
|
||||
compressed_kv = torch.randn(bsz, q_len, kv_lora_rank)
|
||||
k_pe = torch.randn(bsz, q_len, qk_rope_head_dim)
|
||||
|
||||
# Define w_kv_b projection matrix
|
||||
wkv_b = nn.Linear(
|
||||
kv_lora_rank, num_heads * (qk_head_dim - qk_rope_head_dim + v_head_dim), bias=False
|
||||
)
|
||||
|
||||
# Compute naive attention
|
||||
out_naive = naive_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale)
|
||||
|
||||
# Compute MLA attention
|
||||
out_mla = mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale)
|
||||
|
||||
# Check if the two outputs are close
|
||||
assert torch.allclose(out_naive, out_mla, rtol=1e-5, atol=1e-5)
|
||||
Loading…
Reference in New Issue
Block a user