[https://nvbugs/5691730][fix] Have LoRa bf16 ckpts work with Llama 3.3-70B-fp8 (#9808)

Signed-off-by: Michal Guzek <mguzek@nvidia.com>
Signed-off-by: Michal Guzek <moraxu@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Michal Guzek 2026-01-21 23:59:18 -08:00 committed by Yanchao Lu
parent bc2487bc2c
commit fafc22e3d4
3 changed files with 70 additions and 12 deletions

View File

@ -778,21 +778,32 @@ class LlamaDecoderLayer(DecoderLayer):
)
# Fully Connected
if self.PRE_MLP_FUSION:
has_lora = bool(kwargs.get('lora_params'))
if self.is_nvfp4 or self.is_fp8_quant:
scale = self.mlp.gate_up_proj.input_scale
# WAR: Skip FP8/NVFP4 quantization when LoRA is active
# since LoRA grouped_gemm does not support FP8 yet
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
if has_lora:
scale = None # To prevent quantization
fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM # Use non-quantizing fusion
else:
scale = self.mlp.gate_up_proj.input_scale
fusion_op = self.pre_mlp_fusion_op
else:
scale = None
fusion_op = self.pre_mlp_fusion_op
all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.pre_mlp_fusion_op,
fusion_op=fusion_op,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=scale,
eps=self.post_attention_layernorm.variance_epsilon,
))
if self.is_nvfp4:
if fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
act_fp4, act_sf, residual = all_reduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
@ -841,24 +852,30 @@ class LlamaDecoderLayer(DecoderLayer):
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
has_lora = bool(kwargs.get('lora_params'))
# WAR: Skip FP8/NVFP4 quantization when LoRA is active
# since LoRA grouped_gemm does not support FP8 yet
if not (self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant)) \
or not hasattr(self.next_attn.qkv_proj, 'input_scale'):
or not hasattr(self.next_attn.qkv_proj, 'input_scale') \
or has_lora:
scale = None
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
post_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
else:
scale = self.next_attn.qkv_proj.input_scale
post_fusion_op = self.post_mlp_fusion_op
all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_mlp_fusion_op,
fusion_op=post_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
if self.post_mlp_fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
if post_fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
act_fp4, act_sf, residual = all_reduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:

View File

@ -433,6 +433,7 @@ class Attention(nn.Module):
output: Optional[torch.Tensor] = None,
output_sf: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
has_lora: bool = False,
):
num_tokens = attn_metadata.num_tokens
@ -446,7 +447,8 @@ class Attention(nn.Module):
out_scale_sf = None
# Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output
# and keeps attention output in BF16 for better precision when applying pre_quant_scale
if self._use_quantize_output():
# Also don't set out_scale if LoRA is active - LoRA grouped_gemm doesn't support FP8
if self._use_quantize_output() and not has_lora:
out_scale = self.o_proj.inv_input_scale
out_scale_sf = self.o_proj.input_scale
@ -499,6 +501,7 @@ class Attention(nn.Module):
attention_mask_data: Optional[torch.Tensor],
mrope_config: Optional[dict],
attention_sinks: Optional[torch.Tensor] = None,
has_lora: bool = False,
):
mrope_rotary_cos_sin = None
mrope_position_deltas = None
@ -544,7 +547,8 @@ class Attention(nn.Module):
mrope_position_deltas,
attention_window_size,
attention_mask_data,
attention_sinks=attention_sinks)
attention_sinks=attention_sinks,
has_lora=has_lora)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)
@ -619,7 +623,8 @@ class Attention(nn.Module):
attention_window_size,
attention_mask_data,
mrope_config=mrope_config,
attention_sinks=attention_sinks)
attention_sinks=attention_sinks,
has_lora=bool(lora_params))
if self.attn_output_gate:
gate = torch.sigmoid(gate)

View File

@ -30,8 +30,8 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path,
prompts, run_llm_abort_request,
run_llm_with_postprocess_parallel_and_result_handler,
tinyllama_logits_processor_test_harness)
from utils.util import (force_ampere, similar, skip_fp8_pre_ada,
skip_gpu_memory_less_than_40gb,
from utils.util import (force_ampere, similar, similarity_score,
skip_fp8_pre_ada, skip_gpu_memory_less_than_40gb,
skip_gpu_memory_less_than_80gb,
skip_gpu_memory_less_than_138gb, skip_ray)
from utils.llm_data import llm_models_root
@ -629,6 +629,42 @@ def test_llama_3_1_8b_fp8_with_bf16_lora(cuda_graph_config) -> None:
assert similar(output.outputs[0].text, reference)
@skip_ray # https://nvbugs/5682551
@skip_gpu_memory_less_than_80gb
def test_llama_3_3_70b_fp8_with_squad_lora_tp2() -> None:
skip_fp8_pre_ada(use_fp8=True)
model_dir = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8"
lora_dir = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8-lora-adapter_NIM_r8"
prompt = "What is the capital of the United States?"
expected_output = " Washington, D.C.\nWhat is the capital of the United States? Washington, D.C."
lora_config = LoraConfig(lora_dir=[lora_dir],
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)
lora_req = LoRARequest("squad-lora", 0, lora_dir)
llm = LLM(model_dir,
tensor_parallel_size=2,
lora_config=lora_config,
cuda_graph_config=None)
try:
output = llm.generate(prompt,
SamplingParams(max_tokens=50, temperature=0.0),
lora_request=[lora_req])
generated_text = output.outputs[0].text
print(f"Generated output: {repr(generated_text)}")
similarity = similarity_score(generated_text, expected_output)
assert similar(generated_text, expected_output, threshold=0.8), \
f"Output similarity too low (similarity={similarity:.2%})!\nExpected: {repr(expected_output)}\nGot: {repr(generated_text)}"
finally:
llm.shutdown()
@skip_gpu_memory_less_than_80gb
@pytest.mark.part2
@test_lora_with_and_without_cuda_graph