mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[feat] Add EAGLE3 support for Qwen3 (#5206)
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
parent
517c1ecf72
commit
498fadceb4
@ -18,11 +18,13 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod, CutlassFusedMoE, MoE,
|
||||
RoutingMethodType, create_moe)
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import disable_fp4_allgather
|
||||
from .modeling_qwen3 import Qwen3Attention
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
EagerFusionConfig, duplicate_kv_weight,
|
||||
filter_weights, register_auto_model)
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import (DecoderModel, EagerFusionConfig,
|
||||
duplicate_kv_weight, filter_weights,
|
||||
register_auto_model)
|
||||
|
||||
|
||||
class Qwen3Gate(nn.Module):
|
||||
@ -203,6 +205,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
@ -240,6 +243,9 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)))
|
||||
|
||||
if spec_metadata:
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
@ -300,6 +306,7 @@ class Qwen3MoEModel(DecoderModel):
|
||||
input_ids: Optional[torch.IntTensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -317,13 +324,14 @@ class Qwen3MoEModel(DecoderModel):
|
||||
hidden_states, residual = decoder_layer(position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual)
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_auto_model("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
|
||||
Qwen3MoeConfig]):
|
||||
class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
|
||||
Qwen3MoeConfig]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -331,9 +339,7 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
|
||||
):
|
||||
super().__init__(
|
||||
Qwen3MoEModel(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size,
|
||||
model_config,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
@ -353,7 +359,7 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
|
||||
if len(module._parameters) > 0:
|
||||
# skip load weights if tie word embeddings is enabled and layer is lm_head
|
||||
if self.config.tie_word_embeddings and name.startswith(
|
||||
"lm_head"):
|
||||
"lm_head") or name.startswith("draft_model"):
|
||||
continue
|
||||
|
||||
names = name.split(".")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user