diff --git a/tensorrt_llm/_torch/models/modeling_qwen.py b/tensorrt_llm/_torch/models/modeling_qwen.py index 493cbae823..fa20c4df00 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen.py +++ b/tensorrt_llm/_torch/models/modeling_qwen.py @@ -16,6 +16,7 @@ from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode from ..modules.qk_norm_attention import QKNormRoPEAttention from ..modules.rms_norm import RMSNorm +from ..speculative import SpecMetadata from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, register_auto_model) @@ -148,6 +149,7 @@ class QwenDecoderLayer(DecoderLayer): attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mrope_config: Optional[Tuple[torch.Tensor, int]] = None, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: if residual is None: @@ -170,6 +172,10 @@ class QwenDecoderLayer(DecoderLayer): hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) + + if spec_metadata is not None: + spec_metadata.maybe_capture_hidden_states(self.layer_idx, + hidden_states, residual) return hidden_states, residual @@ -204,6 +210,7 @@ class QwenModel(DecoderModel): position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, mrope_config: Optional[Tuple[torch.Tensor, int]] = None, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -222,7 +229,8 @@ class QwenModel(DecoderModel): hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, - mrope_config=mrope_config) + mrope_config=mrope_config, + spec_metadata=spec_metadata) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -249,15 +257,15 @@ class Qwen2ForCausalLM(DecoderModelForCausalLM[QwenModel, Qwen2Config]): inputs_embeds: Optional[torch.FloatTensor] = None, return_context_logits: bool = False, mrope_config: Optional[dict] = None, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: - output = self.model( - input_ids=input_ids, - attn_metadata=attn_metadata, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - mrope_config=mrope_config, - ) + output = self.model(input_ids=input_ids, + attn_metadata=attn_metadata, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + mrope_config=mrope_config, + spec_metadata=spec_metadata) return self.logits_processor.forward( output,