[None][feat] K-EXAONE MTP support (#10796)

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
Yechan Kim 2026-01-22 13:43:00 +09:00 committed by GitHub
parent 415739711f
commit 70caa779a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 385 additions and 61 deletions

View File

@ -10,6 +10,7 @@ The following is a table of supported models for the PyTorch backend:
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3` |
| `DeepseekV32ForCausalLM` | DeepSeek-V3.2 | `deepseek-ai/DeepSeek-V3.2` |
| `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` |
| `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` |
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b` |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` |

View File

@ -42,6 +42,7 @@ This document shows how to build and run [EXAONE](https://huggingface.co/LGAI-EX
* Expert Parallel (EP) (K-EXAONE only)
* Attention Data Parallel (ADP) (K-EXAONE only)
* Disaggregated Serving
* MTP (Multi Token Prediction)
* FP8
* INT8 & INT4 Weight-Only
* INT8 SmoothQuant
@ -120,13 +121,13 @@ python ../../../llm-api/quickstart_advanced.py \
--tp_size 8 \
--moe_ep_size 8 \
--enable_attention_dp \
--trust_remote_code
--apply_chat_template
```
The output will be like:
```bash
[0] Prompt: 'Hello, my name is', Generated text: ' John Smith, and I am a 28-year-old software developer. I live in the city of San Francisco, California. I work remotely for a tech startup based in Austin, Texas.\n\nI enjoy hiking, reading, and playing the piano. In my free time, I often explore new neighborhoods in San Francisco, trying out new restaurants and cafes.\n\n'
[1] Prompt: 'The capital of France is', Generated text: ' Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris, the capital of France is Paris'
[2] Prompt: 'The future of AI is', Generated text: ' bright.\n</think>\n\nThe future of AI holds immense promise across numerous domains. In healthcare, AI is revolutionizing diagnostics, drug discovery, and personalized treatment plans. In education, AI is enabling adaptive learning platforms that cater to individual learning styles and paces. In environmental science, AI is playing a pivotal role in addressing climate change by optimizing'
[0] Prompt: '<|user|>\nHello, my name is<|endofturn|>\n<|assistant|>\n<think>\n', Generated text: 'Okay, the user started with "Hello, my name is" and then stopped. Hmm, they probably forgot to finish their sentence. \n\nI should figure out what they need. Since they\'re introducing themselves, maybe they want help completing their introduction. Or perhaps they\'re testing how I respond to incomplete messages.\n\nLet me check the context again. The user\'s message is cut off'
[1] Prompt: '<|user|>\nThe capital of France is<|endofturn|>\n<|assistant|>\n<think>\n', Generated text: 'Okay, the user asked, "The capital of France is". Hmm, this seems like a straightforward question. But let me think deeper.\n\nFirst, the user might be testing basic knowledge. Or perhaps they\'re a student learning geography. Alternatively, they could be someone verifying information, maybe for a project or trivia.\n\nWait, the user\'s query is incomplete incomplete.'
[2] Prompt: '<|user|>\nThe future of AI is<|endofturn|>\n<|assistant|>\n<think>\n', Generated text: 'Okay, the user asked, "The future of AI is..." and stopped. Hmm, they probably want me to complete that thought or elaborate on what the future of AI entails.\n\nFirst, I need to figure out what the user is really looking for. They might be a student researching AI, a professional trying to understand industry trends, or just someone curious about where AI is heading.\n\n\n\nSince
```
#### MoE Backend Options
@ -147,10 +148,26 @@ python ../../../llm-api/quickstart_advanced.py \
--tp_size 8 \
--moe_ep_size 8 \
--enable_attention_dp \
--moe_backend CUTLASS \
--trust_remote_code
--moe_backend CUTLASS
```
#### MTP (Multi-Token Prediction)
K-EXAONE has 1 MTP layer. To run with MTP, use [examples/llm-api/quickstart_advanced.py](../../../llm-api/quickstart_advanced.py) with additional options:
```bash
python ../../../llm-api/quickstart_advanced.py \
--model_dir $HF_MODEL_DIR \
--tp_size 8 \
--moe_ep_size 8 \
--enable_attention_dp \
--spec_decode_algo MTP \
--spec_decode_max_draft_len N \
--use_one_model
```
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared.
### PyTorch flow Quantization
For PyTorch flow, TRT-LLM supports quantized formats generated by [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer). You can either use pre-quantized models from the HuggingFace model hub, or generate quantized models yourself using the instructions below.

View File

@ -360,6 +360,50 @@ class ModelConfig(Generic[TConfig]):
else:
quant_config.exclude_modules = default_exclude
# NOTE: This is for llm-compressor's quantized checkpoints.
elif hf_quant_config.get("quant_method") == "compressed-tensors":
config_groups = hf_quant_config.get("config_groups")
if config_groups is None:
raise ValueError(
f"config_groups is not set in {hf_quant_config}.")
weights_quant_config = config_groups["group_0"]["weights"]
inputs_quant_config = config_groups["group_0"]["input_activations"]
weights_quant_strategy = weights_quant_config["strategy"]
inputs_quant_strategy = inputs_quant_config["strategy"]
if weights_quant_config["num_bits"] == 8:
if weights_quant_strategy == "channel":
if inputs_quant_strategy != "token":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
elif weights_quant_strategy == "block":
if inputs_quant_strategy != "group":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
group_size = inputs_quant_config["group_size"]
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
if group_size != 128:
raise ValueError(
f"Unsupported group_size: {group_size}. Supported: 128."
)
quant_config.group_size = group_size
else:
raise ValueError(
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
"Supported strategies: 'channel', 'block'.")
else:
raise ValueError(
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
"Supported: 8.")
quant_config.exclude_modules = hf_quant_config.get("ignore", [])
return quant_config, layer_quant_config
@staticmethod
@ -522,11 +566,17 @@ class ModelConfig(Generic[TConfig]):
attn_tp_size * attn_cp_size)
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
num_layers = self.pretrained_config.num_hidden_layers
num_attention_layers = self.get_num_attention_layers()
if (self.spec_config is not None
and self.spec_config.spec_dec_mode.is_mtp_one_model()):
num_layers += self.spec_config.num_nextn_predict_layers
num_attention_layers += self.spec_config.num_nextn_predict_layers
model_config_cpp = ModelConfigCpp(
vocab_size=self.pretrained_config.vocab_size,
num_layers=self.pretrained_config.num_hidden_layers,
num_attention_layers=self.get_num_attention_layers(),
num_layers=num_layers,
num_attention_layers=num_attention_layers,
num_rnn_layers=0,
num_heads=num_heads,
hidden_size=hidden_size,

View File

@ -0,0 +1,67 @@
from torch import nn
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
@register_mapper("HF", "ExaoneMoEForCausalLM")
class ExaoneMoeWeightMapper(HfWeightMapper):
def __init__(self):
super().__init__()
# MoE expert weights: gate_proj->w1, up_proj->w3, down_proj->w2
# e_score_correction_bias: move into gate module
self.params_map = {
r"(.*experts\.\d+\.)gate_proj(.*)": r"\1w1\2",
r"(.*experts\.\d+\.)up_proj(.*)": r"\1w3\2",
r"(.*experts\.\d+\.)down_proj(.*)": r"\1w2\2",
r"(.*)mlp\.e_score_correction_bias(.*)": r"\1mlp.gate.e_score_correction_bias\2",
}
self.mtp_mapping = {
"mtp.fc": "eh_proj",
"mtp.norm": "shared_head.norm",
"mtp.pre_fc_norm_embedding": "enorm",
"mtp.pre_fc_norm_hidden": "hnorm",
}
def preprocess_weights(self, weights: dict):
mtp_layer_offset = self.config.pretrained_config.num_hidden_layers
for name in weights.keys():
if name.startswith("mtp.layers."):
# mtp.layers.{idx}.* -> model.layers.{offset + idx}.*
_, _, mtp_layer_idx, module_name = name.split(".", 3)
new_name = f"model.layers.{mtp_layer_offset + int(mtp_layer_idx)}.{module_name}"
weights[new_name] = weights.pop(name)
elif name.startswith("mtp."):
# mtp.fc.* -> model.layers.{offset}.eh_proj.*
# mtp.norm.* -> model.layers.{offset}.shared_head.norm.*
# etc.
for mtp_prefix, trtllm_name in self.mtp_mapping.items():
if name.startswith(mtp_prefix):
suffix = name[len(mtp_prefix) :]
new_name = f"model.layers.{mtp_layer_offset}.{trtllm_name}{suffix}"
weights[new_name] = weights.pop(name)
break
def is_special_instance_module(self, module: nn.Module) -> bool:
return isinstance(module, MoE)
def handle_special_instance_module(
self,
module: nn.Module,
module_name: str,
module_weights: dict,
allow_partial_loading: bool = False,
) -> None:
if isinstance(module, MoE):
updated_module_weights = {}
for weight_name, weight_value in module_weights.items():
new_weight_name = weight_name.replace("weight_scale", "weight_scale_inv")
if new_weight_name.endswith(".weight_scale_inv"):
weight_value = weight_value.squeeze()
updated_module_weights[new_weight_name] = weight_value
module.load_weights(
weights=[updated_module_weights], allow_partial_loading=allow_partial_loading
)

View File

@ -32,7 +32,8 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
if hasattr(config.pretrained_config, "draft_vocab_size"):
model_arch = "EAGLE3" + model_arch
if model_arch in (
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM",
"ExaoneMoEForCausalLM"
) and config.spec_config is not None and config.spec_config.max_draft_len == 0:
model_arch = "MTPDraftModelForCausalLM"

View File

@ -1,6 +1,5 @@
import math
import os
import re
from typing import Dict, List, Optional, Tuple
import torch
@ -32,15 +31,15 @@ from ..models.modeling_deepseekv3 import Deepseekv3MoE
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.linear import Linear, TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..utils import AuxStreamType, Fp4QuantizedTensor
from .modeling_utils import (
DecoderModel,
DecoderModelForCausalLM,
EagerFusionConfig,
register_auto_model,
)
from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .checkpoints.hf.exaone_moe_weight_mapper import ExaoneMoeWeightMapper
from .modeling_deepseekv3 import DeepseekV3MTPHead
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
# fmt: off
# TODO: Remove this once we have a proper transformers package
@ -59,11 +58,11 @@ AutoConfig.register(ExaoneMoEConfig.model_type, ExaoneMoEConfig)
# fmt: on
def check_is_moe(config: ExaoneMoEConfig, layer_idx: int) -> bool:
def check_is_moe(config: ExaoneMoEConfig, layer_idx: int, is_mtp_layer: bool = False) -> bool:
"""
Check if the current layer is a MoE layer.
"""
return hasattr(config, "is_moe_layer") and config.is_moe_layer[layer_idx]
return not is_mtp_layer and hasattr(config, "is_moe_layer") and config.is_moe_layer[layer_idx]
def enable_attn_allreduce(mapping: Mapping):
@ -75,13 +74,15 @@ class ExaoneMoeAttention(QKNormRoPEAttention):
self,
model_config: ModelConfig[ExaoneMoEConfig],
layer_idx: Optional[int] = None,
is_mtp_layer: bool = False,
fuse_qk_norm_rope: bool = False,
disable_deep_gemm: bool = False,
):
config = model_config.pretrained_config
self.attention_window_size = None
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
# A MTP layer uses the global attention.
self.is_sliding = not is_mtp_layer and config.layer_types[layer_idx] == "sliding_attention"
# NOTE: In ExaoneMoe, only sliding layers apply rope.
pos_embd_params = None
@ -190,10 +191,15 @@ class ExaoneMoeDecoderLayer(DecoderLayer):
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.self_attn = ExaoneMoeAttention(model_config, layer_idx=layer_idx)
is_mtp_layer = False
if layer_idx >= config.num_hidden_layers:
is_mtp_layer = True
self.self_attn = ExaoneMoeAttention(
model_config, layer_idx=layer_idx, is_mtp_layer=is_mtp_layer
)
# MoE or Dense layer
self.is_moe_layer = check_is_moe(config, layer_idx)
self.is_moe_layer = check_is_moe(config, layer_idx, is_mtp_layer)
if self.is_moe_layer:
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
@ -451,6 +457,115 @@ class ExaoneMoeDecoderLayer(DecoderLayer):
return hidden_states, residual
class ExaoneMoeMTPHead(DeepseekV3MTPHead):
"""The shared head of the MTP Layer."""
class ExaoneMoeMTP(ExaoneMoeDecoderLayer):
def __init__(
self,
model_config: ModelConfig[ExaoneMoEConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
):
super().__init__(model_config, aux_stream_dict, layer_idx)
self.model_config = model_config
self.mapping = model_config.mapping
config = model_config.pretrained_config
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeShared]}
# Normalization for input embedding
self.enorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.hnorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
if model_config.mapping.enable_attention_dp:
self.eh_proj = Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
)
else:
self.eh_proj = Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
tensor_parallel_mode=TensorParallelMode.ROW,
mapping=model_config.mapping,
reduce_output=True,
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
)
self.shared_head = ExaoneMoeMTPHead(model_config=model_config)
def forward(
self,
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
def norm_embeds():
return self.enorm(embed_tokens(input_ids))
def norm_hidden():
return self.hnorm(hidden_states)
inputs_embeds, hidden_states = maybe_execute_in_parallel(
norm_embeds,
norm_hidden,
self.event_dict[EventType.Main],
self.event_dict[EventType.MoeShared],
self.aux_stream,
)
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
# Split hidden_states columnwise based on TP
tp_size = self.model_config.mapping.tp_size
tp_rank = self.model_config.mapping.tp_rank
if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp):
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
hidden_states = self.eh_proj(hidden_states)
# Input layer norm
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=tp_size > 1 and not (self.model_config.mapping.enable_attention_dp)
),
)
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
return hidden_states
class ExaoneMoeModel(DecoderModel):
def __init__(self, model_config: ModelConfig[ExaoneMoEConfig]):
super().__init__(model_config)
@ -521,54 +636,76 @@ class ExaoneMoeModel(DecoderModel):
@register_auto_model("ExaoneMoEForCausalLM")
class ExaoneMoeForCausalLM(DecoderModelForCausalLM[ExaoneMoeModel, ExaoneMoEConfig]):
class ExaoneMoeForCausalLM(SpecDecOneEngineForCausalLM[ExaoneMoeModel, ExaoneMoEConfig]):
def __init__(
self,
model_config: ModelConfig[ExaoneMoEConfig],
):
if (
model_config.spec_config is not None
and model_config.spec_config.spec_dec_mode.is_mtp_one_model()
):
# NOTE: K-EXAONE does not contain the 'num_nextn_predict_layers' field,
# which should be equal to 1. Manually set the value here if not present.
if not hasattr(model_config.pretrained_config, "num_nextn_predict_layers"):
model_config.pretrained_config.num_nextn_predict_layers = 1
super().__init__(
ExaoneMoeModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
model=ExaoneMoeModel(model_config),
model_config=model_config,
)
if (
model_config.spec_config is not None
and model_config.spec_config.spec_dec_mode.is_mtp_one_model()
):
model_nextn = model_config.spec_config.num_nextn_predict_layers
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
if ckpt_nextn == 0:
raise ValueError(
"No MTP module is in given checkpoint. Please check if the checkpoint supports the MTP layer "
"(`num_nextn_predict_layers`)."
)
if ckpt_nextn > 1 or model_config.spec_config.use_mtp_vanilla:
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:
extend_exclude_modules = []
for model_mtp_idx in range(
self.num_hidden_layers, self.num_hidden_layers + model_nextn
):
ckpt_mtp_idx = (
model_mtp_idx - self.num_hidden_layers
) % ckpt_nextn + self.num_hidden_layers
model_prefix = f"model.layers.{model_mtp_idx}"
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
for exclude_module in model_config.quant_config.exclude_modules:
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
extend_exclude_modules.append(
exclude_module.replace(ckpt_prefix, model_prefix)
)
self.model_config.quant_config.exclude_modules.extend(extend_exclude_modules)
self.model.layers.extend(self.draft_model.mtp_layers)
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)
def load_weights(
self,
weights: Dict,
weight_mapper: Optional["BaseWeightMapper"] = None, # noqa: F821
weight_mapper: Optional[ExaoneMoeWeightMapper] = None, # noqa: F821
skip_modules: Optional[List[str]] = None,
allow_partial_loading: bool = False,
):
# MoE naming pattern.
moe_weight_patterns = {
"gate_proj": "w1",
"up_proj": "w3",
"down_proj": "w2",
}
assert isinstance(weight_mapper, ExaoneMoeWeightMapper)
module_names = list(weights)
for name in module_names:
if "mlp.e_score_correction_bias" in name:
# Move bias into the gate module.
new_name = name.replace(
"mlp.e_score_correction_bias", "mlp.gate.e_score_correction_bias"
)
else:
# MoE Weight Remapping.
new_name = name
for k, v in moe_weight_patterns.items():
pattern = rf"(experts\.\d+\.){k}\b"
new_name = re.sub(pattern, rf"\1{v}", new_name)
# Remap the name-parameter pair if needed.
if new_name != name:
weights[new_name] = weights.pop(name)
if self.draft_model is not None:
weight_mapper.preprocess_weights(weights)
# Weight renaming MoE is handled in ExaoneMoeWeightMapper.rename_by_params_map
super().load_weights(
weights=weights,
weight_mapper=weight_mapper,
skip_modules=skip_modules or [],
params_map=weight_mapper.params_map,
allow_partial_loading=allow_partial_loading,
)

View File

@ -689,6 +689,9 @@ class MTPForCausalLM(nn.Module):
case "deepseek_v3" | "deepseek_v32":
from .modeling_deepseekv3 import DeepseekV3MTP
mtp_layer = DeepseekV3MTP
case "exaone_moe":
from .modeling_exaone_moe import ExaoneMoeMTP
mtp_layer = ExaoneMoeMTP
case _:
raise ValueError(
f"Model type {model_type} not supported for MTP")
@ -730,6 +733,10 @@ class MTPDraftModel(nn.Module):
layer_idx,
aux_stream_dict,
is_separate_draft_engine=True)
elif model_type in ["exaone_moe"]:
from .modeling_exaone_moe import ExaoneMoeMTP
mtp_layer = ExaoneMoeMTP(model_config, layer_idx, aux_stream_dict)
else:
raise ValueError(
f"MTPDraftModel does not support model_type: {model_type}")
@ -803,6 +810,10 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
from .modeling_deepseekv3 import DeepseekV3WeightLoader
weight_loader = DeepseekV3WeightLoader(self,
is_draft_model=True)
case "exaone_moe":
raise ValueError(
f"Model type {model_type} not supported for MTP for two engine mode. Please use one engine mode instead."
)
case _:
raise ValueError(
f"Model type {model_type} not supported for MTP")
@ -842,7 +853,7 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
def get_draft_model(model_config, draft_config, lm_head, model):
assert getattr(model_config, 'spec_config', None) != None
assert getattr(model_config, 'spec_config', None) is not None
spec_dec_mode = model_config.spec_config.spec_dec_mode
if spec_dec_mode.is_eagle3_one_model():
if model_config.spec_config.eagle3_model_arch == "llama3":
@ -946,7 +957,6 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
spec_metadata=spec_metadata,
**kwargs,
)
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(self.layer_idx,

View File

@ -1156,12 +1156,6 @@ def validate_feature_combination(llm_args, model_engine, sampler_type):
ERR_MSG_TMPL = "{feature1} and {feature2} enabled together is not supported yet."
CONFLICT_RULES = [
{
"features": ["mtp", "slide_window_attention"],
"message":
ERR_MSG_TMPL.format(feature1="mtp",
feature2="slide_window_attention")
},
{
"features": ["trtllm_sampler", "mtp"],
"message":

View File

@ -410,6 +410,53 @@ class ModelLoader:
'block.*.attn.out', 'block.*.mlp.gate',
'block.*.attn.qkv', 'embedding', 'unembedding'
]
# NOTE: This is for llm-compressor's quantized checkpoints.
elif hf_quant_config.get(
"quant_method") == "compressed-tensors":
config_groups = hf_quant_config.get("config_groups")
if config_groups is None:
raise ValueError(
f"config_groups is not set in {hf_quant_config}.")
weights_quant_config = config_groups["group_0"]["weights"]
inputs_quant_config = config_groups["group_0"][
"input_activations"]
weights_quant_strategy = weights_quant_config["strategy"]
inputs_quant_strategy = inputs_quant_config["strategy"]
if weights_quant_config["num_bits"] == 8:
if weights_quant_strategy == "channel":
if inputs_quant_strategy != "token":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
elif weights_quant_strategy == "block":
if inputs_quant_strategy != "group":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
group_size = inputs_quant_config["group_size"]
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
if group_size != 128:
raise ValueError(
f"Unsupported group_size: {group_size}. Supported: 128."
)
quant_config.group_size = group_size
else:
raise ValueError(
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
"Supported strategies: 'channel', 'block'.")
else:
raise ValueError(
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
"Supported: 8.")
quant_config.exclude_modules = hf_quant_config.get(
"ignore", [])
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")