feat(eagle3):support qwen3 dense model (#5879)

Signed-off-by: xq25478 <xq25478@qq.com>
This commit is contained in:
xiaoqi 2025-07-19 01:24:32 +08:00 committed by GitHub
parent 22d4a8c48a
commit 28858c8711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 32 deletions

View File

@ -16,8 +16,9 @@ from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
from ..speculative import SpecMetadata
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, register_auto_model
class Qwen3Attention(Attention):
@ -148,6 +149,7 @@ class Qwen3DecoderLayer(DecoderLayer):
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
@ -171,6 +173,10 @@ class Qwen3DecoderLayer(DecoderLayer):
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
return hidden_states, residual
@ -207,6 +213,7 @@ class Qwen3Model(DecoderModel):
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -227,6 +234,7 @@ class Qwen3Model(DecoderModel):
attn_metadata=attn_metadata,
residual=residual,
mrope_config=mrope_config,
spec_metadata=spec_metadata,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -234,7 +242,7 @@ class Qwen3Model(DecoderModel):
@register_auto_model("Qwen3ForCausalLM")
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]):
def __init__(
self,
@ -242,33 +250,5 @@ class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
):
super().__init__(
Qwen3Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)
# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.IntTensor = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
mrope_config: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
mrope_config=mrope_config,
)
return self.logits_processor.forward(
output,
self.lm_head,
attn_metadata,
return_context_logits,
model_config,
)

View File

@ -150,6 +150,8 @@ Qwen3/Qwen3-8B:
- quant_algo: FP8_BLOCK_SCALES
accuracy: 76.12
- accuracy: 76.12
- spec_dec_algo: Eagle
accuracy: 76.12
Qwen3/Qwen3-30B-A3B:
- quant_algo: FP8_BLOCK_SCALES
accuracy: 79.53

View File

@ -1658,6 +1658,30 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
def test_eagle3(self):
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1]),
)
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
eagle_model_dir = f"{llm_models_root()}/Qwen3/qwen3_8b_eagle3"
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir)
llm = LLM(model=target_model_dir,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
build_config=None)
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-30B-A3B"

View File

@ -40,6 +40,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]