mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat(eagle3):support qwen3 dense model (#5879)
Signed-off-by: xq25478 <xq25478@qq.com>
This commit is contained in:
parent
22d4a8c48a
commit
28858c8711
@ -16,8 +16,9 @@ from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
from ..speculative import SpecMetadata
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, register_auto_model
|
||||
|
||||
|
||||
class Qwen3Attention(Attention):
|
||||
@ -148,6 +149,7 @@ class Qwen3DecoderLayer(DecoderLayer):
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
@ -171,6 +173,10 @@ class Qwen3DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if spec_metadata is not None:
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -207,6 +213,7 @@ class Qwen3Model(DecoderModel):
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -227,6 +234,7 @@ class Qwen3Model(DecoderModel):
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mrope_config=mrope_config,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -234,7 +242,7 @@ class Qwen3Model(DecoderModel):
|
||||
|
||||
|
||||
@register_auto_model("Qwen3ForCausalLM")
|
||||
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
|
||||
class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -242,33 +250,5 @@ class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
|
||||
):
|
||||
super().__init__(
|
||||
Qwen3Model(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size,
|
||||
)
|
||||
|
||||
# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: torch.IntTensor = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
mrope_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
mrope_config=mrope_config,
|
||||
)
|
||||
|
||||
return self.logits_processor.forward(
|
||||
output,
|
||||
self.lm_head,
|
||||
attn_metadata,
|
||||
return_context_logits,
|
||||
model_config,
|
||||
)
|
||||
|
||||
@ -150,6 +150,8 @@ Qwen3/Qwen3-8B:
|
||||
- quant_algo: FP8_BLOCK_SCALES
|
||||
accuracy: 76.12
|
||||
- accuracy: 76.12
|
||||
- spec_dec_algo: Eagle
|
||||
accuracy: 76.12
|
||||
Qwen3/Qwen3-30B-A3B:
|
||||
- quant_algo: FP8_BLOCK_SCALES
|
||||
accuracy: 79.53
|
||||
|
||||
@ -1658,6 +1658,30 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
def test_eagle3(self):
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=CudaGraphConfig(batch_sizes=[1]),
|
||||
)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
|
||||
|
||||
eagle_model_dir = f"{llm_models_root()}/Qwen3/qwen3_8b_eagle3"
|
||||
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir)
|
||||
|
||||
llm = LLM(model=target_model_dir,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config,
|
||||
build_config=None)
|
||||
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen3/Qwen3-30B-A3B"
|
||||
|
||||
@ -40,6 +40,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user