mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5521949][fix] Update FP8 model with BF16 LoRA test, fix test_bielik_11b_v2_2_instruct_multi_lora (#8324)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
4751bdbcb6
commit
e5476a6b2a
@ -523,15 +523,23 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
# For kv cache size calculation: set size_per_head
|
||||
head_dim_names = ["head_size", "head_dim"]
|
||||
head_size = None
|
||||
for head_dim_name in head_dim_names:
|
||||
if head_dim_name in self.pretrained_config:
|
||||
head_size = getattr(self.pretrained_config, head_dim_name)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
|
||||
if hasattr(self.pretrained_config, head_dim_name):
|
||||
value = getattr(self.pretrained_config, head_dim_name)
|
||||
if value is not None:
|
||||
head_size = value
|
||||
break
|
||||
|
||||
if head_size is None:
|
||||
assert hidden_size % num_heads == 0, (
|
||||
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})"
|
||||
)
|
||||
head_size = hidden_size // num_heads
|
||||
calculated_head_size = hidden_size // num_heads
|
||||
logger.warning(
|
||||
f"head_size/head_dim is not set or None, using default value {calculated_head_size}"
|
||||
)
|
||||
head_size = calculated_head_size
|
||||
|
||||
model_config_cpp.mlp_hidden_size = mlp_hidden_size
|
||||
model_config_cpp.size_per_head = head_size
|
||||
|
||||
@ -25,14 +25,13 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path,
|
||||
prompts, run_llm_abort_request,
|
||||
run_llm_with_postprocess_parallel_and_result_handler,
|
||||
tinyllama_logits_processor_test_harness)
|
||||
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
|
||||
from utils.util import (force_ampere, similar, skip_fp8_pre_ada,
|
||||
skip_gpu_memory_less_than_40gb,
|
||||
skip_gpu_memory_less_than_80gb,
|
||||
skip_gpu_memory_less_than_138gb)
|
||||
from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
@ -496,68 +495,36 @@ def test_nemotron_nas_lora() -> None:
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_80gb
|
||||
@pytest.mark.skip(reason="https://nvbugs/5521949")
|
||||
def test_codellama_fp8_with_bf16_lora() -> None:
|
||||
model_dir = f"{llm_models_root()}/codellama/CodeLlama-7b-Instruct-hf/"
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8)
|
||||
def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
|
||||
skip_fp8_pre_ada(use_fp8=True)
|
||||
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
|
||||
lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora"
|
||||
prompt = "美国的首都是哪里?"
|
||||
reference = "华盛顿特区。华盛顿特区是美国的首都和一个行政区"
|
||||
|
||||
target_modules = ['attn_q', 'attn_k', 'attn_v']
|
||||
lora_config = LoraConfig(lora_dir=[lora_dir],
|
||||
max_lora_rank=64,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
lora_req = LoRARequest("lora-chinese", 0, lora_dir)
|
||||
|
||||
# Set up temporary directory for LoRA adapters
|
||||
with tempfile.TemporaryDirectory() as lora_dir:
|
||||
print("Creating dummy LoRAs...")
|
||||
llm = LLM(
|
||||
model_dir,
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
hf_modules = ["q_proj", "k_proj", "v_proj"]
|
||||
|
||||
lora_config = PeftLoraConfig(r=8,
|
||||
target_modules=hf_modules,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM")
|
||||
|
||||
lora_paths = []
|
||||
for i in range(2):
|
||||
lora_model = get_peft_model(model, lora_config)
|
||||
for param in lora_model.parameters():
|
||||
param.data.zero_()
|
||||
lora_path = f"{lora_dir}/lora_{i}"
|
||||
lora_model.save_pretrained(lora_path)
|
||||
lora_paths.append(lora_path)
|
||||
|
||||
lora_config = LoraConfig(lora_dir=lora_paths,
|
||||
lora_target_modules=target_modules,
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
|
||||
llm = LLM(model_dir, quant_config=quant_config, lora_config=lora_config)
|
||||
|
||||
prompts = [
|
||||
"Write a function that calculates the Fibonacci sequence.",
|
||||
"Convert this C++ code to Python: int x = 0; x++;",
|
||||
]
|
||||
|
||||
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
|
||||
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
|
||||
lora_requests = [lora_req1, lora_req2]
|
||||
sampling_params = SamplingParams(max_tokens=200)
|
||||
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests)
|
||||
|
||||
assert len(outputs) == 2
|
||||
try:
|
||||
output = llm.generate(prompt,
|
||||
SamplingParams(max_tokens=20),
|
||||
lora_request=[lora_req])
|
||||
finally:
|
||||
llm.shutdown()
|
||||
assert similar(output.outputs[0].text, reference)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_80gb
|
||||
@pytest.mark.skip(reason="https://nvbugs/5521949")
|
||||
def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
|
||||
model_dir = f"{llm_models_root()}/Bielik-11B-v2.2-Instruct"
|
||||
|
||||
@ -584,12 +551,16 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
|
||||
lora_model.save_pretrained(lora_path)
|
||||
lora_paths.append(lora_path)
|
||||
|
||||
trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
|
||||
lora_target_modules=target_modules,
|
||||
trtllm_lora_config = LoraConfig(lora_target_modules=target_modules,
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
llm = LLM(model_dir, lora_config=trtllm_lora_config)
|
||||
llm = LLM(
|
||||
model_dir,
|
||||
lora_config=trtllm_lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
|
||||
prompts = [
|
||||
"Kim był Mikołaj Kopernik i z czego zasłynął?",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user