mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5516710][fix] fix Llama 3.3 TP PP case (#7717)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
44d7c3b245
commit
5342c607cd
@ -563,7 +563,8 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
# Adjust the scale and fusion pattern.
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
scale = self.next_attn.qkv_proj.input_scale if hasattr(
|
||||
self.next_attn.qkv_proj, 'input_scale') else None
|
||||
else:
|
||||
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
@ -772,7 +773,8 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
# Adjust the scale and fusion pattern.
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
scale = self.next_attn.qkv_proj.input_scale if hasattr(
|
||||
self.next_attn.qkv_proj, 'input_scale') else None
|
||||
else:
|
||||
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
@ -807,7 +807,8 @@ class Llama4MinLatencyDecoderLayer(Llama4DecoderLayer):
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=self.next_attn.qkv_proj.input_scale,
|
||||
scale=self.next_attn.qkv_proj.input_scale if hasattr(
|
||||
self.next_attn.qkv_proj, 'input_scale') else None,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
elif use_fp4_allreduce and self.next_attn is not None:
|
||||
@ -818,7 +819,8 @@ class Llama4MinLatencyDecoderLayer(Llama4DecoderLayer):
|
||||
RESIDUAL_RMS_NORM_QUANT_NVFP4,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=self.next_attn.qkv_proj.input_scale,
|
||||
scale=self.next_attn.qkv_proj.input_scale if hasattr(
|
||||
self.next_attn.qkv_proj, 'input_scale') else None,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
else:
|
||||
|
||||
@ -103,6 +103,7 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
tests:
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
|
||||
26
tests/unittest/_torch/multi_gpu_modeling/test_llama3.py
Normal file
26
tests/unittest/_torch/multi_gpu_modeling/test_llama3.py
Normal file
@ -0,0 +1,26 @@
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import similar
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
|
||||
def test_llama_3_3():
|
||||
model_dir = llm_models_root(
|
||||
) / "llama-3.3-models" / "Llama-3.3-70B-Instruct-FP8"
|
||||
tp = 2
|
||||
pp = 2
|
||||
|
||||
llm = LLM(model_dir, tensor_parallel_size=tp, pipeline_parallel_size=pp)
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
|
||||
outputs = llm.generate(prompts)
|
||||
|
||||
expected_outputs = [
|
||||
" a city of romance, art, fashion, and cuisine. Paris, also known as the City of Light, is a must-visit destination for anyone interested in",
|
||||
" the head of state and head of government of the United States. The president is also the commander-in-chief of the armed forces. The president is elected by the",
|
||||
]
|
||||
for i, output in enumerate(outputs):
|
||||
assert similar(output.outputs[0].text, expected_outputs[i])
|
||||
Loading…
Reference in New Issue
Block a user