mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-6496][feat] Add LoRa Torch tests for the latest NIM model list (#6806)
Signed-off-by: Michal Guzek <mguzek@nvidia.com>
This commit is contained in:
parent
ca8291133a
commit
38da871db3
@ -525,15 +525,23 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
# For kv cache size calculation: set size_per_head
|
||||
head_dim_names = ["head_size", "head_dim"]
|
||||
head_size = None
|
||||
for head_dim_name in head_dim_names:
|
||||
if head_dim_name in self.pretrained_config:
|
||||
head_size = getattr(self.pretrained_config, head_dim_name)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
|
||||
if hasattr(self.pretrained_config, head_dim_name):
|
||||
value = getattr(self.pretrained_config, head_dim_name)
|
||||
if value is not None:
|
||||
head_size = value
|
||||
break
|
||||
|
||||
if head_size is None:
|
||||
assert hidden_size % num_heads == 0, (
|
||||
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})"
|
||||
)
|
||||
head_size = hidden_size // num_heads
|
||||
calculated_head_size = hidden_size // num_heads
|
||||
logger.warning(
|
||||
f"head_size/head_dim is not set or None, using default value {calculated_head_size}"
|
||||
)
|
||||
head_size = calculated_head_size
|
||||
|
||||
model_config_cpp.mlp_hidden_size = mlp_hidden_size
|
||||
model_config_cpp.size_per_head = head_size
|
||||
|
||||
@ -22,6 +22,11 @@ from pathlib import Path
|
||||
|
||||
from packaging import version
|
||||
|
||||
from tensorrt_llm import LLM as LLM_torch
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from .trt_test_alternative import check_call, check_output, exists, is_windows
|
||||
|
||||
|
||||
@ -739,12 +744,28 @@ def generate_dummy_loras(
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
print("Creating pseudo LoRAs...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
hf_model_dir,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# Avoid meta tensors by loading model to CPU first (ensures all parameters are materialized)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
hf_model_dir,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=None, # Load everything to CPU first
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to auto device mapping if CPU loading fails
|
||||
print(
|
||||
"Warning: Loading model to CPU failed, falling back to auto device mapping"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
hf_model_dir,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(r=lora_rank,
|
||||
target_modules=target_modules,
|
||||
bias="none",
|
||||
@ -755,12 +776,57 @@ def generate_dummy_loras(
|
||||
if zero_weights:
|
||||
for param in lora_model.parameters():
|
||||
param.data.zero_()
|
||||
|
||||
pseudo_lora_dir = f"{lora_output_dir}/pseudo_lora_{lora_idx}"
|
||||
lora_model.save_pretrained(pseudo_lora_dir)
|
||||
lora_output_paths.append(pseudo_lora_dir)
|
||||
return lora_output_paths
|
||||
|
||||
|
||||
def get_test_prompts(use_code_prompts: bool = False) -> list[str]:
|
||||
"""Get test prompts for LoRA testing.
|
||||
|
||||
Args:
|
||||
use_code_prompts: If True, return code-related prompts. If False, return general prompts.
|
||||
|
||||
Returns:
|
||||
List of test prompts.
|
||||
"""
|
||||
if use_code_prompts:
|
||||
return [
|
||||
"Write a function that outputs the fibonacci sequence.",
|
||||
"Convert the following C++ code to Python: x = 0;x++;",
|
||||
"Find the largest prime factor of 42.",
|
||||
"write a unit test for this function: $(cat fib.py)",
|
||||
"# A simple python function to remove whitespace from a string:",
|
||||
"How to load CodeLlama from HuggingFace?",
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"Hey how are you doing today?",
|
||||
"How is the weather in Seattle, WA?",
|
||||
"Is it ok to fill diesel in a petrol car?",
|
||||
"Can you check the top 5 trending songs on spotify?",
|
||||
"What is the capital of France?",
|
||||
"How to load CodeLlama from HuggingFace?",
|
||||
]
|
||||
|
||||
|
||||
def get_test_prompts_for_torch() -> list[str]:
|
||||
"""Get test prompts for LoRA Torch testing.
|
||||
|
||||
Returns:
|
||||
List of test prompts.
|
||||
"""
|
||||
return [
|
||||
"Hey how are you doing today?",
|
||||
"How is the weather in Seattle, WA?",
|
||||
"Is it ok to fill diesel in a petrol car?",
|
||||
"Can you check the top 5 trending songs on spotify?",
|
||||
"What is the capital of France?",
|
||||
]
|
||||
|
||||
|
||||
def test_multi_lora_support(
|
||||
hf_model_dir,
|
||||
tllm_ckpt_dir,
|
||||
@ -815,24 +881,7 @@ def test_multi_lora_support(
|
||||
print(
|
||||
f"Build engines completed in {(build_end - build_start):.2f} seconds.")
|
||||
|
||||
if use_code_prompts:
|
||||
input_prompts = [
|
||||
"Write a function that outputs the fibonacci sequence.",
|
||||
"Convert the following C++ code to Python: x = 0;x++;",
|
||||
"Find the largest prime factor of 42.",
|
||||
"write a unit test for this function: $(cat fib.py)",
|
||||
"# A simple python function to remove whitespace from a string:",
|
||||
"How to load CodeLlama from HuggingFace?",
|
||||
]
|
||||
else:
|
||||
input_prompts = [
|
||||
"Hey how are you doing today?",
|
||||
"How is the weather in Seattle, WA?",
|
||||
"Is it ok to fill diesel in a petrol car?",
|
||||
"Can you check the top 5 trending songs on spotify?",
|
||||
"What is the capital of France?",
|
||||
"How to load CodeLlama from HuggingFace?",
|
||||
]
|
||||
input_prompts = get_test_prompts(use_code_prompts)
|
||||
|
||||
print("Run inference with C++ runtime with pybind...")
|
||||
inference_start = time.time()
|
||||
@ -867,6 +916,116 @@ def test_multi_lora_support(
|
||||
)
|
||||
|
||||
|
||||
def test_llm_torch_multi_lora_support(
|
||||
hf_model_dir,
|
||||
llm_venv,
|
||||
num_loras=2,
|
||||
lora_rank=8,
|
||||
target_hf_modules=["q_proj", "k_proj", "v_proj"],
|
||||
target_trtllm_modules=["attn_q", "attn_k", "attn_v"],
|
||||
zero_lora_weights=True,
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
expected_outputs=None):
|
||||
"""Test multi-LoRA support with LLM-API Torch backend."""
|
||||
|
||||
# if expected_outputs is None:
|
||||
# raise ValueError("expected_outputs must be provided for exact validation")
|
||||
|
||||
start_time = time.time()
|
||||
print("Creating dummy LoRAs...")
|
||||
lora_start = time.time()
|
||||
|
||||
lora_paths = generate_dummy_loras(
|
||||
hf_model_dir=hf_model_dir,
|
||||
lora_output_dir=llm_venv.get_working_directory(),
|
||||
num_loras=num_loras,
|
||||
lora_rank=lora_rank,
|
||||
target_modules=target_hf_modules,
|
||||
zero_weights=zero_lora_weights)
|
||||
lora_end = time.time()
|
||||
print(
|
||||
f"Creating dummy LoRAs completed in {(lora_end - lora_start):.2f} seconds."
|
||||
)
|
||||
|
||||
print("Initializing LLM_torch with LoRA support...")
|
||||
init_start = time.time()
|
||||
|
||||
lora_config = LoraConfig(lora_dir=lora_paths,
|
||||
max_lora_rank=lora_rank,
|
||||
max_loras=num_loras,
|
||||
max_cpu_loras=num_loras,
|
||||
lora_target_modules=target_trtllm_modules)
|
||||
|
||||
input_prompts = get_test_prompts_for_torch()
|
||||
|
||||
with LLM_torch(
|
||||
model=hf_model_dir,
|
||||
lora_config=lora_config,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
dtype="bfloat16",
|
||||
max_batch_size=8, # From original test
|
||||
max_input_len=512, # From original test
|
||||
max_seq_len=562, # From original test
|
||||
max_beam_width=1 # From original test
|
||||
) as llm:
|
||||
|
||||
init_end = time.time()
|
||||
print(
|
||||
f"LLM_torch initialization completed in {(init_end - init_start):.2f} seconds."
|
||||
)
|
||||
|
||||
print("Running inference with LLM-API Torch backend...")
|
||||
inference_start = time.time()
|
||||
|
||||
# Create LoRA requests for different adapters
|
||||
lora_requests = []
|
||||
for i in range(len(input_prompts)):
|
||||
if i % 2 == 1: # Add some requests without LoRA
|
||||
lora_requests.append(None)
|
||||
else: # With LoRA
|
||||
lora_requests.append(
|
||||
LoRARequest(f"lora-{i}", i,
|
||||
lora_paths[i % len(lora_paths)]))
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=30,
|
||||
top_p=0.5,
|
||||
top_k=0,
|
||||
temperature=0.0)
|
||||
|
||||
outputs = llm.generate(input_prompts,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_requests)
|
||||
|
||||
inference_end = time.time()
|
||||
print(
|
||||
f"Inference completed in {(inference_end - inference_start):.2f} seconds."
|
||||
)
|
||||
|
||||
# Validate exact outputs
|
||||
print("Validating exact outputs...")
|
||||
assert len(outputs) == len(expected_outputs), \
|
||||
f"Expected {len(expected_outputs)} outputs, got {len(outputs)}"
|
||||
|
||||
for i, (output, expected) in enumerate(zip(outputs, expected_outputs)):
|
||||
actual_text = output.outputs[0].text
|
||||
print(f"Prompt {i+1}: {input_prompts[i]}")
|
||||
print(
|
||||
f"LoRA: {lora_requests[i].lora_int_id if lora_requests[i] else 'None'}"
|
||||
)
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Actual: {actual_text}")
|
||||
print("-" * 50)
|
||||
|
||||
# Exact string comparison
|
||||
assert actual_text == expected, \
|
||||
f"Output {i+1} mismatch:\nExpected: {expected!r}\nActual: {actual_text!r}"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"Total test execution time: {total_time:.2f} seconds")
|
||||
|
||||
|
||||
def get_dummy_spec_decoding_heads(hf_model_dir,
|
||||
save_dir,
|
||||
mode='medusa',
|
||||
|
||||
@ -1015,6 +1015,9 @@ def llama_model_root(request):
|
||||
elif request.param == "llama-3.1-8b-instruct-hf-fp8":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.1-model",
|
||||
"Llama-3.1-8B-Instruct-FP8")
|
||||
elif request.param == "llama-3.1-8b-instruct":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.1-model",
|
||||
"Llama-3.1-8B-Instruct")
|
||||
elif request.param == "llama-3.1-8b-hf-nvfp4":
|
||||
llama_model_root = os.path.join(models_root, "nvfp4-quantized",
|
||||
"Meta-Llama-3.1-8B")
|
||||
@ -1024,9 +1027,18 @@ def llama_model_root(request):
|
||||
elif request.param == "llama-3.2-1b":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.2-models",
|
||||
"Llama-3.2-1B")
|
||||
elif request.param == "llama-3.2-1b-instruct":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.2-models",
|
||||
"Llama-3.2-1B-Instruct")
|
||||
elif request.param == "llama-3.2-3b":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.2-models",
|
||||
"Llama-3.2-3B")
|
||||
elif request.param == "llama-3.2-3b-instruct":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.2-models",
|
||||
"Llama-3.2-3B-Instruct")
|
||||
elif request.param == "llama-3.3-70b-instruct":
|
||||
llama_model_root = os.path.join(models_root, "llama-3.3-models",
|
||||
"Llama-3.3-70B-Instruct")
|
||||
assert os.path.exists(
|
||||
llama_model_root
|
||||
), f"{llama_model_root} does not exist under NFS LLM_MODELS_ROOT dir"
|
||||
@ -1323,6 +1335,11 @@ def llm_lora_model_root(request):
|
||||
elif item == "komt-mistral-7b-v1-lora":
|
||||
model_root_list.append(
|
||||
os.path.join(models_root, "komt-mistral-7b-v1-lora"))
|
||||
elif item == "Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32":
|
||||
model_root_list.append(
|
||||
os.path.join(
|
||||
models_root, "nemotron-nas",
|
||||
"Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32"))
|
||||
|
||||
return ",".join(model_root_list)
|
||||
|
||||
@ -1363,6 +1380,8 @@ def llm_mistral_model_root(request):
|
||||
model_root = os.path.join(models_root, "mistral-7b-v0.1")
|
||||
if request.param == "mistral-7b-v0.1":
|
||||
model_root = os.path.join(models_root, "mistral-7b-v0.1")
|
||||
if request.param == "mistral-nemo-instruct-2407":
|
||||
model_root = os.path.join(models_root, "Mistral-Nemo-Instruct-2407")
|
||||
if request.param == "komt-mistral-7b-v1":
|
||||
model_root = os.path.join(models_root, "komt-mistral-7b-v1")
|
||||
if request.param == "mistral-7b-v0.3":
|
||||
|
||||
@ -25,9 +25,10 @@ import defs.ci_profiler
|
||||
import pytest
|
||||
from defs.common import (convert_weights, generate_summary_cmd,
|
||||
get_cpp_benchmark, get_trt_llm_lib_dir, parse_output,
|
||||
quantize_data, similar, test_multi_lora_support,
|
||||
venv_check_call, venv_check_output,
|
||||
venv_mpi_check_call)
|
||||
quantize_data, similar,
|
||||
test_llm_torch_multi_lora_support,
|
||||
test_multi_lora_support, venv_check_call,
|
||||
venv_check_output, venv_mpi_check_call)
|
||||
# yapf: disable
|
||||
from defs.conftest import (get_device_count, get_device_memory,
|
||||
get_host_total_memory, get_sm_version,
|
||||
@ -4039,6 +4040,87 @@ def test_llama_3_x_fp8_with_bf16_lora(llama_example_root, llm_datasets_root,
|
||||
)
|
||||
|
||||
|
||||
@skip_pre_ada
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("llama_model_root", [
|
||||
'llama-v3-8b-instruct-hf',
|
||||
'llama-3.1-8b-instruct',
|
||||
'llama-3.2-1b-instruct',
|
||||
'llama-3.2-3b-instruct',
|
||||
'llama-3.3-70b-instruct',
|
||||
],
|
||||
indirect=True)
|
||||
def test_llama_3_x_with_bf16_lora_torch(llama_example_root, llm_datasets_root,
|
||||
qcache_dir_without_install_package,
|
||||
llm_venv, engine_dir, llama_model_root):
|
||||
"""Run Llama models with multiple dummy LoRAs using LLM-API Torch backend."""
|
||||
|
||||
if "llama-3.3-70b-instruct" in llama_model_root.lower():
|
||||
tensor_parallel_size = 8
|
||||
if get_device_count() < 8:
|
||||
pytest.skip(
|
||||
"Skipping: llama-3.3-70b-instruct model requires 8 GPUs")
|
||||
else:
|
||||
tensor_parallel_size = 1
|
||||
|
||||
expected_outputs = {
|
||||
'llama-v3-8b-instruct-hf': [
|
||||
" I hope you're having a great day! I just wanted to reach out and say hi, and see if you're doing okay. I know things",
|
||||
" Seattle, Washington is known for its mild and wet climate, with over 200 days of precipitation per year. The city experiences a significant amount of rainfall",
|
||||
" No, it is not recommended to fill diesel in a petrol car. Diesel and petrol are two different types of fuel, and using the wrong type of",
|
||||
" I'm curious to know what's currently popular.\nI can help you with that! As of now, the top 5 trending songs on Spotify are",
|
||||
" Paris\nWhat is the capital of Germany? Berlin\nWhat is the capital of Italy? Rome\nWhat is the capital of Spain? Madrid\nWhat"
|
||||
],
|
||||
'llama-3.1-8b-instruct': [
|
||||
" I'm doing pretty well, thanks for asking. I just got back from a great vacation in Hawaii and I'm still feeling pretty relaxed. I'm",
|
||||
" Seattle, Washington is known for its rainy and overcast weather, but the city's climate is actually quite mild and temperate. The city experiences a",
|
||||
" | What happens if you put diesel in a petrol car?\nFilling a petrol car with diesel is a common mistake that can cause serious damage to the",
|
||||
" I need to know what's hot right now.\nI can check the top 5 trending songs on Spotify for you. However, please note that the",
|
||||
" Paris\nWhat is the capital of France?\nThe capital of France is Paris. Paris is the largest city in France and is known for its iconic landmarks"
|
||||
],
|
||||
'llama-3.2-1b-instruct': [
|
||||
" I'm doing great, thanks for asking! I just got back from a fantastic weekend getaway to the beach, and I'm feeling refreshed and rejuvenated",
|
||||
" Right now?\nI'm planning a trip to Seattle and I want to know what the weather is like. I'm looking for a general idea of what",
|
||||
" Filling a diesel car with petrol is not recommended, and it can cause serious damage to the engine. Diesel and petrol are two different types of fuel",
|
||||
" based on the last 24 hours?\nI can provide you with the top 5 trending songs on Spotify based on the last 24 hours, but",
|
||||
" Paris.\nThe capital of France is Paris. Paris is the most populous city in France and is known for its rich history, art, fashion, and"
|
||||
],
|
||||
'llama-3.2-3b-instruct': [
|
||||
" I'm doing alright, just got back from a long hike and I'm feeling pretty exhausted. Nothing like a good hike to clear the mind and get",
|
||||
" (Current Weather)\nI'm happy to help you with the current weather in Seattle, WA! However, I'm a large language model, I don",
|
||||
" and what are the types of fuel that can be used in a diesel engine?\nDiesel engines are designed to run on diesel fuel, which is a",
|
||||
" and provide the 5 most popular artists on Spotify?\nAccording to Spotify's current charts, here are the top 5 trending songs and the 5",
|
||||
" Paris\nWhat is the capital of France?\nThe capital of France is indeed Paris. Located in the north-central part of the country, Paris is a"
|
||||
],
|
||||
'llama-3.3-70b-instruct': [
|
||||
" I hope you are having a great day. I am doing well, thanks for asking. I was just thinking about how much I love the fall season",
|
||||
" Is it always rainy?\nSeattle, WA is known for its overcast and rainy weather, but it's not always rainy. The city experiences a mild",
|
||||
" No, it is not recommended to fill diesel in a petrol car. Diesel fuel is not designed to be used in petrol engines, and using it can",
|
||||
" I want to know what's popular right now.\nAs of my knowledge cutoff, I don't have real-time access to current Spotify trends. However,",
|
||||
" Paris\nWhat is the capital of Germany? Berlin\nWhat is the capital of Italy? Rome\nWhat is the capital of Spain? Madrid\nWhat"
|
||||
],
|
||||
}
|
||||
|
||||
print("Testing with LLM-API Torch backend...")
|
||||
|
||||
defs.ci_profiler.start("test_llm_torch_multi_lora_support")
|
||||
|
||||
model_name = os.path.basename(llama_model_root).lower()
|
||||
test_llm_torch_multi_lora_support(
|
||||
hf_model_dir=llama_model_root,
|
||||
llm_venv=llm_venv,
|
||||
num_loras=2,
|
||||
lora_rank=8,
|
||||
target_hf_modules=["q_proj", "k_proj", "v_proj"],
|
||||
zero_lora_weights=True,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
expected_outputs=expected_outputs[model_name])
|
||||
defs.ci_profiler.stop("test_llm_torch_multi_lora_support")
|
||||
print(
|
||||
f"test_llm_torch_multi_lora_support: {defs.ci_profiler.elapsed_time_in_sec('test_llm_torch_multi_lora_support')} sec"
|
||||
)
|
||||
|
||||
|
||||
@skip_pre_ada
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("mistral_nemo_model_root", ['Mistral-Nemo-12b-Base'],
|
||||
|
||||
@ -14,13 +14,16 @@
|
||||
# limitations under the License.
|
||||
"""Module test_mistral test mistral examples."""
|
||||
import multiprocessing
|
||||
import platform
|
||||
import os
|
||||
|
||||
import defs.ci_profiler
|
||||
import psutil
|
||||
import pytest
|
||||
from defs.common import (convert_weights, quantize_data,
|
||||
test_llm_torch_multi_lora_support,
|
||||
test_multi_lora_support, venv_check_call)
|
||||
from defs.conftest import get_sm_version, skip_post_blackwell, skip_pre_ada
|
||||
from defs.conftest import (get_device_count, get_sm_version,
|
||||
skip_post_blackwell, skip_pre_ada)
|
||||
from defs.trt_test_alternative import check_call
|
||||
|
||||
# skip trt flow cases on post-Blackwell-Ultra
|
||||
@ -45,25 +48,6 @@ def get_optimal_jobs():
|
||||
return optimal_jobs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def mistral_example_root(llm_venv):
|
||||
if platform.system() != "Windows":
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/345
|
||||
# No wheel for flash-attn on windows and compilation fails locally.
|
||||
max_jobs = get_optimal_jobs()
|
||||
install_cmd = [
|
||||
f"MAX_JOBS={max_jobs}",
|
||||
"python3",
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"--upgrade",
|
||||
"flash-attn==2.4.2",
|
||||
]
|
||||
|
||||
check_call(" ".join(install_cmd), shell=True, env=llm_venv._new_env)
|
||||
|
||||
|
||||
@skip_post_blackwell #nvbug 5298661
|
||||
@pytest.mark.parametrize(
|
||||
"run_type",
|
||||
@ -295,3 +279,60 @@ def test_mistral_nemo_minitron_fp8_with_bf16_lora(
|
||||
target_trtllm_modules=["attn_q", "attn_k", "attn_v"],
|
||||
zero_lora_weights=True,
|
||||
)
|
||||
|
||||
|
||||
@skip_pre_ada
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("llm_mistral_model_root", [
|
||||
'mistral-7b-v0.1',
|
||||
'mistral-nemo-instruct-2407',
|
||||
],
|
||||
indirect=True)
|
||||
def test_mistral_with_bf16_lora_torch(llama_example_root, llm_datasets_root,
|
||||
qcache_dir_without_install_package,
|
||||
llm_venv, engine_dir,
|
||||
llm_mistral_model_root):
|
||||
"""Run Mistral models with multiple dummy LoRAs using LLM-API Torch backend."""
|
||||
|
||||
if "mistral-nemo-instruct-2407" in llm_mistral_model_root.lower():
|
||||
tensor_parallel_size = 2
|
||||
if get_device_count() < 2:
|
||||
pytest.skip(
|
||||
"Skipping: mistral-nemo-instruct-2407 model requires 2 GPUs")
|
||||
else:
|
||||
tensor_parallel_size = 1
|
||||
|
||||
expected_outputs = {
|
||||
'mistral-7b-v0.1': [
|
||||
"I hope you’re doing well. I’m doing well. I’m doing well. I’m doing well. I’m doing",
|
||||
"\n\nSeattle, WA Weather Forecast. Today's weather in Seattle, WA. 59°F. 15°",
|
||||
"\n\nNo, it is not ok to fill diesel in a petrol car. Diesel is a heavier fuel than petrol and will",
|
||||
"\n\nYes, you can check the top 5 trending songs on Spotify. To do this, go to the Spotify website and sign",
|
||||
"\n\nParis is the capital of France.\n\nWhat is the capital of the United States?\n\nWashington, D.C."
|
||||
],
|
||||
'mistral-nemo-instruct-2407': [
|
||||
" I'm doing fine, thanks for asking! How can I assist you today? Let me know if you have any questions or just want to chat!",
|
||||
" Seattle, WA is currently experiencing a temperature of 55°F (13°C) with a chance of rain. The weather is typically cloud",
|
||||
" I have a 2005 Honda City. I have filled diesel in my car by mistake. I have driven the car for about 1",
|
||||
" I'm using python and I've tried using the spotipy library but I can't seem to get it to work. I'm not sure if it",
|
||||
" Paris\n\nThe capital of France is Paris. It is the largest city in the country and is known for its iconic landmarks such as the Eiffel"
|
||||
],
|
||||
}
|
||||
|
||||
print(f"Testing {llm_mistral_model_root} with LLM-API Torch backend...")
|
||||
|
||||
defs.ci_profiler.start("test_llm_torch_multi_lora_support")
|
||||
model_name = os.path.basename(llm_mistral_model_root).lower()
|
||||
test_llm_torch_multi_lora_support(
|
||||
hf_model_dir=llm_mistral_model_root,
|
||||
llm_venv=llm_venv,
|
||||
num_loras=2,
|
||||
lora_rank=8,
|
||||
target_hf_modules=["q_proj", "k_proj", "v_proj"],
|
||||
zero_lora_weights=True,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
expected_outputs=expected_outputs[model_name])
|
||||
defs.ci_profiler.stop("test_llm_torch_multi_lora_support")
|
||||
print(
|
||||
f"test_llm_torch_multi_lora_support: {defs.ci_profiler.elapsed_time_in_sec('test_llm_torch_multi_lora_support')} sec"
|
||||
)
|
||||
|
||||
@ -1,10 +1,18 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import defs.ci_profiler
|
||||
import pytest
|
||||
from defs.common import convert_weights, venv_check_call, venv_mpi_check_call
|
||||
from defs.common import (convert_weights, test_llm_torch_multi_lora_support,
|
||||
venv_check_call, venv_mpi_check_call)
|
||||
from defs.conftest import get_device_memory, get_sm_version
|
||||
from defs.trt_test_alternative import check_call
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
# skip trt flow cases on post-Blackwell-Ultra
|
||||
if get_sm_version() >= 103:
|
||||
pytest.skip(
|
||||
@ -122,3 +130,155 @@ def test_nemotron_nas_summary_2gpu(nemotron_nas_example_root, llm_venv,
|
||||
]
|
||||
|
||||
venv_mpi_check_call(llm_venv, mpi_cmd, summary_cmd)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("nemotron_nas_model_root", [
|
||||
"Llama-3.1-Nemotron-Nano-8B-v1",
|
||||
],
|
||||
indirect=True)
|
||||
def test_nemotron_nano_8b_lora_torch(nemotron_nas_example_root, llm_venv,
|
||||
nemotron_nas_model_root, llm_datasets_root,
|
||||
llm_rouge_root, engine_dir, cmodel_dir):
|
||||
"""Run Nemotron Nano 8B with multiple dummy LoRAs using LLM-API Torch backend."""
|
||||
|
||||
expected_outputs = {
|
||||
'llama-3.1-nemotron-nano-8b-v1': [
|
||||
" I am having a bit of a problem with my computer. The screen is black, but my monitor is still giving me the same signals. The brightness",
|
||||
" How is the climate like? What are some of the typical foods and drinks of the region? What is the economy like? How does the city compare",
|
||||
" I have heard that it's possible but can be dangerous. What are the potential risks? Are there any safety guidelines? I should probably check some references",
|
||||
" I can't do that right now. But I can suggest that if you're interested in music trends, you can check Spotify's \"Discover Weekly\"",
|
||||
" The capital of France is Paris. But wait, I think there's another city called Paris. No, no, that's the same city. Maybe"
|
||||
],
|
||||
}
|
||||
|
||||
print("Testing with LLM-API Torch backend...")
|
||||
|
||||
defs.ci_profiler.start("test_llm_torch_multi_lora_support")
|
||||
|
||||
model_name = os.path.basename(nemotron_nas_model_root).lower()
|
||||
test_llm_torch_multi_lora_support(
|
||||
hf_model_dir=nemotron_nas_model_root,
|
||||
llm_venv=llm_venv,
|
||||
num_loras=2,
|
||||
lora_rank=8,
|
||||
target_hf_modules=["q_proj", "k_proj", "v_proj"],
|
||||
zero_lora_weights=True,
|
||||
tensor_parallel_size=1,
|
||||
expected_outputs=expected_outputs[model_name])
|
||||
defs.ci_profiler.stop("test_llm_torch_multi_lora_support")
|
||||
print(
|
||||
f"test_llm_torch_multi_lora_support: {defs.ci_profiler.elapsed_time_in_sec('test_llm_torch_multi_lora_support')} sec"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("nemotron_nas_model_root", [
|
||||
"Llama-3_3-Nemotron-Super-49B-v1",
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"llm_lora_model_root",
|
||||
['Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32'],
|
||||
indirect=True)
|
||||
def test_nemotron_super_49b_real_lora_torch(nemotron_nas_example_root, llm_venv,
|
||||
nemotron_nas_model_root,
|
||||
llm_lora_model_root,
|
||||
llm_datasets_root, llm_rouge_root,
|
||||
engine_dir, cmodel_dir):
|
||||
"""Run Nemotron Super 49B with real LoRA adapters using LLM-API Torch backend."""
|
||||
|
||||
print("Testing Nemotron Super 49B with real LoRA adapters...")
|
||||
|
||||
print(f"Using real LoRA from: {llm_lora_model_root}")
|
||||
|
||||
defs.ci_profiler.start("test_nemotron_real_lora_torch")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
lora_dir=[llm_lora_model_root],
|
||||
max_lora_rank=32, # From adapter_config.json: "r": 32
|
||||
max_loras=1,
|
||||
max_cpu_loras=1,
|
||||
)
|
||||
|
||||
with LLM(model=nemotron_nas_model_root,
|
||||
lora_config=lora_config,
|
||||
tensor_parallel_size=4,
|
||||
dtype="bfloat16",
|
||||
max_batch_size=2,
|
||||
max_input_len=512,
|
||||
max_seq_len=1024,
|
||||
max_beam_width=1,
|
||||
load_format="dummy") as llm:
|
||||
|
||||
prompts = [
|
||||
"What is the capital of France?",
|
||||
"Explain quantum computing in simple terms."
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=50,
|
||||
temperature=0.7,
|
||||
top_p=0.9)
|
||||
|
||||
lora_request = [
|
||||
LoRARequest("nemotron-lora", 0, llm_lora_model_root),
|
||||
LoRARequest("nemotron-lora", 1, llm_lora_model_root)
|
||||
]
|
||||
|
||||
print("Running inference with real LoRA adapter...")
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
print(f"Prompt {i+1}: {prompts[i]}")
|
||||
print(f"Response {i+1}: {output.outputs[0].text}")
|
||||
print("-" * 50)
|
||||
|
||||
assert len(outputs) == 2
|
||||
assert len(outputs[0].outputs) > 0
|
||||
assert len(outputs[1].outputs) > 0
|
||||
assert len(outputs[0].outputs[0].text) > 0
|
||||
assert len(outputs[1].outputs[0].text) > 0
|
||||
|
||||
defs.ci_profiler.stop("test_nemotron_real_lora_torch")
|
||||
print(
|
||||
f"test_nemotron_real_lora_torch: {defs.ci_profiler.elapsed_time_in_sec('test_nemotron_real_lora_torch')} sec"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO: Test OOMs on 8 GPUs - to fix")
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("nemotron_nas_model_root", [
|
||||
"Llama-3_1-Nemotron-Ultra-253B-v1",
|
||||
],
|
||||
indirect=True)
|
||||
def test_nemotron_ultra_253b_lora_torch(nemotron_nas_example_root, llm_venv,
|
||||
nemotron_nas_model_root,
|
||||
llm_datasets_root, llm_rouge_root,
|
||||
engine_dir, cmodel_dir):
|
||||
"""Run Nemotron Ultra 253B with multiple dummy LoRAs using LLM-API Torch backend."""
|
||||
|
||||
expected_outputs = {
|
||||
'Llama-3_1-Nemotron-Ultra-253B-v1': ["...", "...", "...", "...", "..."],
|
||||
}
|
||||
|
||||
print("Testing with LLM-API Torch backend...")
|
||||
|
||||
defs.ci_profiler.start("test_llm_torch_multi_lora_support")
|
||||
model_name = os.path.basename(nemotron_nas_model_root).lower()
|
||||
test_llm_torch_multi_lora_support(
|
||||
hf_model_dir=nemotron_nas_model_root,
|
||||
llm_venv=llm_venv,
|
||||
num_loras=2,
|
||||
lora_rank=8,
|
||||
target_hf_modules=["q_proj", "k_proj", "v_proj"],
|
||||
zero_lora_weights=True,
|
||||
tensor_parallel_size=8,
|
||||
expected_outputs=expected_outputs[model_name])
|
||||
defs.ci_profiler.stop("test_llm_torch_multi_lora_support")
|
||||
print(
|
||||
f"test_llm_torch_multi_lora_support: {defs.ci_profiler.elapsed_time_in_sec('test_llm_torch_multi_lora_support')} sec"
|
||||
)
|
||||
|
||||
@ -15,8 +15,10 @@
|
||||
import csv
|
||||
import os
|
||||
|
||||
import defs.ci_profiler
|
||||
import pytest
|
||||
from defs.common import (convert_weights, quantize_data,
|
||||
test_llm_torch_multi_lora_support,
|
||||
test_multi_lora_support, venv_check_call,
|
||||
venv_mpi_check_call)
|
||||
from defs.conftest import (get_device_memory, get_sm_version, skip_fp8_pre_ada,
|
||||
@ -446,3 +448,35 @@ def test_phi_fp8_with_bf16_lora(llm_phi_model_root,
|
||||
target_trtllm_modules=trtllm_target_modules[model_name],
|
||||
zero_lora_weights=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="TODO: Resolve an import issue with transformers's LossKwargs")
|
||||
@skip_pre_ada
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("llm_phi_model_root", ['Phi-4-mini-instruct'],
|
||||
indirect=True)
|
||||
def test_phi_4_mini_instruct_with_bf16_lora_torch(
|
||||
phi_example_root, llm_datasets_root, qcache_dir_without_install_package,
|
||||
llm_venv, engine_dir, llm_phi_model_root):
|
||||
"""Run Phi-4-mini-instruct with multiple dummy LoRAs using LLM-API Torch backend."""
|
||||
|
||||
expected_outputs = {
|
||||
'Phi-4-mini-instruct': ["...", "...", "...", "...", "..."],
|
||||
}
|
||||
|
||||
print("Testing with LLM-API Torch backend...")
|
||||
|
||||
defs.ci_profiler.start("test_llm_torch_multi_lora_support")
|
||||
model_name = os.path.basename(llm_phi_model_root).lower()
|
||||
test_llm_torch_multi_lora_support(
|
||||
hf_model_dir=llm_phi_model_root,
|
||||
llm_venv=llm_venv,
|
||||
num_loras=2,
|
||||
lora_rank=8,
|
||||
target_hf_modules=["qkv_proj"],
|
||||
target_trtllm_modules=["attn_qkv"],
|
||||
zero_lora_weights=True,
|
||||
tensor_parallel_size=1,
|
||||
expected_outputs=expected_outputs[model_name])
|
||||
defs.ci_profiler.stop("test_llm_torch_multi_lora_support")
|
||||
|
||||
@ -244,6 +244,10 @@ l0_h100:
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
|
||||
- examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1]
|
||||
- examples/test_phi.py::test_phi_4_mini_instruct_with_bf16_lora_torch[Phi-4-mini-instruct]
|
||||
- examples/test_llama.py::test_llama_3_x_with_bf16_lora_torch[llama-3.2-1b-instruct]
|
||||
- examples/test_nemotron_nas.py::test_nemotron_nano_8b_lora_torch[Llama-3.1-Nemotron-Nano-8B-v1]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user