mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5691730][fix] Have LoRa bf16 ckpts work with Llama 3.3-70B-fp8 (#9808)
Signed-off-by: Michal Guzek <mguzek@nvidia.com> Signed-off-by: Michal Guzek <moraxu@users.noreply.github.com> Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
bc2487bc2c
commit
fafc22e3d4
@ -778,21 +778,32 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
)
|
||||
# Fully Connected
|
||||
if self.PRE_MLP_FUSION:
|
||||
has_lora = bool(kwargs.get('lora_params'))
|
||||
|
||||
if self.is_nvfp4 or self.is_fp8_quant:
|
||||
scale = self.mlp.gate_up_proj.input_scale
|
||||
# WAR: Skip FP8/NVFP4 quantization when LoRA is active
|
||||
# since LoRA grouped_gemm does not support FP8 yet
|
||||
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
|
||||
if has_lora:
|
||||
scale = None # To prevent quantization
|
||||
fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM # Use non-quantizing fusion
|
||||
else:
|
||||
scale = self.mlp.gate_up_proj.input_scale
|
||||
fusion_op = self.pre_mlp_fusion_op
|
||||
else:
|
||||
scale = None
|
||||
fusion_op = self.pre_mlp_fusion_op
|
||||
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.pre_mlp_fusion_op,
|
||||
fusion_op=fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.post_attention_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.post_attention_layernorm.variance_epsilon,
|
||||
))
|
||||
if self.is_nvfp4:
|
||||
if fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
|
||||
act_fp4, act_sf, residual = all_reduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
@ -841,24 +852,30 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
# The next layernorm exists but it could be the last decoder layer.
|
||||
# Adjust the scale and fusion pattern.
|
||||
|
||||
has_lora = bool(kwargs.get('lora_params'))
|
||||
|
||||
# WAR: Skip FP8/NVFP4 quantization when LoRA is active
|
||||
# since LoRA grouped_gemm does not support FP8 yet
|
||||
if not (self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant)) \
|
||||
or not hasattr(self.next_attn.qkv_proj, 'input_scale'):
|
||||
or not hasattr(self.next_attn.qkv_proj, 'input_scale') \
|
||||
or has_lora:
|
||||
scale = None
|
||||
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
post_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
else:
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
post_fusion_op = self.post_mlp_fusion_op
|
||||
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_mlp_fusion_op,
|
||||
fusion_op=post_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
if self.post_mlp_fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
|
||||
if post_fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
|
||||
act_fp4, act_sf, residual = all_reduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
|
||||
@ -433,6 +433,7 @@ class Attention(nn.Module):
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_sf: Optional[torch.Tensor] = None,
|
||||
attention_sinks: Optional[torch.Tensor] = None,
|
||||
has_lora: bool = False,
|
||||
):
|
||||
num_tokens = attn_metadata.num_tokens
|
||||
|
||||
@ -446,7 +447,8 @@ class Attention(nn.Module):
|
||||
out_scale_sf = None
|
||||
# Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output
|
||||
# and keeps attention output in BF16 for better precision when applying pre_quant_scale
|
||||
if self._use_quantize_output():
|
||||
# Also don't set out_scale if LoRA is active - LoRA grouped_gemm doesn't support FP8
|
||||
if self._use_quantize_output() and not has_lora:
|
||||
out_scale = self.o_proj.inv_input_scale
|
||||
out_scale_sf = self.o_proj.input_scale
|
||||
|
||||
@ -499,6 +501,7 @@ class Attention(nn.Module):
|
||||
attention_mask_data: Optional[torch.Tensor],
|
||||
mrope_config: Optional[dict],
|
||||
attention_sinks: Optional[torch.Tensor] = None,
|
||||
has_lora: bool = False,
|
||||
):
|
||||
mrope_rotary_cos_sin = None
|
||||
mrope_position_deltas = None
|
||||
@ -544,7 +547,8 @@ class Attention(nn.Module):
|
||||
mrope_position_deltas,
|
||||
attention_window_size,
|
||||
attention_mask_data,
|
||||
attention_sinks=attention_sinks)
|
||||
attention_sinks=attention_sinks,
|
||||
has_lora=has_lora)
|
||||
if output_sf is not None:
|
||||
output = Fp4QuantizedTensor(output, output_sf)
|
||||
|
||||
@ -619,7 +623,8 @@ class Attention(nn.Module):
|
||||
attention_window_size,
|
||||
attention_mask_data,
|
||||
mrope_config=mrope_config,
|
||||
attention_sinks=attention_sinks)
|
||||
attention_sinks=attention_sinks,
|
||||
has_lora=bool(lora_params))
|
||||
|
||||
if self.attn_output_gate:
|
||||
gate = torch.sigmoid(gate)
|
||||
|
||||
@ -30,8 +30,8 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path,
|
||||
prompts, run_llm_abort_request,
|
||||
run_llm_with_postprocess_parallel_and_result_handler,
|
||||
tinyllama_logits_processor_test_harness)
|
||||
from utils.util import (force_ampere, similar, skip_fp8_pre_ada,
|
||||
skip_gpu_memory_less_than_40gb,
|
||||
from utils.util import (force_ampere, similar, similarity_score,
|
||||
skip_fp8_pre_ada, skip_gpu_memory_less_than_40gb,
|
||||
skip_gpu_memory_less_than_80gb,
|
||||
skip_gpu_memory_less_than_138gb, skip_ray)
|
||||
from utils.llm_data import llm_models_root
|
||||
@ -629,6 +629,42 @@ def test_llama_3_1_8b_fp8_with_bf16_lora(cuda_graph_config) -> None:
|
||||
assert similar(output.outputs[0].text, reference)
|
||||
|
||||
|
||||
@skip_ray # https://nvbugs/5682551
|
||||
@skip_gpu_memory_less_than_80gb
|
||||
def test_llama_3_3_70b_fp8_with_squad_lora_tp2() -> None:
|
||||
skip_fp8_pre_ada(use_fp8=True)
|
||||
|
||||
model_dir = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8"
|
||||
lora_dir = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8-lora-adapter_NIM_r8"
|
||||
|
||||
prompt = "What is the capital of the United States?"
|
||||
expected_output = " Washington, D.C.\nWhat is the capital of the United States? Washington, D.C."
|
||||
|
||||
lora_config = LoraConfig(lora_dir=[lora_dir],
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
lora_req = LoRARequest("squad-lora", 0, lora_dir)
|
||||
|
||||
llm = LLM(model_dir,
|
||||
tensor_parallel_size=2,
|
||||
lora_config=lora_config,
|
||||
cuda_graph_config=None)
|
||||
|
||||
try:
|
||||
output = llm.generate(prompt,
|
||||
SamplingParams(max_tokens=50, temperature=0.0),
|
||||
lora_request=[lora_req])
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated output: {repr(generated_text)}")
|
||||
|
||||
similarity = similarity_score(generated_text, expected_output)
|
||||
assert similar(generated_text, expected_output, threshold=0.8), \
|
||||
f"Output similarity too low (similarity={similarity:.2%})!\nExpected: {repr(expected_output)}\nGot: {repr(generated_text)}"
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_80gb
|
||||
@pytest.mark.part2
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
|
||||
Loading…
Reference in New Issue
Block a user