[None][feat] Add Qwen3 MoE support to TensorRT backend (#6470)

Signed-off-by: gkswns0531 <gkswns0531@gmail.com>
Signed-off-by: hanjuncho <gkswns0531@gmail.com>
Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
Hanjun Cho 2025-08-06 18:02:35 +09:00 committed by GitHub
parent 0ff8df95b7
commit 80f918cc22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 61 deletions

View File

@ -196,6 +196,7 @@ MODEL_MAP = {
'Qwen2VLForConditionalGeneration': QWenForCausalLM,
'Qwen2VLModel': QWenForCausalLM,
'Qwen3ForCausalLM': QWenForCausalLM,
'Qwen3MoeForCausalLM': QWenForCausalLM,
'WhisperEncoder': WhisperEncoder,
'EncoderModel': EncoderModel,
'DecoderModel': DecoderModel,

View File

@ -32,6 +32,8 @@ class QWenConfig(PretrainedConfig):
use_logn_attn: bool = False,
moe: Optional[Union[MoeConfig, dict]] = None,
num_labels: int = 1,
mlp_only_layers: Optional[list] = None,
decoder_sparse_step: int = 1,
**kwargs):
self.mlp_bias = mlp_bias
self.attn_bias = attn_bias
@ -40,6 +42,8 @@ class QWenConfig(PretrainedConfig):
self.disable_weight_only_quant_plugin = disable_weight_only_quant_plugin
self.num_labels = num_labels
self.use_logn_attn = use_logn_attn
self.mlp_only_layers = mlp_only_layers or []
self.decoder_sparse_step = decoder_sparse_step
if moe is None:
# Legacy MOE config fields
moe = MoeConfig(num_experts=kwargs.pop('moe_num_experts', 0),
@ -64,6 +68,8 @@ class QWenConfig(PretrainedConfig):
output[
'disable_weight_only_quant_plugin'] = self.disable_weight_only_quant_plugin
output['use_logn_attn'] = self.use_logn_attn
output['mlp_only_layers'] = self.mlp_only_layers
output['decoder_sparse_step'] = self.decoder_sparse_step
output['moe'] = self.moe.to_dict()
return output
@ -114,7 +120,7 @@ class QWenConfig(PretrainedConfig):
hf_config.hidden_size // hf_config.num_attention_heads)
head_size = getattr(hf_config, "kv_channels", head_dim)
hidden_act = getattr(hf_config, "hidden_act", "silu")
if qwen_type == "qwen2_moe":
if qwen_type in ("qwen2_moe", "qwen3_moe"):
hidden_act = "swiglu"
# Qwen3 models have no attention bias, while legacy models have bias
@ -144,6 +150,11 @@ class QWenConfig(PretrainedConfig):
moe_shared_expert_intermediate_size = getattr(
hf_config, "shared_expert_intermediate_size", 0)
moe_normalization_mode = MoeConfig.ExpertScaleNormalizationMode.NONE
# Add support for mlp_only_layers and decoder_sparse_step (Qwen3 MoE)
mlp_only_layers = getattr(hf_config, "mlp_only_layers", [])
decoder_sparse_step = getattr(hf_config, "decoder_sparse_step", 1)
moe_config = MoeConfig(num_experts=moe_num_experts,
top_k=moe_top_k,
normalization_mode=moe_normalization_mode)
@ -189,6 +200,8 @@ class QWenConfig(PretrainedConfig):
moe_intermediate_size=moe_intermediate_size,
moe_shared_expert_intermediate_size=
moe_shared_expert_intermediate_size,
mlp_only_layers=mlp_only_layers,
decoder_sparse_step=decoder_sparse_step,
moe=moe_config,
mapping=mapping,
quantization=quant_config,

View File

@ -714,57 +714,58 @@ def convert_hf_qwen(hf_model,
dtype,
use_gemm_woq_plugin))
if qwen_type == "qwen2_moe" and moe_config and moe_config.has_moe():
if moe_config and moe_config.has_moe():
if qwen_type == "qwen2_moe":
# shared_expert for qwen2_moe
shared_expert_up_proj = model_params[
f'model.layers.{l}.mlp.shared_expert.up_proj.weight']
shared_expert_down_proj = model_params[
f'model.layers.{l}.mlp.shared_expert.down_proj.weight']
shared_expert_gate = model_params[
f'model.layers.{l}.mlp.shared_expert.gate_proj.weight']
shared_expert_up_proj = split(shared_expert_up_proj,
mapping.tp_size,
mapping.tp_rank,
dim=0)
shared_expert_down_proj = split(shared_expert_down_proj,
mapping.tp_size,
mapping.tp_rank,
dim=1)
shared_expert_gate = split(shared_expert_gate,
mapping.tp_size,
mapping.tp_rank,
dim=0)
shared_expert_gate_up_proj = torch.concat(
[shared_expert_up_proj, shared_expert_gate],
dim=-2).to(dtype)
# shared_expert for qwen2_moe
shared_expert_up_proj = model_params[
f'model.layers.{l}.mlp.shared_expert.up_proj.weight']
shared_expert_down_proj = model_params[
f'model.layers.{l}.mlp.shared_expert.down_proj.weight']
shared_expert_gate = model_params[
f'model.layers.{l}.mlp.shared_expert.gate_proj.weight']
shared_expert_up_proj = split(shared_expert_up_proj,
mapping.tp_size,
mapping.tp_rank,
dim=0)
shared_expert_down_proj = split(shared_expert_down_proj,
mapping.tp_size,
mapping.tp_rank,
dim=1)
shared_expert_gate = split(shared_expert_gate,
mapping.tp_size,
mapping.tp_rank,
dim=0)
shared_expert_gate_up_proj = torch.concat(
[shared_expert_up_proj, shared_expert_gate], dim=-2).to(dtype)
## mlp.shared_expert.gate_up_proj.weight
weights.update(
get_tllm_linear_weight(shared_expert_gate_up_proj,
tllm_prex + 'mlp.shared_expert.fc.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
## mlp.shared_expert.gate_up_proj.weight
weights.update(
get_tllm_linear_weight(shared_expert_gate_up_proj,
tllm_prex + 'mlp.shared_expert.fc.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
## mlp.shared_expert.down_proj.weight
weights.update(
get_tllm_linear_weight(
shared_expert_down_proj.to(dtype),
tllm_prex + 'mlp.shared_expert.proj.', None,
use_weight_only, plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
## mlp.shared_expert.down_proj.weight
weights.update(
get_tllm_linear_weight(shared_expert_down_proj.to(dtype),
tllm_prex + 'mlp.shared_expert.proj.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
moe_shared_expert_gate_weights = get_weight(
model_params, prefix + 'mlp.shared_expert_gate', dtype)
weights.update(
get_tllm_linear_weight(
moe_shared_expert_gate_weights,
tllm_prex + 'mlp.shared_expert_gate.',
None,
False, # Router should never be quantized
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin))
moe_shared_expert_gate_weights = get_weight(
model_params, prefix + 'mlp.shared_expert_gate', dtype)
weights.update(
get_tllm_linear_weight(
moe_shared_expert_gate_weights,
tllm_prex + 'mlp.shared_expert_gate.',
None,
False, # Router should never be quantized
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin))
## fine-grained experts
rank_experts = list(range(moe_config.num_experts))
@ -811,6 +812,7 @@ def convert_hf_qwen(hf_model,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin))
else:
mlp_gate_weight = get_weight(model_params, prefix + key_list[2],
dtype)

View File

@ -90,11 +90,15 @@ class QWenDecoderLayer(Module):
if config.moe.has_moe():
mlp_kwargs = {'moe_config': config.moe, 'mapping': config.mapping}
if config.qwen_type == 'qwen2_moe':
# Qwen2 MoE uses SharedMoE with shared expert
ClsMLP = SharedMoE
mlp_kwargs['use_shared_gate'] = True
mlp_kwargs['use_side_stream'] = True
mlp_kwargs['moe_config'].shared_expert_intermediate_size = \
config.moe_shared_expert_intermediate_size
elif config.qwen_type == 'qwen3_moe':
# Qwen3 MoE uses standard MOE without shared expert
ClsMLP = MOE
else:
ClsMLP = MOE
else:
@ -104,7 +108,7 @@ class QWenDecoderLayer(Module):
# Qwen's real inter_size depends on qwen_type
if self.config.qwen_type == 'qwen':
intermediate_size = config.intermediate_size // 2
elif self.config.qwen_type == 'qwen2_moe':
elif self.config.qwen_type in ('qwen2_moe', 'qwen3_moe'):
intermediate_size = config.moe_intermediate_size
else:
intermediate_size = config.intermediate_size
@ -264,18 +268,11 @@ class QWenForCausalLM(DecoderModelForCausalLM):
"mlp_4h_to_h": "mlp.c_proj",
"mlp_gate": "w1",
}
elif config.qwen_type == 'qwen2_moe':
elif config.qwen_type in ('qwen2_moe', 'qwen3_moe'):
self.trtllm_modules_to_hf_modules = copy.copy(
get_default_trtllm_modules_to_hf_modules())
# Common MoE expert mappings for both Qwen2 and Qwen3 MoE
self.trtllm_modules_to_hf_modules.update({
"mlp_h_to_4h":
"mlp.shared_expert.gate_proj",
"mlp_4h_to_h":
"mlp.shared_expert.down_proj",
"mlp_gate":
"mlp.shared_expert.up_proj",
"mlp_router":
"mlp.shared_expert_gate",
"moe_h_to_4h":
"mlp.experts.gate_proj",
"moe_4h_to_h":
@ -283,6 +280,18 @@ class QWenForCausalLM(DecoderModelForCausalLM):
"moe_gate":
"mlp.experts.up_proj",
})
# Qwen2 MoE additionally has shared expert
if config.qwen_type == 'qwen2_moe':
self.trtllm_modules_to_hf_modules.update({
"mlp_h_to_4h":
"mlp.shared_expert.gate_proj",
"mlp_4h_to_h":
"mlp.shared_expert.down_proj",
"mlp_gate":
"mlp.shared_expert.up_proj",
"mlp_router":
"mlp.shared_expert_gate",
})
else:
self.trtllm_modules_to_hf_modules = None
super().__init__(config, transformer, lm_head)
@ -343,6 +352,12 @@ class QWenForCausalLM(DecoderModelForCausalLM):
"mlp.shared_expert_gate": "mlp.shared_expert_gate",
"fc": ["up_proj", "gate_proj"],
}
elif config.qwen_type == "qwen3_moe":
custom_dict = {
"fc": ["up_proj", "gate_proj"],
"q_layernorm": "q_norm",
"k_layernorm": "k_norm",
}
elif config.qwen_type in {"qwen2", "qwen2_vl"
} and config.tie_word_embeddings:
custom_dict = {"lm_head": "model.embed_tokens"}
@ -360,7 +375,7 @@ class QWenForCausalLM(DecoderModelForCausalLM):
"transformer": "language_model.model",
"lm_head": "language_model.lm_head",
}
elif config.qwen_type in ("qwen3", "qwen3_moe"):
elif config.qwen_type == "qwen3":
custom_dict = {
"q_layernorm": "q_norm",
"k_layernorm": "k_norm",
@ -412,7 +427,7 @@ class QWenForCausalLM(DecoderModelForCausalLM):
loader.load(tllm_key,
custom_postprocess_kwargs=arg_dict))
loader.fill(tllm_weights)
elif config.qwen_type == "qwen2_moe":
elif config.qwen_type in ("qwen2_moe", "qwen3_moe"):
for tllm_key, _ in model.named_parameters():
sub_module = model
for attr in tllm_key.split(".")[:-1]: