mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][feat] K-EXAONE MTP support (#10796)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
parent
415739711f
commit
70caa779a4
@ -10,6 +10,7 @@ The following is a table of supported models for the PyTorch backend:
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3` |
|
||||
| `DeepseekV32ForCausalLM` | DeepSeek-V3.2 | `deepseek-ai/DeepSeek-V3.2` |
|
||||
| `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` |
|
||||
| `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` |
|
||||
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b` |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` |
|
||||
|
||||
@ -42,6 +42,7 @@ This document shows how to build and run [EXAONE](https://huggingface.co/LGAI-EX
|
||||
* Expert Parallel (EP) (K-EXAONE only)
|
||||
* Attention Data Parallel (ADP) (K-EXAONE only)
|
||||
* Disaggregated Serving
|
||||
* MTP (Multi Token Prediction)
|
||||
* FP8
|
||||
* INT8 & INT4 Weight-Only
|
||||
* INT8 SmoothQuant
|
||||
@ -120,13 +121,13 @@ python ../../../llm-api/quickstart_advanced.py \
|
||||
--tp_size 8 \
|
||||
--moe_ep_size 8 \
|
||||
--enable_attention_dp \
|
||||
--trust_remote_code
|
||||
--apply_chat_template
|
||||
```
|
||||
The output will be like:
|
||||
```bash
|
||||
[0] Prompt: 'Hello, my name is', Generated text: ' John Smith, and I am a 28-year-old software developer. I live in the city of San Francisco, California. I work remotely for a tech startup based in Austin, Texas.\n\nI enjoy hiking, reading, and playing the piano. In my free time, I often explore new neighborhoods in San Francisco, trying out new restaurants and cafes.\n\n'
|
||||
[1] Prompt: 'The capital of France is', Generated text: ' Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris'
|
||||
[2] Prompt: 'The future of AI is', Generated text: ' bright.\n</think>\n\nThe future of AI holds immense promise across numerous domains. In healthcare, AI is revolutionizing diagnostics, drug discovery, and personalized treatment plans. In education, AI is enabling adaptive learning platforms that cater to individual learning styles and paces. In environmental science, AI is playing a pivotal role in addressing climate change by optimizing'
|
||||
[0] Prompt: '<|user|>\nHello, my name is<|endofturn|>\n<|assistant|>\n<think>\n', Generated text: 'Okay, the user started with "Hello, my name is" and then stopped. Hmm, they probably forgot to finish their sentence. \n\nI should figure out what they need. Since they\'re introducing themselves, maybe they want help completing their introduction. Or perhaps they\'re testing how I respond to incomplete messages.\n\nLet me check the context again. The user\'s message is cut off'
|
||||
[1] Prompt: '<|user|>\nThe capital of France is<|endofturn|>\n<|assistant|>\n<think>\n', Generated text: 'Okay, the user asked, "The capital of France is". Hmm, this seems like a straightforward question. But let me think deeper.\n\nFirst, the user might be testing basic knowledge. Or perhaps they\'re a student learning geography. Alternatively, they could be someone verifying information, maybe for a project or trivia.\n\nWait, the user\'s query is incomplete incomplete.'
|
||||
[2] Prompt: '<|user|>\nThe future of AI is<|endofturn|>\n<|assistant|>\n<think>\n', Generated text: 'Okay, the user asked, "The future of AI is..." and stopped. Hmm, they probably want me to complete that thought or elaborate on what the future of AI entails.\n\nFirst, I need to figure out what the user is really looking for. They might be a student researching AI, a professional trying to understand industry trends, or just someone curious about where AI is heading.\n\n\n\nSince
|
||||
```
|
||||
|
||||
#### MoE Backend Options
|
||||
@ -147,10 +148,26 @@ python ../../../llm-api/quickstart_advanced.py \
|
||||
--tp_size 8 \
|
||||
--moe_ep_size 8 \
|
||||
--enable_attention_dp \
|
||||
--moe_backend CUTLASS \
|
||||
--trust_remote_code
|
||||
--moe_backend CUTLASS
|
||||
```
|
||||
|
||||
#### MTP (Multi-Token Prediction)
|
||||
|
||||
K-EXAONE has 1 MTP layer. To run with MTP, use [examples/llm-api/quickstart_advanced.py](../../../llm-api/quickstart_advanced.py) with additional options:
|
||||
|
||||
```bash
|
||||
python ../../../llm-api/quickstart_advanced.py \
|
||||
--model_dir $HF_MODEL_DIR \
|
||||
--tp_size 8 \
|
||||
--moe_ep_size 8 \
|
||||
--enable_attention_dp \
|
||||
--spec_decode_algo MTP \
|
||||
--spec_decode_max_draft_len N \
|
||||
--use_one_model
|
||||
```
|
||||
|
||||
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared.
|
||||
|
||||
### PyTorch flow Quantization
|
||||
|
||||
For PyTorch flow, TRT-LLM supports quantized formats generated by [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer). You can either use pre-quantized models from the HuggingFace model hub, or generate quantized models yourself using the instructions below.
|
||||
|
||||
@ -360,6 +360,50 @@ class ModelConfig(Generic[TConfig]):
|
||||
else:
|
||||
quant_config.exclude_modules = default_exclude
|
||||
|
||||
# NOTE: This is for llm-compressor's quantized checkpoints.
|
||||
elif hf_quant_config.get("quant_method") == "compressed-tensors":
|
||||
config_groups = hf_quant_config.get("config_groups")
|
||||
if config_groups is None:
|
||||
raise ValueError(
|
||||
f"config_groups is not set in {hf_quant_config}.")
|
||||
|
||||
weights_quant_config = config_groups["group_0"]["weights"]
|
||||
inputs_quant_config = config_groups["group_0"]["input_activations"]
|
||||
weights_quant_strategy = weights_quant_config["strategy"]
|
||||
inputs_quant_strategy = inputs_quant_config["strategy"]
|
||||
|
||||
if weights_quant_config["num_bits"] == 8:
|
||||
if weights_quant_strategy == "channel":
|
||||
if inputs_quant_strategy != "token":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
|
||||
elif weights_quant_strategy == "block":
|
||||
if inputs_quant_strategy != "group":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
group_size = inputs_quant_config["group_size"]
|
||||
|
||||
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
|
||||
if group_size != 128:
|
||||
raise ValueError(
|
||||
f"Unsupported group_size: {group_size}. Supported: 128."
|
||||
)
|
||||
quant_config.group_size = group_size
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
|
||||
"Supported strategies: 'channel', 'block'.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
|
||||
"Supported: 8.")
|
||||
|
||||
quant_config.exclude_modules = hf_quant_config.get("ignore", [])
|
||||
return quant_config, layer_quant_config
|
||||
|
||||
@staticmethod
|
||||
@ -522,11 +566,17 @@ class ModelConfig(Generic[TConfig]):
|
||||
attn_tp_size * attn_cp_size)
|
||||
|
||||
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
|
||||
num_layers = self.pretrained_config.num_hidden_layers
|
||||
num_attention_layers = self.get_num_attention_layers()
|
||||
if (self.spec_config is not None
|
||||
and self.spec_config.spec_dec_mode.is_mtp_one_model()):
|
||||
num_layers += self.spec_config.num_nextn_predict_layers
|
||||
num_attention_layers += self.spec_config.num_nextn_predict_layers
|
||||
|
||||
model_config_cpp = ModelConfigCpp(
|
||||
vocab_size=self.pretrained_config.vocab_size,
|
||||
num_layers=self.pretrained_config.num_hidden_layers,
|
||||
num_attention_layers=self.get_num_attention_layers(),
|
||||
num_layers=num_layers,
|
||||
num_attention_layers=num_attention_layers,
|
||||
num_rnn_layers=0,
|
||||
num_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_mapper
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
||||
|
||||
|
||||
@register_mapper("HF", "ExaoneMoEForCausalLM")
|
||||
class ExaoneMoeWeightMapper(HfWeightMapper):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# MoE expert weights: gate_proj->w1, up_proj->w3, down_proj->w2
|
||||
# e_score_correction_bias: move into gate module
|
||||
self.params_map = {
|
||||
r"(.*experts\.\d+\.)gate_proj(.*)": r"\1w1\2",
|
||||
r"(.*experts\.\d+\.)up_proj(.*)": r"\1w3\2",
|
||||
r"(.*experts\.\d+\.)down_proj(.*)": r"\1w2\2",
|
||||
r"(.*)mlp\.e_score_correction_bias(.*)": r"\1mlp.gate.e_score_correction_bias\2",
|
||||
}
|
||||
self.mtp_mapping = {
|
||||
"mtp.fc": "eh_proj",
|
||||
"mtp.norm": "shared_head.norm",
|
||||
"mtp.pre_fc_norm_embedding": "enorm",
|
||||
"mtp.pre_fc_norm_hidden": "hnorm",
|
||||
}
|
||||
|
||||
def preprocess_weights(self, weights: dict):
|
||||
mtp_layer_offset = self.config.pretrained_config.num_hidden_layers
|
||||
|
||||
for name in weights.keys():
|
||||
if name.startswith("mtp.layers."):
|
||||
# mtp.layers.{idx}.* -> model.layers.{offset + idx}.*
|
||||
_, _, mtp_layer_idx, module_name = name.split(".", 3)
|
||||
new_name = f"model.layers.{mtp_layer_offset + int(mtp_layer_idx)}.{module_name}"
|
||||
weights[new_name] = weights.pop(name)
|
||||
elif name.startswith("mtp."):
|
||||
# mtp.fc.* -> model.layers.{offset}.eh_proj.*
|
||||
# mtp.norm.* -> model.layers.{offset}.shared_head.norm.*
|
||||
# etc.
|
||||
for mtp_prefix, trtllm_name in self.mtp_mapping.items():
|
||||
if name.startswith(mtp_prefix):
|
||||
suffix = name[len(mtp_prefix) :]
|
||||
new_name = f"model.layers.{mtp_layer_offset}.{trtllm_name}{suffix}"
|
||||
weights[new_name] = weights.pop(name)
|
||||
break
|
||||
|
||||
def is_special_instance_module(self, module: nn.Module) -> bool:
|
||||
return isinstance(module, MoE)
|
||||
|
||||
def handle_special_instance_module(
|
||||
self,
|
||||
module: nn.Module,
|
||||
module_name: str,
|
||||
module_weights: dict,
|
||||
allow_partial_loading: bool = False,
|
||||
) -> None:
|
||||
if isinstance(module, MoE):
|
||||
updated_module_weights = {}
|
||||
for weight_name, weight_value in module_weights.items():
|
||||
new_weight_name = weight_name.replace("weight_scale", "weight_scale_inv")
|
||||
if new_weight_name.endswith(".weight_scale_inv"):
|
||||
weight_value = weight_value.squeeze()
|
||||
updated_module_weights[new_weight_name] = weight_value
|
||||
module.load_weights(
|
||||
weights=[updated_module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
@ -32,7 +32,8 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
|
||||
if hasattr(config.pretrained_config, "draft_vocab_size"):
|
||||
model_arch = "EAGLE3" + model_arch
|
||||
if model_arch in (
|
||||
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"
|
||||
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM",
|
||||
"ExaoneMoEForCausalLM"
|
||||
) and config.spec_config is not None and config.spec_config.max_draft_len == 0:
|
||||
model_arch = "MTPDraftModelForCausalLM"
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -32,15 +31,15 @@ from ..models.modeling_deepseekv3 import Deepseekv3MoE
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from .modeling_utils import (
|
||||
DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
EagerFusionConfig,
|
||||
register_auto_model,
|
||||
)
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .checkpoints.hf.exaone_moe_weight_mapper import ExaoneMoeWeightMapper
|
||||
from .modeling_deepseekv3 import DeepseekV3MTPHead
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
|
||||
|
||||
# fmt: off
|
||||
# TODO: Remove this once we have a proper transformers package
|
||||
@ -59,11 +58,11 @@ AutoConfig.register(ExaoneMoEConfig.model_type, ExaoneMoEConfig)
|
||||
# fmt: on
|
||||
|
||||
|
||||
def check_is_moe(config: ExaoneMoEConfig, layer_idx: int) -> bool:
|
||||
def check_is_moe(config: ExaoneMoEConfig, layer_idx: int, is_mtp_layer: bool = False) -> bool:
|
||||
"""
|
||||
Check if the current layer is a MoE layer.
|
||||
"""
|
||||
return hasattr(config, "is_moe_layer") and config.is_moe_layer[layer_idx]
|
||||
return not is_mtp_layer and hasattr(config, "is_moe_layer") and config.is_moe_layer[layer_idx]
|
||||
|
||||
|
||||
def enable_attn_allreduce(mapping: Mapping):
|
||||
@ -75,13 +74,15 @@ class ExaoneMoeAttention(QKNormRoPEAttention):
|
||||
self,
|
||||
model_config: ModelConfig[ExaoneMoEConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
is_mtp_layer: bool = False,
|
||||
fuse_qk_norm_rope: bool = False,
|
||||
disable_deep_gemm: bool = False,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
self.attention_window_size = None
|
||||
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
# A MTP layer uses the global attention.
|
||||
self.is_sliding = not is_mtp_layer and config.layer_types[layer_idx] == "sliding_attention"
|
||||
|
||||
# NOTE: In ExaoneMoe, only sliding layers apply rope.
|
||||
pos_embd_params = None
|
||||
@ -190,10 +191,15 @@ class ExaoneMoeDecoderLayer(DecoderLayer):
|
||||
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
self.self_attn = ExaoneMoeAttention(model_config, layer_idx=layer_idx)
|
||||
is_mtp_layer = False
|
||||
if layer_idx >= config.num_hidden_layers:
|
||||
is_mtp_layer = True
|
||||
self.self_attn = ExaoneMoeAttention(
|
||||
model_config, layer_idx=layer_idx, is_mtp_layer=is_mtp_layer
|
||||
)
|
||||
|
||||
# MoE or Dense layer
|
||||
self.is_moe_layer = check_is_moe(config, layer_idx)
|
||||
self.is_moe_layer = check_is_moe(config, layer_idx, is_mtp_layer)
|
||||
if self.is_moe_layer:
|
||||
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
|
||||
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
|
||||
@ -451,6 +457,115 @@ class ExaoneMoeDecoderLayer(DecoderLayer):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class ExaoneMoeMTPHead(DeepseekV3MTPHead):
|
||||
"""The shared head of the MTP Layer."""
|
||||
|
||||
|
||||
class ExaoneMoeMTP(ExaoneMoeDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[ExaoneMoEConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
):
|
||||
super().__init__(model_config, aux_stream_dict, layer_idx)
|
||||
self.model_config = model_config
|
||||
self.mapping = model_config.mapping
|
||||
config = model_config.pretrained_config
|
||||
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
|
||||
self.event_dict = {key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeShared]}
|
||||
|
||||
# Normalization for input embedding
|
||||
self.enorm = RMSNorm(
|
||||
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
|
||||
)
|
||||
self.hnorm = RMSNorm(
|
||||
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
if model_config.mapping.enable_attention_dp:
|
||||
self.eh_proj = Linear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
|
||||
)
|
||||
else:
|
||||
self.eh_proj = Linear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
mapping=model_config.mapping,
|
||||
reduce_output=True,
|
||||
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
|
||||
)
|
||||
|
||||
self.shared_head = ExaoneMoeMTPHead(model_config=model_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
embed_tokens: Embedding,
|
||||
attn_metadata: AttentionMetadata,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
def norm_embeds():
|
||||
return self.enorm(embed_tokens(input_ids))
|
||||
|
||||
def norm_hidden():
|
||||
return self.hnorm(hidden_states)
|
||||
|
||||
inputs_embeds, hidden_states = maybe_execute_in_parallel(
|
||||
norm_embeds,
|
||||
norm_hidden,
|
||||
self.event_dict[EventType.Main],
|
||||
self.event_dict[EventType.MoeShared],
|
||||
self.aux_stream,
|
||||
)
|
||||
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
|
||||
# Split hidden_states columnwise based on TP
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
tp_rank = self.model_config.mapping.tp_rank
|
||||
|
||||
if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp):
|
||||
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
|
||||
hidden_states = self.eh_proj(hidden_states)
|
||||
|
||||
# Input layer norm
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
all_reduce_params=AllReduceParams(enable_allreduce=not (self.disable_attn_allreduce)),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=tp_size > 1 and not (self.model_config.mapping.enable_attention_dp)
|
||||
),
|
||||
)
|
||||
|
||||
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ExaoneMoeModel(DecoderModel):
|
||||
def __init__(self, model_config: ModelConfig[ExaoneMoEConfig]):
|
||||
super().__init__(model_config)
|
||||
@ -521,54 +636,76 @@ class ExaoneMoeModel(DecoderModel):
|
||||
|
||||
|
||||
@register_auto_model("ExaoneMoEForCausalLM")
|
||||
class ExaoneMoeForCausalLM(DecoderModelForCausalLM[ExaoneMoeModel, ExaoneMoEConfig]):
|
||||
class ExaoneMoeForCausalLM(SpecDecOneEngineForCausalLM[ExaoneMoeModel, ExaoneMoEConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[ExaoneMoEConfig],
|
||||
):
|
||||
if (
|
||||
model_config.spec_config is not None
|
||||
and model_config.spec_config.spec_dec_mode.is_mtp_one_model()
|
||||
):
|
||||
# NOTE: K-EXAONE does not contain the 'num_nextn_predict_layers' field,
|
||||
# which should be equal to 1. Manually set the value here if not present.
|
||||
if not hasattr(model_config.pretrained_config, "num_nextn_predict_layers"):
|
||||
model_config.pretrained_config.num_nextn_predict_layers = 1
|
||||
|
||||
super().__init__(
|
||||
ExaoneMoeModel(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size,
|
||||
model=ExaoneMoeModel(model_config),
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.spec_config is not None
|
||||
and model_config.spec_config.spec_dec_mode.is_mtp_one_model()
|
||||
):
|
||||
model_nextn = model_config.spec_config.num_nextn_predict_layers
|
||||
ckpt_nextn = self.config.num_nextn_predict_layers
|
||||
self.num_hidden_layers = self.config.num_hidden_layers
|
||||
if ckpt_nextn == 0:
|
||||
raise ValueError(
|
||||
"No MTP module is in given checkpoint. Please check if the checkpoint supports the MTP layer "
|
||||
"(`num_nextn_predict_layers`)."
|
||||
)
|
||||
if ckpt_nextn > 1 or model_config.spec_config.use_mtp_vanilla:
|
||||
# modify the QuantConfig to support duplicated mtp layers
|
||||
if model_config.quant_config.exclude_modules is not None:
|
||||
extend_exclude_modules = []
|
||||
for model_mtp_idx in range(
|
||||
self.num_hidden_layers, self.num_hidden_layers + model_nextn
|
||||
):
|
||||
ckpt_mtp_idx = (
|
||||
model_mtp_idx - self.num_hidden_layers
|
||||
) % ckpt_nextn + self.num_hidden_layers
|
||||
model_prefix = f"model.layers.{model_mtp_idx}"
|
||||
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
|
||||
for exclude_module in model_config.quant_config.exclude_modules:
|
||||
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
|
||||
extend_exclude_modules.append(
|
||||
exclude_module.replace(ckpt_prefix, model_prefix)
|
||||
)
|
||||
self.model_config.quant_config.exclude_modules.extend(extend_exclude_modules)
|
||||
self.model.layers.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Dict,
|
||||
weight_mapper: Optional["BaseWeightMapper"] = None, # noqa: F821
|
||||
weight_mapper: Optional[ExaoneMoeWeightMapper] = None, # noqa: F821
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
allow_partial_loading: bool = False,
|
||||
):
|
||||
# MoE naming pattern.
|
||||
moe_weight_patterns = {
|
||||
"gate_proj": "w1",
|
||||
"up_proj": "w3",
|
||||
"down_proj": "w2",
|
||||
}
|
||||
assert isinstance(weight_mapper, ExaoneMoeWeightMapper)
|
||||
|
||||
module_names = list(weights)
|
||||
for name in module_names:
|
||||
if "mlp.e_score_correction_bias" in name:
|
||||
# Move bias into the gate module.
|
||||
new_name = name.replace(
|
||||
"mlp.e_score_correction_bias", "mlp.gate.e_score_correction_bias"
|
||||
)
|
||||
else:
|
||||
# MoE Weight Remapping.
|
||||
new_name = name
|
||||
for k, v in moe_weight_patterns.items():
|
||||
pattern = rf"(experts\.\d+\.){k}\b"
|
||||
new_name = re.sub(pattern, rf"\1{v}", new_name)
|
||||
|
||||
# Remap the name-parameter pair if needed.
|
||||
if new_name != name:
|
||||
weights[new_name] = weights.pop(name)
|
||||
if self.draft_model is not None:
|
||||
weight_mapper.preprocess_weights(weights)
|
||||
|
||||
# Weight renaming MoE is handled in ExaoneMoeWeightMapper.rename_by_params_map
|
||||
super().load_weights(
|
||||
weights=weights,
|
||||
weight_mapper=weight_mapper,
|
||||
skip_modules=skip_modules or [],
|
||||
params_map=weight_mapper.params_map,
|
||||
allow_partial_loading=allow_partial_loading,
|
||||
)
|
||||
|
||||
|
||||
@ -689,6 +689,9 @@ class MTPForCausalLM(nn.Module):
|
||||
case "deepseek_v3" | "deepseek_v32":
|
||||
from .modeling_deepseekv3 import DeepseekV3MTP
|
||||
mtp_layer = DeepseekV3MTP
|
||||
case "exaone_moe":
|
||||
from .modeling_exaone_moe import ExaoneMoeMTP
|
||||
mtp_layer = ExaoneMoeMTP
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for MTP")
|
||||
@ -730,6 +733,10 @@ class MTPDraftModel(nn.Module):
|
||||
layer_idx,
|
||||
aux_stream_dict,
|
||||
is_separate_draft_engine=True)
|
||||
elif model_type in ["exaone_moe"]:
|
||||
from .modeling_exaone_moe import ExaoneMoeMTP
|
||||
mtp_layer = ExaoneMoeMTP(model_config, layer_idx, aux_stream_dict)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MTPDraftModel does not support model_type: {model_type}")
|
||||
@ -803,6 +810,10 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
|
||||
from .modeling_deepseekv3 import DeepseekV3WeightLoader
|
||||
weight_loader = DeepseekV3WeightLoader(self,
|
||||
is_draft_model=True)
|
||||
case "exaone_moe":
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for MTP for two engine mode. Please use one engine mode instead."
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for MTP")
|
||||
@ -842,7 +853,7 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
|
||||
|
||||
|
||||
def get_draft_model(model_config, draft_config, lm_head, model):
|
||||
assert getattr(model_config, 'spec_config', None) != None
|
||||
assert getattr(model_config, 'spec_config', None) is not None
|
||||
spec_dec_mode = model_config.spec_config.spec_dec_mode
|
||||
if spec_dec_mode.is_eagle3_one_model():
|
||||
if model_config.spec_config.eagle3_model_arch == "llama3":
|
||||
@ -946,7 +957,6 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
spec_metadata=spec_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
|
||||
@ -1156,12 +1156,6 @@ def validate_feature_combination(llm_args, model_engine, sampler_type):
|
||||
ERR_MSG_TMPL = "{feature1} and {feature2} enabled together is not supported yet."
|
||||
|
||||
CONFLICT_RULES = [
|
||||
{
|
||||
"features": ["mtp", "slide_window_attention"],
|
||||
"message":
|
||||
ERR_MSG_TMPL.format(feature1="mtp",
|
||||
feature2="slide_window_attention")
|
||||
},
|
||||
{
|
||||
"features": ["trtllm_sampler", "mtp"],
|
||||
"message":
|
||||
|
||||
@ -410,6 +410,53 @@ class ModelLoader:
|
||||
'block.*.attn.out', 'block.*.mlp.gate',
|
||||
'block.*.attn.qkv', 'embedding', 'unembedding'
|
||||
]
|
||||
# NOTE: This is for llm-compressor's quantized checkpoints.
|
||||
elif hf_quant_config.get(
|
||||
"quant_method") == "compressed-tensors":
|
||||
config_groups = hf_quant_config.get("config_groups")
|
||||
if config_groups is None:
|
||||
raise ValueError(
|
||||
f"config_groups is not set in {hf_quant_config}.")
|
||||
|
||||
weights_quant_config = config_groups["group_0"]["weights"]
|
||||
inputs_quant_config = config_groups["group_0"][
|
||||
"input_activations"]
|
||||
weights_quant_strategy = weights_quant_config["strategy"]
|
||||
inputs_quant_strategy = inputs_quant_config["strategy"]
|
||||
|
||||
if weights_quant_config["num_bits"] == 8:
|
||||
if weights_quant_strategy == "channel":
|
||||
if inputs_quant_strategy != "token":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
|
||||
elif weights_quant_strategy == "block":
|
||||
if inputs_quant_strategy != "group":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
group_size = inputs_quant_config["group_size"]
|
||||
|
||||
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
|
||||
if group_size != 128:
|
||||
raise ValueError(
|
||||
f"Unsupported group_size: {group_size}. Supported: 128."
|
||||
)
|
||||
quant_config.group_size = group_size
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
|
||||
"Supported strategies: 'channel', 'block'.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
|
||||
"Supported: 8.")
|
||||
|
||||
quant_config.exclude_modules = hf_quant_config.get(
|
||||
"ignore", [])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported quantization_config: {hf_quant_config}.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user