diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 5dc193074b..e4ef4b2dbb 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -561,13 +561,13 @@ class Llama4DecoderLayer(DecoderLayer): else: # The next layernorm exists but it could be the last decoder layer. # Adjust the scale and fusion pattern. - if self.next_attn is not None and (self.is_nvfp4 - or self.is_fp8_quant): - scale = self.next_attn.qkv_proj.input_scale if hasattr( - self.next_attn.qkv_proj, 'input_scale') else None - else: - self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + if not (self.next_attn is not None and (self.is_nvfp4 + or self.is_fp8_quant)) \ + or not hasattr(self.next_attn.qkv_proj, 'input_scale'): scale = None + self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + else: + scale = self.next_attn.qkv_proj.input_scale # TODO: MIN_LATENCY_MODE is hardcoded to False if cutlass_min_latency_mode: @@ -771,13 +771,14 @@ class LlamaDecoderLayer(DecoderLayer): else: # The next layernorm exists but it could be the last decoder layer. # Adjust the scale and fusion pattern. - if self.next_attn is not None and (self.is_nvfp4 - or self.is_fp8_quant): - scale = self.next_attn.qkv_proj.input_scale if hasattr( - self.next_attn.qkv_proj, 'input_scale') else None - else: - self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + + if not (self.next_attn is not None and (self.is_nvfp4 + or self.is_fp8_quant)) \ + or not hasattr(self.next_attn.qkv_proj, 'input_scale'): scale = None + self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + else: + scale = self.next_attn.qkv_proj.input_scale all_reduce_output = self.all_reduce( hidden_states, diff --git a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py index 7762cda3b4..027eeeace2 100644 --- a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py +++ b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py @@ -800,18 +800,19 @@ class Llama4MinLatencyDecoderLayer(Llama4DecoderLayer): needs_post_allreduce = self.fusion_config.POST_MOE_FUSION \ or self.fusion_config.POST_MLP_FUSION if needs_post_allreduce and self.next_layer_layernorm is not None: - if use_fp8_allreduce and self.next_attn is not None: + if use_fp8_allreduce and self.next_attn is not None \ + and hasattr(elf.next_attn.qkv_proj, 'input_scale'): hidden_states, residual = self.all_reduce( hidden_states, all_reduce_params=AllReduceParams( fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, residual=residual, norm_weight=self.next_layer_layernorm.weight, - scale=self.next_attn.qkv_proj.input_scale if hasattr( - self.next_attn.qkv_proj, 'input_scale') else None, + scale=self.next_attn.qkv_proj.input_scale, eps=self.next_layer_layernorm.variance_epsilon, )) - elif use_fp4_allreduce and self.next_attn is not None: + elif use_fp4_allreduce and self.next_attn is not None \ + and hasattr(self.next_attn.qkv_proj, 'input_scale'): act_fp4, act_sf, residual = self.all_reduce( hidden_states, all_reduce_params=AllReduceParams( @@ -819,8 +820,7 @@ class Llama4MinLatencyDecoderLayer(Llama4DecoderLayer): RESIDUAL_RMS_NORM_QUANT_NVFP4, residual=residual, norm_weight=self.next_layer_layernorm.weight, - scale=self.next_attn.qkv_proj.input_scale if hasattr( - self.next_attn.qkv_proj, 'input_scale') else None, + scale=self.next_attn.qkv_proj.input_scale, eps=self.next_layer_layernorm.variance_epsilon, )) else: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index c6699b8bc5..4e4bd55f40 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -643,6 +643,30 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness): task.evaluate(llm, extra_evaluator_kwargs=dict(apply_chat_template=True)) + @pytest.mark.skip_less_device(4) + @skip_pre_blackwell + def test_fp8_tp2pp2(self): + model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5) + with LLM(model_path, + tensor_parallel_size=2, + pipeline_parallel_size=2, + max_batch_size=32, + kv_cache_config=kv_cache_config) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + sampling_params = SamplingParams( + max_tokens=256, + temperature=0.0, + add_special_tokens=False, + ) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GPQADiamond(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=dict(apply_chat_template=True)) + class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index e9de74d278..dc1cbe8dd0 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -51,6 +51,7 @@ l0_dgx_b200: - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 + - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2 - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] - condition: @@ -103,7 +104,6 @@ l0_dgx_b200: backend: pytorch tests: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] - - unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3 - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_llama3.py b/tests/unittest/_torch/multi_gpu_modeling/test_llama3.py deleted file mode 100644 index f04712a946..0000000000 --- a/tests/unittest/_torch/multi_gpu_modeling/test_llama3.py +++ /dev/null @@ -1,26 +0,0 @@ -from utils.llm_data import llm_models_root -from utils.util import similar - -from tensorrt_llm import LLM - - -def test_llama_3_3(): - model_dir = llm_models_root( - ) / "llama-3.3-models" / "Llama-3.3-70B-Instruct-FP8" - tp = 2 - pp = 2 - - llm = LLM(model_dir, tensor_parallel_size=tp, pipeline_parallel_size=pp) - prompts = [ - "The capital of France is", - "The president of the United States is", - ] - - outputs = llm.generate(prompts) - - expected_outputs = [ - " a city of romance, art, fashion, and cuisine. Paris, also known as the City of Light, is a must-visit destination for anyone interested in", - " the head of state and head of government of the United States. The president is also the commander-in-chief of the armed forces. The president is elected by the", - ] - for i, output in enumerate(outputs): - assert similar(output.outputs[0].text, expected_outputs[i])