mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add Qwen3 MoE support to TensorRT backend (#6470)
Signed-off-by: gkswns0531 <gkswns0531@gmail.com> Signed-off-by: hanjuncho <gkswns0531@gmail.com> Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
parent
0ff8df95b7
commit
80f918cc22
@ -196,6 +196,7 @@ MODEL_MAP = {
|
||||
'Qwen2VLForConditionalGeneration': QWenForCausalLM,
|
||||
'Qwen2VLModel': QWenForCausalLM,
|
||||
'Qwen3ForCausalLM': QWenForCausalLM,
|
||||
'Qwen3MoeForCausalLM': QWenForCausalLM,
|
||||
'WhisperEncoder': WhisperEncoder,
|
||||
'EncoderModel': EncoderModel,
|
||||
'DecoderModel': DecoderModel,
|
||||
|
||||
@ -32,6 +32,8 @@ class QWenConfig(PretrainedConfig):
|
||||
use_logn_attn: bool = False,
|
||||
moe: Optional[Union[MoeConfig, dict]] = None,
|
||||
num_labels: int = 1,
|
||||
mlp_only_layers: Optional[list] = None,
|
||||
decoder_sparse_step: int = 1,
|
||||
**kwargs):
|
||||
self.mlp_bias = mlp_bias
|
||||
self.attn_bias = attn_bias
|
||||
@ -40,6 +42,8 @@ class QWenConfig(PretrainedConfig):
|
||||
self.disable_weight_only_quant_plugin = disable_weight_only_quant_plugin
|
||||
self.num_labels = num_labels
|
||||
self.use_logn_attn = use_logn_attn
|
||||
self.mlp_only_layers = mlp_only_layers or []
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
if moe is None:
|
||||
# Legacy MOE config fields
|
||||
moe = MoeConfig(num_experts=kwargs.pop('moe_num_experts', 0),
|
||||
@ -64,6 +68,8 @@ class QWenConfig(PretrainedConfig):
|
||||
output[
|
||||
'disable_weight_only_quant_plugin'] = self.disable_weight_only_quant_plugin
|
||||
output['use_logn_attn'] = self.use_logn_attn
|
||||
output['mlp_only_layers'] = self.mlp_only_layers
|
||||
output['decoder_sparse_step'] = self.decoder_sparse_step
|
||||
output['moe'] = self.moe.to_dict()
|
||||
return output
|
||||
|
||||
@ -114,7 +120,7 @@ class QWenConfig(PretrainedConfig):
|
||||
hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
head_size = getattr(hf_config, "kv_channels", head_dim)
|
||||
hidden_act = getattr(hf_config, "hidden_act", "silu")
|
||||
if qwen_type == "qwen2_moe":
|
||||
if qwen_type in ("qwen2_moe", "qwen3_moe"):
|
||||
hidden_act = "swiglu"
|
||||
|
||||
# Qwen3 models have no attention bias, while legacy models have bias
|
||||
@ -144,6 +150,11 @@ class QWenConfig(PretrainedConfig):
|
||||
moe_shared_expert_intermediate_size = getattr(
|
||||
hf_config, "shared_expert_intermediate_size", 0)
|
||||
moe_normalization_mode = MoeConfig.ExpertScaleNormalizationMode.NONE
|
||||
|
||||
# Add support for mlp_only_layers and decoder_sparse_step (Qwen3 MoE)
|
||||
mlp_only_layers = getattr(hf_config, "mlp_only_layers", [])
|
||||
decoder_sparse_step = getattr(hf_config, "decoder_sparse_step", 1)
|
||||
|
||||
moe_config = MoeConfig(num_experts=moe_num_experts,
|
||||
top_k=moe_top_k,
|
||||
normalization_mode=moe_normalization_mode)
|
||||
@ -189,6 +200,8 @@ class QWenConfig(PretrainedConfig):
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
moe_shared_expert_intermediate_size=
|
||||
moe_shared_expert_intermediate_size,
|
||||
mlp_only_layers=mlp_only_layers,
|
||||
decoder_sparse_step=decoder_sparse_step,
|
||||
moe=moe_config,
|
||||
mapping=mapping,
|
||||
quantization=quant_config,
|
||||
|
||||
@ -714,57 +714,58 @@ def convert_hf_qwen(hf_model,
|
||||
dtype,
|
||||
use_gemm_woq_plugin))
|
||||
|
||||
if qwen_type == "qwen2_moe" and moe_config and moe_config.has_moe():
|
||||
if moe_config and moe_config.has_moe():
|
||||
if qwen_type == "qwen2_moe":
|
||||
# shared_expert for qwen2_moe
|
||||
shared_expert_up_proj = model_params[
|
||||
f'model.layers.{l}.mlp.shared_expert.up_proj.weight']
|
||||
shared_expert_down_proj = model_params[
|
||||
f'model.layers.{l}.mlp.shared_expert.down_proj.weight']
|
||||
shared_expert_gate = model_params[
|
||||
f'model.layers.{l}.mlp.shared_expert.gate_proj.weight']
|
||||
shared_expert_up_proj = split(shared_expert_up_proj,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0)
|
||||
shared_expert_down_proj = split(shared_expert_down_proj,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=1)
|
||||
shared_expert_gate = split(shared_expert_gate,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0)
|
||||
shared_expert_gate_up_proj = torch.concat(
|
||||
[shared_expert_up_proj, shared_expert_gate],
|
||||
dim=-2).to(dtype)
|
||||
|
||||
# shared_expert for qwen2_moe
|
||||
shared_expert_up_proj = model_params[
|
||||
f'model.layers.{l}.mlp.shared_expert.up_proj.weight']
|
||||
shared_expert_down_proj = model_params[
|
||||
f'model.layers.{l}.mlp.shared_expert.down_proj.weight']
|
||||
shared_expert_gate = model_params[
|
||||
f'model.layers.{l}.mlp.shared_expert.gate_proj.weight']
|
||||
shared_expert_up_proj = split(shared_expert_up_proj,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0)
|
||||
shared_expert_down_proj = split(shared_expert_down_proj,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=1)
|
||||
shared_expert_gate = split(shared_expert_gate,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0)
|
||||
shared_expert_gate_up_proj = torch.concat(
|
||||
[shared_expert_up_proj, shared_expert_gate], dim=-2).to(dtype)
|
||||
## mlp.shared_expert.gate_up_proj.weight
|
||||
weights.update(
|
||||
get_tllm_linear_weight(shared_expert_gate_up_proj,
|
||||
tllm_prex + 'mlp.shared_expert.fc.',
|
||||
None, use_weight_only,
|
||||
plugin_weight_only_quant_type, dtype,
|
||||
use_gemm_woq_plugin))
|
||||
|
||||
## mlp.shared_expert.gate_up_proj.weight
|
||||
weights.update(
|
||||
get_tllm_linear_weight(shared_expert_gate_up_proj,
|
||||
tllm_prex + 'mlp.shared_expert.fc.',
|
||||
None, use_weight_only,
|
||||
plugin_weight_only_quant_type, dtype,
|
||||
use_gemm_woq_plugin))
|
||||
## mlp.shared_expert.down_proj.weight
|
||||
weights.update(
|
||||
get_tllm_linear_weight(
|
||||
shared_expert_down_proj.to(dtype),
|
||||
tllm_prex + 'mlp.shared_expert.proj.', None,
|
||||
use_weight_only, plugin_weight_only_quant_type, dtype,
|
||||
use_gemm_woq_plugin))
|
||||
|
||||
## mlp.shared_expert.down_proj.weight
|
||||
weights.update(
|
||||
get_tllm_linear_weight(shared_expert_down_proj.to(dtype),
|
||||
tllm_prex + 'mlp.shared_expert.proj.',
|
||||
None, use_weight_only,
|
||||
plugin_weight_only_quant_type, dtype,
|
||||
use_gemm_woq_plugin))
|
||||
|
||||
moe_shared_expert_gate_weights = get_weight(
|
||||
model_params, prefix + 'mlp.shared_expert_gate', dtype)
|
||||
weights.update(
|
||||
get_tllm_linear_weight(
|
||||
moe_shared_expert_gate_weights,
|
||||
tllm_prex + 'mlp.shared_expert_gate.',
|
||||
None,
|
||||
False, # Router should never be quantized
|
||||
plugin_weight_only_quant_type,
|
||||
dtype,
|
||||
use_gemm_woq_plugin))
|
||||
moe_shared_expert_gate_weights = get_weight(
|
||||
model_params, prefix + 'mlp.shared_expert_gate', dtype)
|
||||
weights.update(
|
||||
get_tllm_linear_weight(
|
||||
moe_shared_expert_gate_weights,
|
||||
tllm_prex + 'mlp.shared_expert_gate.',
|
||||
None,
|
||||
False, # Router should never be quantized
|
||||
plugin_weight_only_quant_type,
|
||||
dtype,
|
||||
use_gemm_woq_plugin))
|
||||
|
||||
## fine-grained experts
|
||||
rank_experts = list(range(moe_config.num_experts))
|
||||
@ -811,6 +812,7 @@ def convert_hf_qwen(hf_model,
|
||||
plugin_weight_only_quant_type,
|
||||
dtype,
|
||||
use_gemm_woq_plugin))
|
||||
|
||||
else:
|
||||
mlp_gate_weight = get_weight(model_params, prefix + key_list[2],
|
||||
dtype)
|
||||
|
||||
@ -90,11 +90,15 @@ class QWenDecoderLayer(Module):
|
||||
if config.moe.has_moe():
|
||||
mlp_kwargs = {'moe_config': config.moe, 'mapping': config.mapping}
|
||||
if config.qwen_type == 'qwen2_moe':
|
||||
# Qwen2 MoE uses SharedMoE with shared expert
|
||||
ClsMLP = SharedMoE
|
||||
mlp_kwargs['use_shared_gate'] = True
|
||||
mlp_kwargs['use_side_stream'] = True
|
||||
mlp_kwargs['moe_config'].shared_expert_intermediate_size = \
|
||||
config.moe_shared_expert_intermediate_size
|
||||
elif config.qwen_type == 'qwen3_moe':
|
||||
# Qwen3 MoE uses standard MOE without shared expert
|
||||
ClsMLP = MOE
|
||||
else:
|
||||
ClsMLP = MOE
|
||||
else:
|
||||
@ -104,7 +108,7 @@ class QWenDecoderLayer(Module):
|
||||
# Qwen's real inter_size depends on qwen_type
|
||||
if self.config.qwen_type == 'qwen':
|
||||
intermediate_size = config.intermediate_size // 2
|
||||
elif self.config.qwen_type == 'qwen2_moe':
|
||||
elif self.config.qwen_type in ('qwen2_moe', 'qwen3_moe'):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
else:
|
||||
intermediate_size = config.intermediate_size
|
||||
@ -264,18 +268,11 @@ class QWenForCausalLM(DecoderModelForCausalLM):
|
||||
"mlp_4h_to_h": "mlp.c_proj",
|
||||
"mlp_gate": "w1",
|
||||
}
|
||||
elif config.qwen_type == 'qwen2_moe':
|
||||
elif config.qwen_type in ('qwen2_moe', 'qwen3_moe'):
|
||||
self.trtllm_modules_to_hf_modules = copy.copy(
|
||||
get_default_trtllm_modules_to_hf_modules())
|
||||
# Common MoE expert mappings for both Qwen2 and Qwen3 MoE
|
||||
self.trtllm_modules_to_hf_modules.update({
|
||||
"mlp_h_to_4h":
|
||||
"mlp.shared_expert.gate_proj",
|
||||
"mlp_4h_to_h":
|
||||
"mlp.shared_expert.down_proj",
|
||||
"mlp_gate":
|
||||
"mlp.shared_expert.up_proj",
|
||||
"mlp_router":
|
||||
"mlp.shared_expert_gate",
|
||||
"moe_h_to_4h":
|
||||
"mlp.experts.gate_proj",
|
||||
"moe_4h_to_h":
|
||||
@ -283,6 +280,18 @@ class QWenForCausalLM(DecoderModelForCausalLM):
|
||||
"moe_gate":
|
||||
"mlp.experts.up_proj",
|
||||
})
|
||||
# Qwen2 MoE additionally has shared expert
|
||||
if config.qwen_type == 'qwen2_moe':
|
||||
self.trtllm_modules_to_hf_modules.update({
|
||||
"mlp_h_to_4h":
|
||||
"mlp.shared_expert.gate_proj",
|
||||
"mlp_4h_to_h":
|
||||
"mlp.shared_expert.down_proj",
|
||||
"mlp_gate":
|
||||
"mlp.shared_expert.up_proj",
|
||||
"mlp_router":
|
||||
"mlp.shared_expert_gate",
|
||||
})
|
||||
else:
|
||||
self.trtllm_modules_to_hf_modules = None
|
||||
super().__init__(config, transformer, lm_head)
|
||||
@ -343,6 +352,12 @@ class QWenForCausalLM(DecoderModelForCausalLM):
|
||||
"mlp.shared_expert_gate": "mlp.shared_expert_gate",
|
||||
"fc": ["up_proj", "gate_proj"],
|
||||
}
|
||||
elif config.qwen_type == "qwen3_moe":
|
||||
custom_dict = {
|
||||
"fc": ["up_proj", "gate_proj"],
|
||||
"q_layernorm": "q_norm",
|
||||
"k_layernorm": "k_norm",
|
||||
}
|
||||
elif config.qwen_type in {"qwen2", "qwen2_vl"
|
||||
} and config.tie_word_embeddings:
|
||||
custom_dict = {"lm_head": "model.embed_tokens"}
|
||||
@ -360,7 +375,7 @@ class QWenForCausalLM(DecoderModelForCausalLM):
|
||||
"transformer": "language_model.model",
|
||||
"lm_head": "language_model.lm_head",
|
||||
}
|
||||
elif config.qwen_type in ("qwen3", "qwen3_moe"):
|
||||
elif config.qwen_type == "qwen3":
|
||||
custom_dict = {
|
||||
"q_layernorm": "q_norm",
|
||||
"k_layernorm": "k_norm",
|
||||
@ -412,7 +427,7 @@ class QWenForCausalLM(DecoderModelForCausalLM):
|
||||
loader.load(tllm_key,
|
||||
custom_postprocess_kwargs=arg_dict))
|
||||
loader.fill(tllm_weights)
|
||||
elif config.qwen_type == "qwen2_moe":
|
||||
elif config.qwen_type in ("qwen2_moe", "qwen3_moe"):
|
||||
for tllm_key, _ in model.named_parameters():
|
||||
sub_module = model
|
||||
for attr in tllm_key.split(".")[:-1]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user