[https://nvbugs/5536131][fix] Fix illegal access issue when scale is not provided in Llama3/4. (#7960)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-10-08 14:47:00 +08:00 committed by GitHub
parent 647080e3d5
commit 1ca84e1a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 45 deletions

View File

@ -561,13 +561,13 @@ class Llama4DecoderLayer(DecoderLayer):
else:
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale if hasattr(
self.next_attn.qkv_proj, 'input_scale') else None
else:
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
if not (self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant)) \
or not hasattr(self.next_attn.qkv_proj, 'input_scale'):
scale = None
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
else:
scale = self.next_attn.qkv_proj.input_scale
# TODO: MIN_LATENCY_MODE is hardcoded to False
if cutlass_min_latency_mode:
@ -771,13 +771,14 @@ class LlamaDecoderLayer(DecoderLayer):
else:
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale if hasattr(
self.next_attn.qkv_proj, 'input_scale') else None
else:
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
if not (self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant)) \
or not hasattr(self.next_attn.qkv_proj, 'input_scale'):
scale = None
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
else:
scale = self.next_attn.qkv_proj.input_scale
all_reduce_output = self.all_reduce(
hidden_states,

View File

@ -800,18 +800,19 @@ class Llama4MinLatencyDecoderLayer(Llama4DecoderLayer):
needs_post_allreduce = self.fusion_config.POST_MOE_FUSION \
or self.fusion_config.POST_MLP_FUSION
if needs_post_allreduce and self.next_layer_layernorm is not None:
if use_fp8_allreduce and self.next_attn is not None:
if use_fp8_allreduce and self.next_attn is not None \
and hasattr(elf.next_attn.qkv_proj, 'input_scale'):
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=self.next_attn.qkv_proj.input_scale if hasattr(
self.next_attn.qkv_proj, 'input_scale') else None,
scale=self.next_attn.qkv_proj.input_scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
elif use_fp4_allreduce and self.next_attn is not None:
elif use_fp4_allreduce and self.next_attn is not None \
and hasattr(self.next_attn.qkv_proj, 'input_scale'):
act_fp4, act_sf, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -819,8 +820,7 @@ class Llama4MinLatencyDecoderLayer(Llama4DecoderLayer):
RESIDUAL_RMS_NORM_QUANT_NVFP4,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=self.next_attn.qkv_proj.input_scale if hasattr(
self.next_attn.qkv_proj, 'input_scale') else None,
scale=self.next_attn.qkv_proj.input_scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:

View File

@ -643,6 +643,30 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
task.evaluate(llm,
extra_evaluator_kwargs=dict(apply_chat_template=True))
@pytest.mark.skip_less_device(4)
@skip_pre_blackwell
def test_fp8_tp2pp2(self):
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5)
with LLM(model_path,
tensor_parallel_size=2,
pipeline_parallel_size=2,
max_batch_size=32,
kv_cache_config=kv_cache_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
sampling_params = SamplingParams(
max_tokens=256,
temperature=0.0,
add_special_tokens=False,
)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GPQADiamond(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=dict(apply_chat_template=True))
class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"

View File

@ -51,6 +51,7 @@ l0_dgx_b200:
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
- condition:
@ -103,7 +104,6 @@ l0_dgx_b200:
backend: pytorch
tests:
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
- unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]

View File

@ -1,26 +0,0 @@
from utils.llm_data import llm_models_root
from utils.util import similar
from tensorrt_llm import LLM
def test_llama_3_3():
model_dir = llm_models_root(
) / "llama-3.3-models" / "Llama-3.3-70B-Instruct-FP8"
tp = 2
pp = 2
llm = LLM(model_dir, tensor_parallel_size=tp, pipeline_parallel_size=pp)
prompts = [
"The capital of France is",
"The president of the United States is",
]
outputs = llm.generate(prompts)
expected_outputs = [
" a city of romance, art, fashion, and cuisine. Paris, also known as the City of Light, is a must-visit destination for anyone interested in",
" the head of state and head of government of the United States. The president is also the commander-in-chief of the armed forces. The president is elected by the",
]
for i, output in enumerate(outputs):
assert similar(output.outputs[0].text, expected_outputs[i])