diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 98007d5ef4..e51cbd843d 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -165,22 +165,9 @@ class Llama4Attention(Attention): q, k, v = self.split_qkv(q, k, v) q = self._attention_scaling(q, position_ids) - out_scale = None - out_scale_sf = None - if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: - out_scale = self.o_proj.inv_input_scale - if self.o_proj.has_nvfp4 and self.support_nvfp4_output: - out_scale_sf = self.o_proj.input_scale - q, k, v = self.convert_qkv(q, k, v) - attn_output = self.attn.forward(q, - k, - v, - attn_metadata, - out_scale=out_scale, - out_scale_sf=out_scale_sf, - attention_mask=attention_mask, - mrope_config=mrope_config) + attn_output = self.forward_impl(q, k, v, attn_metadata, attention_mask, + None, None, mrope_config) if isinstance(attn_output, tuple): attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1]) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index e8a5774211..519de4ae37 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -362,6 +362,12 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], model_config, model_config.mapping) + if draft_config is not None: + for key, value in draft_config.extra_attrs.items(): + assert key in ('attn_layers', 'mla_layers') + assert key in model_config.extra_attrs + model_config.extra_attrs[key].update(value) + def forward( self, attn_metadata: AttentionMetadata, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 4cc1e5712c..4972a65db0 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -92,7 +92,7 @@ def attn_custom_op_inplace( mrope_position_deltas, attention_window_size, attention_mask_data, - False, + enable_attn_nvfp4_output=False, output=output) @@ -372,6 +372,58 @@ class Attention(nn.Module): return attn_output[0], attn_output[1] return attn_output, None + def forward_impl( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attention_mask: AttentionMask, + attention_window_size: Optional[int], + attention_mask_data: Optional[torch.Tensor], + mrope_config: Optional[dict], + ): + mrope_rotary_cos_sin = None + mrope_position_deltas = None + if mrope_config is not None: + if "mrope_rotary_cos_sin" in mrope_config: + mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"] + if "mrope_position_deltas" in mrope_config: + mrope_position_deltas = mrope_config["mrope_position_deltas"] + + # Currently only TRTLLM and FLASHINFER are torch compile compatible backends. + # Only enable custom inplace op when torch compiling. + use_custom_inplace_op = (self.register_to_config + and (self.attn_backend == "TRTLLM" + or self.attn_backend == "FLASHINFER") + and is_torch_compiling()) + + if use_custom_inplace_op: + output = self.create_output(q) + attn_custom_op_inplace( + q, + k, + v, + attention_mask, + mrope_rotary_cos_sin, + mrope_position_deltas, + attention_window_size, + attention_mask_data, + self.layer_idx_str, + output, + ) + else: + output, output_sf = self._attn_impl(q, k, v, attn_metadata, + attention_mask, + mrope_rotary_cos_sin, + mrope_position_deltas, + attention_window_size, + attention_mask_data) + if output_sf is not None: + output = Fp4QuantizedTensor(output, output_sf) + + return output + def forward( self, position_ids: Optional[torch.IntTensor], @@ -414,54 +466,18 @@ class Attention(nn.Module): if qkv_lora is not None: qkv = qkv + qkv_lora - mrope_rotary_cos_sin = None - mrope_position_deltas = None - if mrope_config is not None: - if "mrope_rotary_cos_sin" in mrope_config: - mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"] - if "mrope_position_deltas" in mrope_config: - mrope_position_deltas = mrope_config["mrope_position_deltas"] - - output = None - q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) - # Currently only TRTLLM and FLASHINFER are torch compile compatible backends. - # Only enable custom inplace op when torch compiling. - use_custom_inplace_op = (self.register_to_config - and (self.attn_backend == "TRTLLM" - or self.attn_backend == "FLASHINFER") - and is_torch_compiling()) - if use_custom_inplace_op: - output = self.create_output(q) - attn_custom_op_inplace( - q, - k, - v, - attention_mask, - mrope_rotary_cos_sin, - mrope_position_deltas, - attention_window_size, - attention_mask_data, - self.layer_idx_str, - output=output, - ) - else: - output, output_sf = self._attn_impl( - q, - k, - v, - attn_metadata, - attention_mask, - mrope_rotary_cos_sin, - mrope_position_deltas, - attention_window_size, - attention_mask_data, - ) - if output_sf is not None: - output = Fp4QuantizedTensor(output, output_sf) + output = self.forward_impl(q, + k, + v, + attn_metadata, + attention_mask, + attention_window_size, + attention_mask_data, + mrope_config=mrope_config) attn_output = self.o_proj(output, all_reduce_params=all_reduce_params, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8ce64d51ea..5112907c20 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -550,11 +550,14 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness): speculative_model_dir=eagle_model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.75) + torch_compile_config = TorchCompileConfig( + enable_fullgraph=True, + enable_piecewise_cuda_graph=True, + max_num_streams=3) if torch_compile else None pytorch_config = dict( cuda_graph_config=CudaGraphConfig(max_batch_size=8), enable_attention_dp=False, - torch_compile_config=TorchCompileConfig( - enable_fullgraph=torch_compile)) + torch_compile_config=torch_compile_config) with LLM(model_path, kv_cache_config=kv_cache_config, tensor_parallel_size=tp_size, diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 9f782c4276..2e7fa8a338 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -260,7 +260,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124) examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320) -accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320) accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541) accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)