feat: EXAONE4.0 support (#5696)

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
Yechan Kim 2025-07-14 22:28:10 +09:00 committed by GitHub
parent dbf29184dc
commit 63139fdcff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 392 additions and 1 deletions

View File

@ -10,7 +10,11 @@ See the LLaMA example [`examples/models/core/llama`](../llama) for details.
- [Supported Models](#supported-models)
- [EXAONE-3.0](#exaone-30)
- [EXAONE-Deep](#exaone-deep)
- [EXAONE-4.0](#exaone-40)
- [Usage](#usage)
- [PyTorch flow](#pytorch-flow)
-[PyTorch flow Quantization](#pytorch-flow-quantization)
- [TRT Flow](#trt-flow)
- [Convert checkpoint and build TensorRT engine(s)](#convert-checkpoint-and-build-tensorrt-engines)
- [FP8 Post-Training Quantization](#fp8-post-training-quantization)
- [SmoothQuant](#smoothquant)
@ -39,16 +43,79 @@ git clone https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct $HF_MODEL_
### EXAONE-Deep
Download the HuggingFace BF16 checkpoints of EXAONE-Deep model. Here, we only use the `EXAONE-Deep-2.4B` model for the example. We can use the same procedure as EXAONE-3.0 to convert the weights and build the TensorRT engine.
Download the HuggingFace checkpoints of EXAONE-Deep model. Here, we only use the `EXAONE-Deep-2.4B` model for the example. We can use the same procedure as EXAONE-3.0 to convert the weights and build the TensorRT engine.
```bash
export HF_MODEL_DIR=hf_models/exaone_deep
git clone https://huggingface.co/LGAI-EXAONE/EXAONE-Deep-2.4B $HF_MODEL_DIR
```
### EXAONE-4.0
Download he HuggingFace checkpoints of EXAONE-4.0 model. Here, we only use the `TODO: replace with REAL name, EXAONE-4.0` model for the example. From EXAONE-4.0 model, we support EXAONE models only on PyTorch flow.
```bash
export HF_MODEL_DIR=hf_models/exaone4
git clone ... $HF_MODEL_DIR (TODO Change ... to real HF directory)
```
## Usage
The next section describe how to convert the weights from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT-LLM format. We will use llama's [convert_checkpoint.py](../llama/convert_checkpoint.py) for EXAONE model and then we build the model with `trtllm-build`.
### Pytorch flow
To quickly run EXAONE-4.0 models, you can use [examples/llm-api/quickstart_advanced.py](../../../llm-api/quickstart_advanced.py):
```bash
python ../../../llm-api/quickstart_advanced.py --model_dir hf_models/$MODEL_NAME --disable_kv_cache_reuse
```
SWA currently does not support kv_cache_reuse. Please make sure to disable KV cache reuse when running with SWA.
The output will be like:
```bash
TODO: Fill this with real HF checkpoints output
```
#### PyTorch flow Quantization
For PyTorch flow, TRT-LLM supports quantized format generated by [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer).
You can either do pre-quantized models in HF model hub, or can generate quantized model by yourself and then run models with below command:
```bash
git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git
cd TensorRT-Model-Optimizer/examples/llm_ptq
scripts/huggingface_example.sh --model hf_models/$MODEL_NAME --quant fp8 --export_fmt hf
```
For more information, please refer to official [docs](https://github.com/NVIDIA/TensorRT-Model-Optimizer) or [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer).
Troubleshooting
The following error may occur during quantization:
```bash
torch._dynamo.exc.Unsupported: Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
```
This error may indicate an incompatibility between `torch.compile()` and the `HybridCache` module of the transformers library. As a result, [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) (ModelOpt) cannot perform PTQ with HybridCache.
Temporarily switching to `DynamicCache` when creating PTQ models could help address the issue. This can be done by updating the `cache_implementation` field in the `generation_config.json` file located in the model checkpoint directory, for example:
```json
# generation_config.json
{
// Change "hybrid" to "dynamic" to run PTQ.
// Revert this to "hybrid" after quantization is complete.
"cache_implementation": "hybrid",
...
}
```
For models with sliding window attention, DynamicCache is less memory-efficient than HybridCache because it retains the entire key-value cache. However, this does not break the model's attention logic, as the cache implementation is separated from the attention computation itself. This trade-off is acceptable for the PTQ process, which is a one-time procedure. Our tests confirm that this workaround does not degrade accuracy on MMLU or GSM8K benchmarks with the default ModelOpt settings.
### TRT flow
### Convert checkpoint and build TensorRT engine(s)
```bash

View File

@ -4,6 +4,7 @@ from .modeling_auto import AutoModelForCausalLM
from .modeling_bert import BertForSequenceClassification
from .modeling_clip import CLIPVisionModel
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
from .modeling_exaone4 import Exaone4ForCausalLM
from .modeling_gemma3 import Gemma3ForCausalLM
from .modeling_gemma3vl import Gemma3Model
from .modeling_hyperclovax import HCXVisionForCausalLM
@ -30,6 +31,7 @@ __all__ = [
"BertForSequenceClassification",
"CLIPVisionModel",
"DeepseekV3ForCausalLM",
"Exaone4ForCausalLM",
"Gemma3ForCausalLM",
"HCXVisionForCausalLM",
"Gemma3Model",

View File

@ -0,0 +1,322 @@
from typing import Optional, Tuple
import torch
from torch import nn
from tensorrt_llm._torch.distributed import AllReduceParams
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
try:
from transformers import Exaone4Config
except ImportError:
# TODO: Remove this once we have a proper transformers package
from transformers import AutoConfig, PretrainedConfig
class Exaone4Config(PretrainedConfig):
model_type = "exaone4"
AutoConfig.register(Exaone4Config.model_type, Exaone4Config)
def check_is_sliding(config: Exaone4Config, layer_idx: int) -> bool:
"""
Check if the current layer is a sliding window (local attention) layer.
"""
if config.sliding_window is None:
return False
if isinstance(config.sliding_window_pattern, int):
return ((layer_idx + 1) % config.sliding_window_pattern) != 0
elif isinstance(config.sliding_window_pattern, str):
assert isinstance(config.sliding_window, int), (
f"Sliding window must be positive integer, but got {config.sliding_window}"
)
return (layer_idx != config.num_hidden_layers - 1
and config.sliding_window_pattern[layer_idx % len(
config.sliding_window_pattern)] == "L")
return False
class Exaone4Attention(Attention):
def __init__(self,
model_config: ModelConfig[Exaone4Config],
is_sliding: bool,
layer_idx: Optional[int] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
fuse_qk_norm_rope: bool = False):
config = model_config.pretrained_config
self.attention_window_size = None
# NOTE: In EXAONE4, only sliding layers apply rope.
self.is_sliding = is_sliding
pos_embd_params = None
if self.is_sliding:
self.attention_window_size = config.sliding_window
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
)
self.fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope)
# TODO: Fusing qk norm with rope has an issue that slightly hurts accuracy.
assert self.fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now"
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
rope_fusion=False,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
def apply_qk_norm_rope(self, qkv, position_ids):
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
self.num_key_value_heads, self.head_dim,
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight, self.pos_embd_params.rope.theta,
self.pos_embd_params.is_neox, position_ids.view(-1))
return qkv, None, None
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
if self.fuse_qk_norm_rope:
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
qkv = q
return self.apply_qk_norm_rope(qkv, position_ids)
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if self.is_sliding:
return super().apply_rope(q, k, v, position_ids)
else:
return q, k, v
def forward(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
# TODO LoRA has not been tested yet but there is no need to prevent it.
assert lora_params is None, "LORA is not supported for Exaone4Attention"
return super().forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
attention_window_size=self.attention_window_size,
**kwargs,
)
class Exaone4DecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[Exaone4Config],
layer_idx: int,
aux_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
)
is_sliding = check_is_sliding(config, layer_idx)
self.self_attn = Exaone4Attention(
model_config,
is_sliding=is_sliding,
layer_idx=layer_idx,
aux_stream=aux_stream,
)
self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=getattr(config, "mlp_bias", False),
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx,
)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_feedforward_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.mapping = model_config.mapping
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.mapping.tp_size == 1)),
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.mlp(
hidden_states,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.mapping.tp_size == 1)),
)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class Exaone4Model(DecoderModel):
def __init__(self, model_config: ModelConfig[Exaone4Config]):
super().__init__(model_config)
config = self.model_config.pretrained_config
self.num_hidden_layers = config.num_hidden_layers
self.aux_stream = torch.cuda.Stream()
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
Exaone4DecoderLayer(
model_config,
layer_idx,
self.aux_stream,
) for layer_idx in range(self.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
lora_params=None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, Optional[torch.Tensor]]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at "
"the same time, and must specify either one.")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds.to(self.dtype)
for decoder_layer in self.layers[:self.num_hidden_layers]:
hidden_states = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
lora_params=lora_params,
)
hidden_states = self.norm(hidden_states)
return hidden_states
@register_auto_model("Exaone4ForCausalLM")
class Exaone4ForCausalLM(DecoderModelForCausalLM[Exaone4Model, Exaone4Config]):
def __init__(
self,
model_config: ModelConfig[Exaone4Config],
):
model_config.pretrained_config.torch_dtype = torch.bfloat16
super().__init__(Exaone4Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)