[TRTLLM-10195][feat] K-EXAONE support (#10355)

Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: Jaedeok Kim <jaedeokk@nvidia.com>
This commit is contained in:
Yechan Kim 2026-01-12 00:29:51 +09:00 committed by GitHub
parent 80649a8b78
commit 8e0d20d901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1331 additions and 64 deletions

View File

@ -1,9 +1,6 @@
# EXAONE
This document shows how to build and run a [EXAONE](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) model in TensorRT-LLM.
The TensorRT LLM EXAONE implementation is based on the LLaMA model. The implementation can be found in [llama/model.py](../../../../tensorrt_llm/models/llama/model.py).
See the LLaMA example [`examples/models/core/llama`](../llama) for details.
This document shows how to build and run [EXAONE](https://huggingface.co/LGAI-EXAONE) models in TensorRT-LLM.
- [EXAONE](#exaone)
- [Support Matrix](#support-matrix)
@ -11,31 +8,51 @@ See the LLaMA example [`examples/models/core/llama`](../llama) for details.
- [EXAONE-3.0](#exaone-30)
- [EXAONE-Deep](#exaone-deep)
- [EXAONE-4.0](#exaone-40)
- [Usage](#usage)
- [PyTorch flow](#pytorch-flow)
-[PyTorch flow Quantization](#pytorch-flow-quantization)
- [TRT Flow](#trt-flow)
- [K-EXAONE](#k-exaone)
- [PyTorch flow](#pytorch-flow)
- [Running EXAONE-4.0](#running-exaone-40)
- [Running K-EXAONE](#running-k-exaone)
- [MoE Backend Options](#moe-backend-options)
- [PyTorch flow Quantization](#pytorch-flow-quantization)
- [FP8 Quantization](#fp8-quantization)
- [NVFP4 Quantization](#nvfp4-quantization)
- [Running the TensorRT LLM Server](#running-the-tensorrt-llm-server)
- [Running Aggregated TensorRT LLM Server](#running-aggregated-tensorrt-llm-server)
- [Creating the Extra Options Configuration](#creating-the-extra-options-configuration)
- [Launch trtllm-serve OpenAI-compatible API server](#launch-trtllm-serve-openai-compatible-api-server)
- [Running Disaggregated TensorRT LLM Server](#running-disaggregated-tensorrt-llm-server)
- [Step 1: Set Environment Variables](#step-1-set-environment-variables)
- [Step 2: Create Configuration Files](#step-2-create-configuration-files)
- [Step 3: Launch the Disaggregated Server](#step-3-launch-the-disaggregated-server)
- [TRT flow](#trt-flow)
- [Convert checkpoint and build TensorRT engine(s)](#convert-checkpoint-and-build-tensorrt-engines)
- [FP8 Post-Training Quantization](#fp8-post-training-quantization)
- [SmoothQuant](#smoothquant)
- [Groupwise quantization (AWQ)](#groupwise-quantization-awq)
- [W4A16 AWQ with FP8 GEMM (W4A8 AWQ)](#w4a16-awq-with-fp8-gemm-w4a8-awq)
- [W4A16 AWQ with FP8 GEMM (W4A8 AWQ)](#w4a16-awq-with-fp8-gemm-w4a8-awq)
- [Run Engine](#run-engine)
- [Troubleshootings](#troubleshootings)
- [Troubleshootings for EXAONE-4.0](#troubleshootings-for-exaone-40)
- [Troubleshootings for K-EXAONE](#troubleshootings-for-k-exaone)
## Support Matrix
* FP16
* BF16
* Tensor Parallel
* Tensor Parallel (TP)
* Expert Parallel (EP) (K-EXAONE only)
* Attention Data Parallel (ADP) (K-EXAONE only)
* Disaggregated Serving
* FP8
* INT8 & INT4 Weight-Only
* INT8 SmoothQuant
* INT4 AWQ & W4A8 AWQ
* NVFP4 (K-EXAONE only)
## Supported Models
**Note:**
- **EXAONE-3.0** and **EXAONE-Deep** are supported using the [TRT Flow](#trt-flow).
- **EXAONE-4.0** is supported using the [PyTorch flow](#pytorch-flow).
**Note:**
- **EXAONE-3.0** & **EXAONE-Deep** are supported using the [TRT Flow](#trt-flow).
- **EXAONE-4.0** & **K-EXAONE** are supported using the [PyTorch flow](#pytorch-flow).
Please refer to the corresponding sections below for usage instructions and examples for each model.
@ -59,23 +76,33 @@ git clone https://huggingface.co/LGAI-EXAONE/EXAONE-Deep-2.4B $HF_MODEL_DIR
### EXAONE-4.0
Download he HuggingFace checkpoints of EXAONE-4.0 model. Here, we only use the `EXAONE-4.0-32B` model for the example. From EXAONE-4.0 model, we support only on PyTorch flow.
Download the HuggingFace checkpoints of the EXAONE-4.0 model. Here, we use the `EXAONE-4.0-32B` model as an example. EXAONE-4.0 is supported only via the PyTorch flow.
```bash
export HF_MODEL_DIR=hf_models/exaone4
git clone https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B $HF_MODEL_DIR
```
### Pytorch flow
### K-EXAONE
K-EXAONE is a Mixture of Experts (MoE) model based on the EXAONE architecture. It features a hybrid architecture with both dense and MoE layers, sliding window attention, and supports FP8 and NVFP4 quantization for efficient inference.
Download the HuggingFace checkpoints of the K-EXAONE model:
```bash
export HF_MODEL_DIR=hf_models/kexaone
git clone https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B $HF_MODEL_DIR
```
## PyTorch flow
### Running EXAONE-4.0
To quickly run EXAONE-4.0 models, you can use [examples/llm-api/quickstart_advanced.py](../../../llm-api/quickstart_advanced.py):
```bash
python ../../../llm-api/quickstart_advanced.py --model_dir hf_models/$MODEL_NAME --disable_kv_cache_reuse
python ../../../llm-api/quickstart_advanced.py --model_dir $HF_MODEL_DIR
```
SWA currently does not support kv_cache_reuse. Please make sure to disable KV cache reuse when running with SWA.
The output will be like:
```bash
[0] Prompt: 'Hello, my name is', Generated text: " [Your Name], and I'm a [Your Profession]. I'm here to learn and share with you.\n\nBest regards,\n[Your Name]\n\nThis letter is concise, professional, and clearly states who you are and what you're here for. It's a good starting point"
@ -83,47 +110,239 @@ The output will be like:
[2] Prompt: 'The future of AI is', Generated text: ' not just about technology but also about how we choose to use it. We must ensure that AI is developed and deployed in a way that benefits all of humanity, not just a select few. This means prioritizing ethical considerations, transparency, and accountability in AI development. It also means involving diverse stakeholders in the conversation about AI'
```
#### PyTorch flow Quantization
### Running K-EXAONE
For PyTorch flow, TRT-LLM supports quantized format generated by [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer).
You can either do pre-quantized models in HF model hub, or can generate quantized model by yourself and then run models with below command:
K-EXAONE is a Mixture of Experts model that benefits from multiple parallelism strategies. You can run it with tensor parallelism (TP), expert parallelism (EP), and attention data parallelism (ADP):
```bash
git clone https://github.com/NVIDIA/Model-Optimizer.git
python ../../../llm-api/quickstart_advanced.py \
--model_dir $HF_MODEL_DIR \
--tp_size 8 \
--moe_ep_size 8 \
--enable_attention_dp \
--trust_remote_code
```
The output will be like:
```bash
[0] Prompt: 'Hello, my name is', Generated text: ' John Smith, and I am a 28-year-old software developer. I live in the city of San Francisco, California. I work remotely for a tech startup based in Austin, Texas.\n\nI enjoy hiking, reading, and playing the piano. In my free time, I often explore new neighborhoods in San Francisco, trying out new restaurants and cafes.\n\n'
[1] Prompt: 'The capital of France is', Generated text: ' Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris'
[2] Prompt: 'The future of AI is', Generated text: ' bright.\n</think>\n\nThe future of AI holds immense promise across numerous domains. In healthcare, AI is revolutionizing diagnostics, drug discovery, and personalized treatment plans. In education, AI is enabling adaptive learning platforms that cater to individual learning styles and paces. In environmental science, AI is playing a pivotal role in addressing climate change by optimizing'
```
#### MoE Backend Options
K-EXAONE supports the following MoE backends:
| Backend | Description |
|---------|-------------|
| `CUTLASS` | Default backend, optimized for general use cases |
| `TRTLLM` | TensorRT-LLM backend using TRT-LLM Gen kernels, optimized for low-latency inference |
| `WIDEEP` | Wide expert parallelism backend for cases where EP size exceeds the number of experts |
You can specify the MoE backend using the `--moe_backend` argument:
```bash
python ../../../llm-api/quickstart_advanced.py \
--model_dir $HF_MODEL_DIR \
--tp_size 8 \
--moe_ep_size 8 \
--enable_attention_dp \
--moe_backend CUTLASS \
--trust_remote_code
```
### PyTorch flow Quantization
For PyTorch flow, TRT-LLM supports quantized formats generated by [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer). You can either use pre-quantized models from the HuggingFace model hub, or generate quantized models yourself using the instructions below.
First, clone the [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer) repository:
```bash
git clone https://github.com/NVIDIA/Model-Optimizer
cd Model-Optimizer/examples/llm_ptq
scripts/huggingface_example.sh --model hf_models/$MODEL_NAME --quant fp8 --export_fmt hf
```
For more information, please refer to official [docs](https://github.com/NVIDIA/Model-Optimizer) or [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer).
For more information, please refer to the official [Model Optimizer documentation](https://github.com/NVIDIA/Model-Optimizer).
Troubleshooting
#### FP8 Quantization
FP8 quantization provides a good balance between model accuracy and inference performance. To quantize a model to FP8 format:
The following error may occur during quantization:
```bash
torch._dynamo.exc.Unsupported: Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
python3 hf_ptq.py --model $HF_MODEL_DIR --quant fp8 --export_fmt hf
```
This error may indicate an incompatibility between `torch.compile()` and the `HybridCache` module of the transformers library. As a result, [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer) (ModelOpt) cannot perform PTQ with HybridCache.
#### NVFP4 Quantization
Temporarily switching to `DynamicCache` when creating PTQ models could help address the issue. This can be done by updating the `cache_implementation` field in the `generation_config.json` file located in the model checkpoint directory, for example:
```json
# generation_config.json
{
// Change "hybrid" to "dynamic" to run PTQ.
// Revert this to "hybrid" after quantization is complete.
"cache_implementation": "hybrid",
...
}
NVFP4 (4-bit floating point) quantization enables memory-efficient inference with reduced GPU memory footprint. To quantize a model to NVFP4 format:
```bash
python3 hf_ptq.py --model $HF_MODEL_DIR --quant nvfp4 --export_fmt hf
```
For models with sliding window attention, DynamicCache is less memory-efficient than HybridCache because it retains the entire key-value cache. However, this does not break the model's attention logic, as the cache implementation is separated from the attention computation itself. This trade-off is acceptable for the PTQ process, which is a one-time procedure. Our tests confirm that this workaround does not degrade accuracy on MMLU or GSM8K benchmarks with the default ModelOpt settings.
### TRT flow
## Running the TensorRT LLM Server
The next section describe how to convert the weights from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT LLM format. We will use llama's [convert_checkpoint.py](../llama/convert_checkpoint.py) for EXAONE model and then we build the model with `trtllm-build`.
This section describes how to deploy the K-EXAONE model using the TensorRT LLM server with an OpenAI-compatible API endpoint.
Make sure `HF_MODEL_DIR` points to your EXAONE checkpoint directory.
The examples in this section are intended as a minimal, runnable demonstration and are not fully performance-optimized. For more features and performance tuning, please refer the documents below.
- [Disaggregated Serving examples](../../../disaggregated/README.md)
- [Disaggregated Serving feature guide](../../../../docs/source/features/disagg-serving.md)
- [Recommended LLM API configuration settings](../../../configs/README.md) (see also `examples/configs/curated/`)
### Running Aggregated TensorRT LLM Server
The aggregated server runs all components (context and generation phases) on the same set of GPUs, which is suitable for single-node deployments.
#### Creating the Extra Options Configuration
Create a YAML configuration file to specify advanced options such as attention data parallelism, CUDA graph settings, and MoE backend configuration:
```bash
cat <<EOF > configs.yaml
enable_attention_dp: true
trust_remote_code: true
cuda_graph_config:
max_batch_size: 2048
enable_padding: true
moe_config:
backend: CUTLASS # The TRTLLM backend is recommended for the Blackwell architecture.
kv_cache_config:
enable_block_reuse: true # Please disable the block reuse feature when conducting performance benchmarking.
max_attention_window: [128, 128, 128, 131072] # This allows KV cache manager to possibly improve memory efficiency.
free_gpu_memory_fraction: 0.9
dtype: "auto"
attention_dp_config:
enable_balance: true
batching_wait_iters: 50
timeout_iters: 1
num_postprocess_workers: 4 # Can mitigate the postprocessing overhead (e.g. detokenization)
EOF
```
#### Launch trtllm-serve OpenAI-compatible API server
Start the server using `trtllm-serve` with the PyTorch backend. This launches an OpenAI-compatible API server that can handle chat completions and text generation requests:
```bash
trtllm-serve \
$HF_MODEL_DIR \
--host localhost \
--port 8000 \
--backend pytorch \
--max_batch_size 2048 \
--max_num_tokens 8192 \
--tp_size 8 \
--ep_size 8 \
--pp_size 1 \
--config ./configs.yaml
```
Once the server is running, you can send requests to `http://localhost:8000/v1/completions` using the OpenAI API format.
### Running Disaggregated TensorRT LLM Server
Disaggregated serving separates the context (prefill) and generation (decode) phases onto different GPU sets, enabling better resource utilization and improved throughput. This example demonstrates a single-node disaggregated deployment using 8 GPUs (4 for context, 4 for generation). For more details, see the [Disaggregated Serving documentation](../../../disaggregated/README.md).
#### Step 1: Set Environment Variables
Configure the parallelism and buffer settings:
```bash
# Buffer size for KV cache transfer between context and generation servers
export MAX_TOKENS_IN_BUFFER=8192
# Model parallelism configuration
export TP_SIZE=4
export MOE_EP_SIZE=4
export ENABLE_ATTENTION_DP=true
```
#### Step 2: Create Configuration Files
**Context server configuration (`ctx_extra-llm-api-config.yaml`):**
```bash
cat > ctx_extra-llm-api-config.yaml << EOF
backend: pytorch
trust_remote_code: true
disable_overlap_scheduler: true
enable_chunked_prefill: true
tensor_parallel_size: $TP_SIZE
moe_expert_parallel_size: $MOE_EP_SIZE
pipeline_parallel_size: 1
enable_attention_dp: $ENABLE_ATTENTION_DP
cache_transceiver_config:
backend: UCX
max_tokens_in_buffer: $MAX_TOKENS_IN_BUFFER
EOF
```
**Generation server configuration (`gen_extra-llm-api-config.yaml`):**
```bash
cat > gen_extra-llm-api-config.yaml << EOF
backend: pytorch
trust_remote_code: true
disable_overlap_scheduler: false
enable_chunked_prefill: true
tensor_parallel_size: $TP_SIZE
moe_expert_parallel_size: $MOE_EP_SIZE
pipeline_parallel_size: 1
enable_attention_dp: $ENABLE_ATTENTION_DP
cache_transceiver_config:
backend: UCX
max_tokens_in_buffer: $MAX_TOKENS_IN_BUFFER
EOF
```
**Disaggregated orchestrator configuration (`disagg_config.yaml`):**
```bash
cat > disagg_config.yaml << EOF
hostname: localhost
port: 8000
backend: pytorch
context_servers:
num_instances: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
urls:
- "localhost:8002"
EOF
```
#### Step 3: Launch the Disaggregated Server
Start all components in the following order:
```bash
# 1. Start context server (GPUs 0-3)
CUDA_VISIBLE_DEVICES=0,1,2,3 trtllm-serve $HF_MODEL_DIR \
--host localhost --port 8001 --enable_chunked_prefill \
--extra_llm_api_options ./ctx_extra-llm-api-config.yaml &> log_ctx.log &
# 2. Start generation server (GPUs 4-7)
CUDA_VISIBLE_DEVICES=4,5,6,7 trtllm-serve $HF_MODEL_DIR \
--host localhost --port 8002 --enable_chunked_prefill \
--extra_llm_api_options ./gen_extra-llm-api-config.yaml &> log_gen.log &
# 3. Start disaggregated orchestrator
trtllm-serve disaggregated -c disagg_config.yaml -t 360 -r 1200 &> log_disagg.log &
```
Once all servers are running, you can send requests to `http://localhost:8000/v1/completions` using the OpenAI API format.
## TRT flow
The next section describes how to convert weights from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT LLM format. We will use LLaMA's [convert_checkpoint.py](../llama/convert_checkpoint.py) for EXAONE models and then build the model with `trtllm-build`.
### Convert checkpoint and build TensorRT engine(s)
@ -141,7 +360,7 @@ trtllm-build \
--output_dir trt_engines/exaone/fp16/1-gpu \
--gemm_plugin auto
# Build the EXAONE model using a single GPU and and apply INT8 weight-only quantization.
# Build the EXAONE model using a single GPU and apply INT8 weight-only quantization.
python ../llama/convert_checkpoint.py \
--model_dir $HF_MODEL_DIR \
--output_dir trt_models/exaone/int8_wq/1-gpu \
@ -154,7 +373,7 @@ trtllm-build \
--output_dir trt_engines/exaone/int8_wq/1-gpu \
--gemm_plugin auto
# Build the EXAONE model using a single GPU and and apply INT4 weight-only quantization.
# Build the EXAONE model using a single GPU and apply INT4 weight-only quantization.
python ../llama/convert_checkpoint.py \
--model_dir $HF_MODEL_DIR \
--output_dir trt_models/exaone/int4_wq/1-gpu \
@ -183,18 +402,18 @@ trtllm-build \
### FP8 Post-Training Quantization
The examples below uses the NVIDIA Modelopt (AlgorithMic Model Optimization) toolkit for the model quantization process.
The examples below use the NVIDIA ModelOpt (AlgorithMic Model Optimization) toolkit for the model quantization process.
First make sure Modelopt toolkit is installed (see [examples/quantization/README.md](/examples/quantization/README.md#preparation))
```bash
# Build the EXAONE model using a single GPU and and apply FP8 quantization.
# Build the EXAONE model using a single GPU and apply FP8 quantization.
python ../../../quantization/quantize.py \
--model_dir $HF_MODEL_DIR \
--dtype float16 \
--qformat fp8 \
--kv_cache_dtype fp8 \
--output_dir trt_models/exaone/fp8/1-gpu \
--output_dir trt_models/exaone/fp8/1-gpu
trtllm-build \
--checkpoint_dir trt_models/exaone/fp8/1-gpu \
@ -204,12 +423,12 @@ trtllm-build \
### SmoothQuant
The examples below uses the NVIDIA Modelopt (AlgorithMic Model Optimization) toolkit for the model quantization process.
The examples below use the NVIDIA ModelOpt (AlgorithMic Model Optimization) toolkit for the model quantization process.
First make sure Modelopt toolkit is installed (see [examples/quantization/README.md](/examples/quantization/README.md#preparation))
```bash
# Build the EXAONE model using a single GPU and and apply INT8 SmoothQuant.
# Build the EXAONE model using a single GPU and apply INT8 SmoothQuant.
python ../../../quantization/quantize.py \
--model_dir $HF_MODEL_DIR \
--dtype float16 \
@ -224,12 +443,12 @@ trtllm-build \
### Groupwise quantization (AWQ)
The examples below uses the NVIDIA Modelopt (AlgorithMic Model Optimization) toolkit for the model quantization process.
The examples below use the NVIDIA ModelOpt (AlgorithMic Model Optimization) toolkit for the model quantization process.
First make sure Modelopt toolkit is installed (see [examples/quantization/README.md](/examples/quantization/README.md#preparation))
```bash
# Build the EXAONE model using a single GPU and and apply INT4 AWQ.
# Build the EXAONE model using a single GPU and apply INT4 AWQ.
python ../../../quantization/quantize.py \
--model_dir $HF_MODEL_DIR \
--dtype float16 \
@ -248,7 +467,7 @@ For Hopper GPUs, TRT-LLM also supports employing FP8 GEMM for accelerating linea
Please make sure your system contains a Hopper GPU before trying the commands below.
```bash
# Build the EXAONE model using a single GPU and and apply W4A8 AWQ.
# Build the EXAONE model using a single GPU and apply W4A8 AWQ.
python ../../../quantization/quantize.py \
--model_dir $HF_MODEL_DIR \
--dtype float16 \
@ -287,4 +506,50 @@ python ../../../summarize.py \
--engine_dir trt_engines/exaone/fp16/1-gpu
```
For more examples see [`examples/models/core/llama/README.md`](../llama/README.md)
For more examples regarding EXAONE-3.0 & EXAONE-Deep's TRT flow, see [`examples/models/core/llama/README.md`](../llama/README.md)
## Troubleshootings
### Troubleshootings for EXAONE-4.0
The following error may occur during quantization:
```bash
torch._dynamo.exc.Unsupported: Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
```
This error may indicate an incompatibility between `torch.compile()` and the `HybridCache` module of the transformers library. As a result, [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer) (ModelOpt) cannot perform PTQ with HybridCache.
Temporarily switching to `DynamicCache` when creating PTQ models could help address the issue. This can be done by updating the `cache_implementation` field in the `generation_config.json` file located in the model checkpoint directory, for example:
```json
# generation_config.json
{
// Change "hybrid" to "dynamic" to run PTQ.
// Revert this to "hybrid" after quantization is complete.
"cache_implementation": "hybrid",
...
}
```
For models with sliding window attention, DynamicCache is less memory-efficient than HybridCache because it retains the entire key-value cache. However, this does not break the model's attention logic, as the cache implementation is separated from the attention computation itself. This trade-off is acceptable for the PTQ process, which is a one-time procedure. Our tests confirm that this workaround does not degrade accuracy on MMLU or GSM8K benchmarks with the default ModelOpt settings.
### Troubleshootings for K-EXAONE
K-EXAONE is a Mixture of Experts (MoE) model which activates 8 experts per token. When not enough tokens are given during the PTQ, some experts on some layers might not be activated and will not produce proper weights.
To address this issue, provide enough data samples during calibration by increasing `calib_size` and `calib_seq` parameters:
**FP8 Quantization:**
```bash
cd Model-Optimizer/examples/llm_ptq
python3 hf_ptq.py --model hf_models/$MODEL_NAME --quant fp8 --export_fmt hf --calib_size 8192 --calib_seq 1024
```
**NVFP4 Quantization:**
```bash
cd Model-Optimizer/examples/llm_ptq
python3 hf_ptq.py --model hf_models/$MODEL_NAME --quant nvfp4 --export_fmt hf --calib_size 8192 --calib_seq 1024
```

View File

@ -5,6 +5,7 @@ from .modeling_bert import BertForSequenceClassification
from .modeling_clip import CLIPVisionModel
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
from .modeling_exaone4 import Exaone4ForCausalLM
from .modeling_exaone_moe import ExaoneMoeForCausalLM
from .modeling_gemma3 import Gemma3ForCausalLM
from .modeling_gemma3vl import Gemma3VLM
from .modeling_glm import Glm4MoeForCausalLM
@ -44,6 +45,7 @@ __all__ = [
"CLIPVisionModel",
"DeepseekV3ForCausalLM",
"Exaone4ForCausalLM",
"ExaoneMoeForCausalLM",
"Gemma3ForCausalLM",
"Gemma3VLM",
"HCXVisionForCausalLM",

View File

@ -0,0 +1,581 @@
import math
import os
import re
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
from ...logger import logger
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (
PositionalEmbeddingParams,
PredefinedAttentionMask,
RopeParams,
)
from ..distributed import (
AllReduce,
AllReduceFusionOp,
AllReduceParams,
MoEAllReduce,
MoEAllReduceParams,
)
from ..model_config import ModelConfig
from ..models.modeling_deepseekv3 import Deepseekv3MoE
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..utils import AuxStreamType, Fp4QuantizedTensor
from .modeling_utils import (
DecoderModel,
DecoderModelForCausalLM,
EagerFusionConfig,
register_auto_model,
)
# fmt: off
# TODO: Remove this once we have a proper transformers package
from transformers import AutoConfig, PretrainedConfig # isort: skip
class ExaoneMoEConfig(PretrainedConfig):
model_type = "exaone_moe"
logger.warning_once(
"transformers does not support 'ExaoneMoEConfig'. "
"Register ExaoneMoEConfig to mimic the ExaoneMoE model.",
key="EXAONE_MOE_REGISTER_WARNING"
)
AutoConfig.register(ExaoneMoEConfig.model_type, ExaoneMoEConfig)
# End of the config register.
# fmt: on
def check_is_moe(config: ExaoneMoEConfig, layer_idx: int) -> bool:
"""
Check if the current layer is a MoE layer.
"""
return hasattr(config, "is_moe_layer") and config.is_moe_layer[layer_idx]
def enable_attn_allreduce(mapping: Mapping):
return not mapping.enable_attention_dp or mapping.has_tp()
class ExaoneMoeAttention(QKNormRoPEAttention):
def __init__(
self,
model_config: ModelConfig[ExaoneMoEConfig],
layer_idx: Optional[int] = None,
fuse_qk_norm_rope: bool = False,
disable_deep_gemm: bool = False,
):
config = model_config.pretrained_config
self.attention_window_size = None
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
# NOTE: In ExaoneMoe, only sliding layers apply rope.
pos_embd_params = None
if self.is_sliding:
self.attention_window_size = config.sliding_window
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
)
fuse_qk_norm_rope = self.is_sliding and fuse_qk_norm_rope
# NOTE: Fusing qk norm with rope has an issue that slightly hurts accuracy.
assert not fuse_qk_norm_rope, "Fusing qk norm and rope is having issue now"
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
fuse_qk_norm_rope=fuse_qk_norm_rope,
skip_rope=not self.is_sliding,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
disable_deep_gemm=disable_deep_gemm,
reduce_output=enable_attn_allreduce(model_config.mapping),
)
def forward(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.CAUSAL,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
lora_params=lora_params,
attention_window_size=self.attention_window_size,
**kwargs,
)
class ExaoneMoeSparseMoEBlock(Deepseekv3MoE):
"""
ExaoneMoe Sparse MoE Block Layer.
It follows DeepSeek-V3 implementation.
"""
class ExaoneMoeDecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[ExaoneMoEConfig],
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
layer_idx: int,
):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.mapping = model_config.mapping
mapping = self.mapping
self.enable_attention_dp = mapping.enable_attention_dp
self.mlp_tp_size = mapping.tp_size
self.is_p2p_supported = can_access_peer(mapping)
self.fusion_config = EagerFusionConfig()
# MoE fusions are disabled by default in K-EXAONE since
# it may cause a slight accuracy drop due to numerical gap.
self.enable_fusion = os.environ.get("TRTLLM_EXAONE_EAGER_FUSION_ENABLED", "0") == "1"
self.enable_fusion &= not self.enable_attention_dp
# FIXME: incompatible with mixed quantization mode
quant_config = self._get_decoder_layer_quant_config(model_config, layer_idx)
self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()
assert quant_config.quant_algo is not QuantAlgo.MIXED_PRECISION, (
"MIXED_PRECISION is ambiguous"
)
self.allreduce = None
self.moe_allreduce = None
if not self.enable_attention_dp and self.mapping.tp_size > 1:
self.allreduce = AllReduce(
mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype,
)
self.moe_allreduce = MoEAllReduce(self.mapping)
has_tp = mapping.has_tp()
has_pp = mapping.has_pp()
# Submodule definitions
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.self_attn = ExaoneMoeAttention(model_config, layer_idx=layer_idx)
# MoE or Dense layer
self.is_moe_layer = check_is_moe(config, layer_idx)
if self.is_moe_layer:
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
self.mlp = ExaoneMoeSparseMoEBlock(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
shared_expert_intermediate_size=config.moe_intermediate_size
* config.num_shared_experts,
dtype=config.torch_dtype,
model_config=model_config,
override_quant_config=quant_config,
aux_stream_dict=aux_stream_dict,
layer_idx=layer_idx,
)
else:
block_size = 1
if quant_config.quant_algo is None and quant_config.group_size is not None:
block_size = quant_config.group_size
self.mlp_tp_size = self._compute_mlp_tp_size(config.intermediate_size, block_size)
has_mlp_tp = self.mlp_tp_size > 1
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp
self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
dtype=config.torch_dtype,
config=model_config,
# Keep sharding consistent with computed mlp_tp_size.
# In attention-DP, mlp_tp_size==1 -> disable TP sharding here.
overridden_tp_size=self.mlp_tp_size,
layer_idx=layer_idx,
reduce_output=has_mlp_tp,
)
self.disable_attn_allreduce = (
self.fusion_config.PRE_MOE_FUSION
or self.fusion_config.PRE_MLP_FUSION
or self.mapping.tp_size == 1
or self.enable_attention_dp
)
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.next_layer_layernorm: RMSNorm = None
def _get_decoder_layer_quant_config(
self, model_config: ModelConfig[ExaoneMoEConfig], layer_idx: int
):
"""
The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM
moe_backend only supports fp8/fp4 quantization, we need to override
the quant_config for the MTP layer.
"""
quant_config = model_config.quant_config
layer_name = f"model.layers.{layer_idx}"
if quant_config.is_module_excluded_from_quantization(layer_name):
return QuantConfig(
quant_algo=None,
kv_cache_quant_algo=quant_config.kv_cache_quant_algo,
)
else:
return model_config.quant_config
def _compute_mlp_tp_size(self, intermediate_size: int, block_size: int) -> int:
"""Adopted from DeepseekV3DecoderLayer._compute_mlp_tp_size."""
assert intermediate_size % block_size == 0, (
f"intermediate_size {intermediate_size} must be divisible by block_size {block_size}."
)
if self.enable_attention_dp:
# If using attention DP, the MLP also uses DP instead of TP.
mlp_tp_size = 1
else:
# The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes.
tp = math.gcd(
intermediate_size // block_size,
self.mapping.tp_size,
)
if tp > self.mapping.gpus_per_node:
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) # Avoid costly inter-node TP
else:
mlp_tp_size = tp
return mlp_tp_size
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# LN has neem already applied at the previous layer except the first layer.
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
if self.is_moe_layer:
hidden_states, residual = self.forward_moe(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
)
else:
hidden_states, residual = self.forward_mlp(
hidden_states=hidden_states,
residual=residual,
)
return hidden_states, residual
def forward_moe(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _run_moe(hidden_states, hidden_states_fp4, do_finalize):
return self.mlp(
hidden_states,
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
)
),
do_finalize=do_finalize,
)
if self.fusion_config.PRE_MOE_FUSION:
# moe_backend can be either CUTLASS or TRTLLM here
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
trigger_completion_at_end=False,
),
)
else:
# No fusion
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now
do_finalize = self.mapping.is_multi_node() or (
not (
self.fusion_config.POST_MOE_FUSION
and hidden_states.shape[0] <= self.moe_allreduce.max_token
and self.model_config.moe_backend == "TRTLLM"
and self.mlp.experts.has_nvfp4
and self.is_p2p_supported
)
)
hidden_states = _run_moe(hidden_states, hidden_states_fp4=None, do_finalize=do_finalize)
if self.fusion_config.POST_MOE_FUSION:
if do_finalize:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
trigger_completion_at_end=False,
),
)
else:
assert len(hidden_states) == 4, "hidden_states must have 4 elements"
shared_output = hidden_states[0]
fc2_output = hidden_states[1]
expert_scale_factor = hidden_states[2]
expanded_idx_to_permuted_idx = hidden_states[3]
moe_all_reduce_params = MoEAllReduceParams(
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
expert_scale_factor=expert_scale_factor,
shared_expert_output=shared_output,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
is_cutlass_min_latency=False,
)
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params
)
elif self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(hidden_states, residual)
return hidden_states, residual
def forward_mlp(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.fusion_config.PRE_MLP_FUSION:
act_fp4, act_sf, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=self.mlp.gate_up_proj.input_scale,
eps=self.post_attention_layernorm.variance_epsilon,
),
)
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(
hidden_states,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)
),
)
if self.fusion_config.POST_MLP_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
),
)
elif self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(hidden_states, residual)
return hidden_states, residual
class ExaoneMoeModel(DecoderModel):
def __init__(self, model_config: ModelConfig[ExaoneMoEConfig]):
super().__init__(model_config)
config = self.model_config.pretrained_config
self.num_hidden_layers = config.num_hidden_layers
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
AuxStreamType.MoeBalancer: aux_stream_list[2],
}
self.layers = nn.ModuleList(
[
ExaoneMoeDecoderLayer(
model_config=model_config,
aux_stream_dict=self.aux_stream_dict,
layer_idx=layer_idx,
)
for layer_idx in range(self.num_hidden_layers)
]
)
self.norm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lora_params=None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, Optional[torch.Tensor]]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at "
"the same time, and must specify either one."
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds.to(self.dtype)
residual = None
for decoder_layer in self.layers[: self.num_hidden_layers]:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
lora_params=lora_params,
)
# The last LN already has been applied as a part of fusion.
return hidden_states
@register_auto_model("ExaoneMoEForCausalLM")
class ExaoneMoeForCausalLM(DecoderModelForCausalLM[ExaoneMoeModel, ExaoneMoEConfig]):
def __init__(
self,
model_config: ModelConfig[ExaoneMoEConfig],
):
super().__init__(
ExaoneMoeModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)
def load_weights(
self,
weights: Dict,
weight_mapper: Optional["BaseWeightMapper"] = None, # noqa: F821
skip_modules: Optional[List[str]] = None,
allow_partial_loading: bool = False,
):
# MoE naming pattern.
moe_weight_patterns = {
"gate_proj": "w1",
"up_proj": "w3",
"down_proj": "w2",
}
module_names = list(weights)
for name in module_names:
if "mlp.e_score_correction_bias" in name:
# Move bias into the gate module.
new_name = name.replace(
"mlp.e_score_correction_bias", "mlp.gate.e_score_correction_bias"
)
else:
# MoE Weight Remapping.
new_name = name
for k, v in moe_weight_patterns.items():
pattern = rf"(experts\.\d+\.){k}\b"
new_name = re.sub(pattern, rf"\1{v}", new_name)
# Remap the name-parameter pair if needed.
if new_name != name:
weights[new_name] = weights.pop(name)
super().load_weights(
weights=weights,
weight_mapper=weight_mapper,
skip_modules=skip_modules or [],
allow_partial_loading=allow_partial_loading,
)
def post_load_weights(self):
# For the cross-layer residual+LN fusion.
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
layer.next_layer_layernorm = self.model.norm
else:
layer.next_layer_layernorm = self.model.layers[idx + 1].input_layernorm

View File

@ -257,21 +257,33 @@ class Deepseekv3RoutingImpl:
if self.n_group > 1:
if self.top_k > 8 or (num_experts / n_group) > 32 or (
num_experts / n_group) * self.topk_group > 128:
if (self.is_fused):
if self.is_fused:
warnings.warn(
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
)
self.is_fused = False
else:
elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22)
or self.topk_group == 1):
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
if num_experts > 512 or (self.top_k > 8 and self.top_k != 22):
if (self.is_fused):
warnings.warn(
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
)
self.is_fused = False
if self.is_fused:
warnings.warn(
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
)
self.is_fused = False
if not self.is_fused:
if self.n_group == 1 and self.topk_group == 1:
scores, scores_with_bias = self.get_scores(logits,
e_score_correction_bias)
_, topk_indices = torch.topk(scores_with_bias, k=self.top_k, dim=1)
topk_values = torch.gather(scores, dim=1,
index=topk_indices).type_as(scores)
# Normalize and scale.
topk_values_sum = torch.sum(topk_values, dim=-1,
keepdim=True) + 1e-20
topk_values = topk_values / topk_values_sum * self.routed_scaling_factor
return topk_values, topk_indices
elif not self.is_fused:
scores, scores_with_bias = self.get_scores(logits,
e_score_correction_bias)
scores_shape = list(scores_with_bias.shape)

View File

@ -82,6 +82,7 @@ l0_b200:
- unittest/_torch/modeling -k "modeling_llama"
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_gpt_oss"
- unittest/_torch/modeling/test_modeling_exaone_moe.py
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_dep[1]
- unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1]
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8

View File

@ -0,0 +1,406 @@
import unittest
from copy import deepcopy
from dataclasses import dataclass
import torch
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_exaone_moe import ExaoneMoeForCausalLM
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from utils.util import getSMVersion # isort: skip
# fmt: off
# TODO: Remove this once we have a proper transformers package
from tensorrt_llm._torch.models.modeling_exaone_moe import ExaoneMoEConfig # isort: skip
SKIP_EXAONE_MOE_HF_ACCURACY_TEST = False
try:
from transformers.models.exaone_moe.modeling_exaone_moe import (
ExaoneMoEForCausalLM as HFExaoneMoEForCausalLM,
)
except ImportError:
# TODO: Remove this once we have a proper config for EXAONE-MoE
SKIP_EXAONE_MOE_HF_ACCURACY_TEST = True
# fmt: on
WINDOW_SIZE = 4
NUM_HIDDEN_LAYERS = 4
EXAONE_MOE_CONFIG = {
"architectures": ["ExaoneMoEForCausalLM"],
"attention_dropout": 0.0,
"bos_token_id": 1,
"dtype": "bfloat16",
"eos_token_id": 53,
"first_last_k_dense_replace": 1,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 6144,
"initializer_range": 0.02,
"intermediate_size": 18432,
"is_moe_layer": [False] + [True] * (NUM_HIDDEN_LAYERS - 1),
"layer_types": [
"sliding_attention",
"sliding_attention",
"sliding_attention",
"full_attention",
],
"max_position_embeddings": 262144,
"model_type": "exaone_moe",
"moe_intermediate_size": 2048,
"n_group": 1,
"norm_topk_prob": True,
"num_attention_heads": 64,
"num_experts": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": NUM_HIDDEN_LAYERS,
"num_key_value_heads": 8,
"num_shared_experts": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"rope_theta": 1000000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"sliding_window": WINDOW_SIZE,
"sliding_window_pattern": "LLLG",
"tie_word_embeddings": False,
"tokenizer_class": "GPT2Tokenizer",
"topk_group": 1,
"topk_method": "noaux_tc",
"transformers_version": "5.0.0.dev0",
"use_cache": True,
"vocab_size": 153600,
}
@dataclass(repr=False)
class Scenario:
attention_backend: str
input_len: int = WINDOW_SIZE - 1
use_cuda_graph: bool = False
def __repr__(self) -> str:
return (
f"attention_backend:{self.attention_backend.lower()}-"
f"input_len:{self.input_len}-"
f"use_cuda_graph:{self.use_cuda_graph}"
)
class TestExaoneMoe(unittest.TestCase):
@parameterized.expand([None, "FP8"])
def test_exaone_moe_sanity(self, quant_algo):
"""Test basic EXAONE-MoE model forward pass with optional quantization."""
config_dict = deepcopy(EXAONE_MOE_CONFIG)
exaone_moe_config = ExaoneMoEConfig.from_dict(config_dict)
if quant_algo:
quant_config = QuantConfig(quant_algo=quant_algo)
else:
quant_config = QuantConfig()
if quant_algo == "FP8" and getSMVersion() < 89:
self.skipTest("This test is not supported in pre-Ada architecture")
dtype = exaone_moe_config.torch_dtype
device = torch.device("cuda")
model_config = ModelConfig(pretrained_config=exaone_moe_config, quant_config=quant_config)
exaone_moe = ExaoneMoeForCausalLM(model_config).to(device)
input_ids = torch.tensor(
[100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, device=device
)
context_sequence_lengths = [3, 2, 1]
sequence_lengths = context_sequence_lengths + [1, 1]
past_seen_tokens = [0, 0, 0, 62, 75]
request_ids = list(range(len(sequence_lengths)))
token_nums = (torch.tensor(past_seen_tokens) + torch.tensor(sequence_lengths)).tolist()
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
num_blocks = 100
tokens_per_block = 128
head_dim = exaone_moe.config.hidden_size // exaone_moe.config.num_attention_heads
num_layers = exaone_moe.config.num_hidden_layers
num_kv_heads = exaone_moe.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = len(context_sequence_lengths) + 2
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=len(context_sequence_lengths) + 2,
max_num_tokens=8192,
)
position_ids = []
for i, tokens in enumerate(past_seen_tokens):
seq_len = context_sequence_lengths[i] if i < len(context_sequence_lengths) else 1
position_id = torch.arange(tokens, tokens + seq_len, device=input_ids.device)
position_ids.append(position_id)
position_ids = torch.cat(position_ids).unsqueeze(0)
with torch.inference_mode():
attn_metadata.prepare()
logits = exaone_moe.forward(
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
self.assertEqual(len(past_seen_tokens), logits.shape[0])
with torch.inference_mode():
attn_metadata.prepare()
logits = exaone_moe.forward(
input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True,
)
self.assertEqual(input_ids.shape, logits.shape[:-1])
kv_cache_manager.shutdown()
def test_exaone_moe_moe_layer_config(self):
"""Test that MoE layers are correctly configured."""
config_dict = deepcopy(EXAONE_MOE_CONFIG)
exaone_moe_config = ExaoneMoEConfig.from_dict(config_dict)
device = torch.device("cuda")
model_config = ModelConfig(pretrained_config=exaone_moe_config)
exaone_moe = ExaoneMoeForCausalLM(model_config).to(device)
# Verify MoE layer configuration
is_moe_layer = config_dict["is_moe_layer"]
self.assertEqual(len(is_moe_layer), NUM_HIDDEN_LAYERS)
self.assertFalse(is_moe_layer[0]) # First layer should be dense
for i in range(1, NUM_HIDDEN_LAYERS):
self.assertTrue(is_moe_layer[i]) # Rest should be MoE
# Verify model has correct number of layers
self.assertEqual(len(exaone_moe.model.layers), NUM_HIDDEN_LAYERS)
@parameterized.expand(
[
Scenario(attention_backend="TRTLLM", input_len=WINDOW_SIZE - 2),
Scenario(attention_backend="TRTLLM", input_len=WINDOW_SIZE - 2, use_cuda_graph=True),
],
lambda testcase_func, param_num, param: f"{testcase_func.__name__}[{param.args[0]}]",
)
@torch.no_grad()
def test_exaone_moe_allclose_to_hf(self, scenario: Scenario) -> None:
"""Compare output to HuggingFace implementation."""
if SKIP_EXAONE_MOE_HF_ACCURACY_TEST:
self.skipTest("EXAONE-MoE HF model is not available in this environment")
attention_backend = scenario.attention_backend
metadata_cls = get_attention_backend(attention_backend).Metadata
torch.random.manual_seed(0)
config_dict = deepcopy(EXAONE_MOE_CONFIG)
exaone_moe_config = ExaoneMoEConfig.from_dict(config_dict)
dtype = exaone_moe_config.torch_dtype
device = torch.device("cuda")
hf_exaone_moe = HFExaoneMoEForCausalLM(exaone_moe_config).to(dtype).to(device).eval()
model_config = ModelConfig(
pretrained_config=exaone_moe_config, attn_backend=attention_backend
)
exaone_moe = ExaoneMoeForCausalLM(model_config).to(dtype).to(device)
exaone_moe.load_weights(hf_exaone_moe.state_dict())
exaone_moe.post_load_weights()
num_blocks = 1
tokens_per_block = 128
head_dim = getattr(
exaone_moe.config,
"head_dim",
exaone_moe.config.hidden_size // exaone_moe.config.num_attention_heads,
)
num_layers = exaone_moe.config.num_hidden_layers
num_kv_heads = exaone_moe.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
enable_partial_reuse=False,
copy_on_partial_reuse=False,
max_attention_window=[int(exaone_moe_config.sliding_window)],
max_tokens=num_blocks * tokens_per_block,
)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
# Context phase
input_ids = torch.tensor(
[i * 100 for i in range(1, scenario.input_len + 1)], dtype=torch.int32, device=device
)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int32)]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = exaone_moe.forward(
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
ref = hf_exaone_moe.forward(
input_ids=input_ids.unsqueeze(0), position_ids=position_ids, use_cache=True
)
# MoE models may have slightly higher tolerance due to expert routing
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.5, rtol=0.5)
# Generation phase
gen_input_ids = torch.tensor([600], dtype=torch.int32, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
gen_position_ids = [
torch.arange(
input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.int32
)
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = create_mock_cuda_graph_runner(1) if scenario.use_cuda_graph else None
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
return exaone_moe.forward(
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
else:
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
key = (1, 0, False)
graph_runner.capture(key, lambda inputs: exaone_moe.forward(**inputs), inputs)
for _ in range(2):
attn_metadata.prepare()
logits = graph_runner.replay(key, inputs)
return logits
if scenario.use_cuda_graph:
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
with torch.inference_mode():
logits = run_forward(
input_ids=gen_input_ids, position_ids=gen_position_ids, attn_metadata=attn_metadata
)
ref = hf_exaone_moe.forward(
input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True,
)
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.5, rtol=0.5)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()
if __name__ == "__main__":
unittest.main()