mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: SpicyNoodle <522169030@qq.com>
This commit is contained in:
parent
9cae7277ea
commit
46f035befe
@ -16,6 +16,7 @@ from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.qk_norm_attention import QKNormRoPEAttention
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
|
||||
@ -148,6 +149,7 @@ class QwenDecoderLayer(DecoderLayer):
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
@ -170,6 +172,10 @@ class QwenDecoderLayer(DecoderLayer):
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if spec_metadata is not None:
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -204,6 +210,7 @@ class QwenModel(DecoderModel):
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -222,7 +229,8 @@ class QwenModel(DecoderModel):
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mrope_config=mrope_config)
|
||||
mrope_config=mrope_config,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@ -249,15 +257,15 @@ class Qwen2ForCausalLM(DecoderModelForCausalLM[QwenModel, Qwen2Config]):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
mrope_config: Optional[dict] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
mrope_config=mrope_config,
|
||||
)
|
||||
output = self.model(input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
mrope_config=mrope_config,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
return self.logits_processor.forward(
|
||||
output,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user