diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 1ba1fa0299..1f14e2d278 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -53,11 +53,11 @@ class Gemma3Attention(Attention): max_position_embeddings=config.max_position_embeddings, bias=False, pos_embd_params=pos_embd_params, + qk_norm_type=QkNormType.pre_rope, layer_idx=layer_idx, dtype=config.torch_dtype, dense_bias=False, config=model_config, - qk_norm_type=QkNormType.pre_rope, q_scaling=q_scaling, ) self.q_norm = RMSNorm(hidden_size=config.head_dim, diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 79443fd7cc..36163e6b10 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -73,11 +73,11 @@ class Llama4Attention(Attention): max_position_embeddings=config.max_position_embeddings, bias=config.attention_bias, pos_embd_params=pos_embd_params, + qk_norm_type=QkNormType.post_rope + if use_qk_norm else QkNormType.none, layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, - qk_norm_type=QkNormType.post_rope - if use_qk_norm else QkNormType.none, attention_chunk_size=attention_chunk_size, ) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index d614100f3e..cfd1b2e9ff 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -26,8 +26,10 @@ class Qwen3Attention(Attention): self, model_config: ModelConfig[Qwen3Config], layer_idx: Optional[int] = None, + fuse_qk_norm_rope: bool = True, ): config = model_config.pretrained_config + if getattr(config, "rope_scaling", None) is not None: pos_embd_params = PositionalEmbeddingParams( type=PositionEmbeddingType.from_string( @@ -40,20 +42,27 @@ class Qwen3Attention(Attention): rope=RopeParams.from_config(config), ) + self.fuse_qk_norm_rope = fuse_qk_norm_rope + super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, bias=config.attention_bias, - pos_embd_params=pos_embd_params, + pos_embd_params=pos_embd_params + if not self.fuse_qk_norm_rope else None, + qk_norm_type=QkNormType.pre_rope, layer_idx=layer_idx, dtype=config.torch_dtype, dense_bias=config.attention_bias, config=model_config, - qk_norm_type=QkNormType.pre_rope, ) + # If fuse_qk_norm_rope is true, we pass pos_embd_params=None to super().__init__, + # so we need to do assignment to record the actual pos_embd_params. + self.pos_embd_params = pos_embd_params + self.q_norm = RMSNorm(hidden_size=self.head_dim, eps=1e-6, dtype=config.torch_dtype, @@ -85,6 +94,21 @@ class Qwen3Attention(Attention): return q, k + def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor): + if not self.fuse_qk_norm_rope: + return super().apply_rope(qkv, position_ids) + else: + return self.apply_qk_norm_rope(qkv, position_ids) + + def apply_qk_norm_rope(self, qkv, position_ids): + torch.ops.trtllm.fused_qk_norm_rope( + qkv, self.num_heads, self.num_key_value_heads, + self.num_key_value_heads, self.head_dim, + self.q_norm.variance_epsilon, self.q_norm.weight, + self.k_norm.weight, self.pos_embd_params.rope.theta, + self.pos_embd_params.is_neox, position_ids.view(-1)) + return qkv, None, None + class Qwen3DecoderLayer(DecoderLayer): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 83a87ae9f1..4c7c8a5b61 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -43,11 +43,11 @@ class Attention(nn.Module): max_position_embeddings: int, bias: bool, pos_embd_params: Optional[PositionalEmbeddingParams] = None, + qk_norm_type: QkNormType = QkNormType.none, layer_idx: Optional[int] = None, dtype: torch.dtype = None, dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, - qk_norm_type: QkNormType = QkNormType.none, q_scaling: float = 1.0, attention_chunk_size: Optional[int] = None, ): @@ -61,11 +61,11 @@ class Attention(nn.Module): max_position_embeddings (int): The maximum position embeddings. bias (bool): Whether to use bias in the linear layers. pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters. + qk_norm_type (QkNormType): The type of QK normalization. layer_idx (int): The layer index. dtype (torch.dtype): The data type. dense_bias (bool): Whether to use bias in the output projection layer. config (ModelConfig): The model configuration. - qk_norm_type (QkNormType): The type of QK normalization. q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0. attention_chunk_size (int): See [Chunked Attention] below. """ @@ -154,7 +154,7 @@ class Attention(nn.Module): self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend - self.pos_embd_params = pos_embd_params + attn_cls = get_attention_backend(self.attn_backend) # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used, # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora @@ -169,9 +169,20 @@ class Attention(nn.Module): self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) - attn_cls = get_attention_backend(self.attn_backend) + # enable_rope_fusion: Whether to fuse RoPE into the attention OP. + # If true, RoPE will be applied in self.attn.forward. + # If false, RoPE will be applied in self.apply_rope. self.enable_rope_fusion = attn_cls.support_fused_rope( - ) and qk_norm_type != QkNormType.post_rope + ) and self.qk_norm_type != QkNormType.post_rope + + self.rotary_emb = None + if not self.enable_rope_fusion and self.pos_embd_params is not None: + self.rotary_emb = RotaryEmbedding( + self.pos_embd_params.rope, + head_dim=self.head_dim, + is_neox=self.pos_embd_params.is_neox, + ) + self.attn = create_attention( self.attn_backend, self.layer_idx, @@ -188,16 +199,6 @@ class Attention(nn.Module): self.support_fused_qkv = self.attn.support_fused_qkv() - self.rotary_emb = None - self.apply_rotary_emb = (not self.enable_rope_fusion - and pos_embd_params is not None) - if self.apply_rotary_emb: - self.rotary_emb = RotaryEmbedding( - pos_embd_params.rope, - head_dim=self.head_dim, - is_neox=pos_embd_params.is_neox, - ) - if not config.skip_create_weights_in_init: self.create_weights() @@ -261,17 +262,9 @@ class Attention(nn.Module): if qkv_lora is not None: qkv = qkv + qkv_lora - q, k, v = qkv, None, None - if self.qk_norm_type == QkNormType.pre_rope: - q, k, v = self.split_qkv(q, k, v) - q, k = self.apply_qk_norm(q, k) - if self.apply_rotary_emb and position_ids is not None: - q, k, v = self.split_qkv(q, k, v) - q, k = self.rotary_emb(position_ids, [q, k]) - if self.qk_norm_type == QkNormType.post_rope: - q, k = self.apply_qk_norm(q, k) - out_scale = None + q, k, v = self.apply_rope(qkv, position_ids) + out_scale = None if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: out_scale = self.o_proj.inv_input_scale @@ -297,6 +290,29 @@ class Attention(nn.Module): f"QK norm is not implemented for {self.__class__.__name__}." "Please override the `apply_qk_norm` method in the subclass.") + def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor): + """ + Apply RoPE to the query and key, possibly including QK norm. + Args: + qkv (torch.Tensor): The query, key, and value tensor. + position_ids (torch.Tensor): The position IDs of each token for RoPE. + Returns: + tuple: A tuple of (q, k, v). + This method could be overridden in the subclass, it is possible that k/v is None and q is the concatenated qkv tensor, up to the implementation. + Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn. + """ + q, k, v = qkv, None, None + if self.qk_norm_type == QkNormType.pre_rope: + q, k, v = self.split_qkv(q, k, v) + q, k = self.apply_qk_norm(q, k) + if not self.enable_rope_fusion and position_ids is not None: + q, k, v = self.split_qkv(q, k, v) + q, k = self.rotary_emb(position_ids, [q, k]) + if self.qk_norm_type == QkNormType.post_rope: + q, k = self.apply_qk_norm(q, k) + + return q, k, v + def extract_extra_attrs(layer_idx: str): extra_attrs = get_model_extra_attrs() diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 36f1ee5468..dd691fcef0 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -52,10 +52,10 @@ class DecodingCUDAGraphRunner: # Using ones instead of zeros prevents NaNs in e.g. Deepseek self.input_ids = torch.ones((batch_size * token_per_request, ), device=device, - dtype=torch.int64) + dtype=torch.int32) self.position_ids = torch.zeros((1, batch_size * token_per_request), device=device, - dtype=torch.int64) + dtype=torch.int32) self.extra_model_inputs = {} self.attn_metadata = attn_metadata