TensorRT-LLMs/tests/_torch/multi_gpu/test_model.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

133 lines
4.5 KiB
Python

import os
import sys
from difflib import SequenceMatcher
import pytest
import torch
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi.llm_utils import CalibConfig
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from utils.llm_data import llm_models_root
from utils.util import getSMVersion
MAX_SEQ_LEN = 2048
def similar(a, b, threshold=0.9):
"similar compare a and b "
return SequenceMatcher(None, a, b).ratio() >= threshold
@pytest.mark.parametrize("model_name", ["llama-3.1-model/Meta-Llama-3.1-8B"],
ids=["llama-3.1-8b"])
@pytest.mark.parametrize("quant", ["bf16", "fp8", "fp8_kv_cache"])
@pytest.mark.parametrize("tp_size", [1, 4], ids=["tp1", "tp4"])
@pytest.mark.parametrize("torch_compile", [True, False],
ids=["torch_compile", "eager"])
def test_model(model_name, quant, tp_size, torch_compile):
quant_configs = {
"bf16":
QuantConfig(),
"fp8":
QuantConfig(quant_algo=QuantAlgo.FP8),
"fp8_kv_cache":
QuantConfig(
quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8,
),
}
quant_config = quant_configs[quant]
is_fp8 = quant_config.quant_algo == QuantAlgo.FP8
is_fp8_kv_cache = quant_config.kv_cache_quant_algo == QuantAlgo.FP8
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs available, need {tp_size} "
f"but only have {torch.cuda.device_count()}")
if is_fp8 and getSMVersion() < 90:
pytest.skip(f"FP8 is not supported in this SM version {getSMVersion()}")
# 8GB weight + 8GB KV cache + 8GB cache_indirection (TRT engine only) = 24GB
if is_fp8 and get_total_gpu_memory(0) < 24 * 1024**3:
pytest.skip("Not enough GPU memory to run FP8 model")
# 16GB weight + 8GB KV cache + 8GB cache_indirection (TRT engine only) = 32GB
if not is_fp8 and get_total_gpu_memory(0) < 32 * 1024**3:
pytest.skip("Not enough GPU memory to run BF16 model")
prompts = [
"The president of the United States is",
]
expected_outputs = [
" the head of state and head of government of the",
]
pytorch_config = PyTorchConfig(
torch_compile_enabled=torch_compile,
cuda_graph_padding_enabled=torch_compile,
cuda_graph_batch_sizes=[4],
)
if is_fp8_kv_cache:
pytorch_config.kv_cache_dtype = "fp8"
model_dir = str(llm_models_root() / model_name)
if is_fp8:
fp8_model_names = {
"llama-3.1-model/Meta-Llama-3.1-8B":
"llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
}
model_dir = str(llm_models_root() / fp8_model_names[model_name])
llm = LLM(
model=model_dir,
tensor_parallel_size=tp_size,
quant_config=quant_config,
pytorch_backend_config=pytorch_config,
calib_config=CalibConfig(calib_dataset=str(llm_models_root() /
"datasets/cnn_dailymail")),
)
with llm:
outputs = llm.generate(
prompts,
sampling_params=SamplingParams(max_tokens=10),
)
assert len(outputs) == len(expected_outputs), "Output length mismatch"
for output, expected in zip(outputs, expected_outputs):
output_text = output.outputs[0].text
print(output_text)
print(output.outputs[0].token_ids)
assert similar(
output_text,
expected), f"Expected '{expected}' but get '{output_text}'"
if __name__ == '__main__':
test_model("llama-3.1-model/Meta-Llama-3.1-8B",
"bf16",
1,
torch_compile=False)
test_model("llama-3.1-model/Meta-Llama-3.1-8B",
"bf16",
4,
torch_compile=False)
test_model("llama-3.1-model/Meta-Llama-3.1-8B",
"fp8",
1,
torch_compile=False)
test_model("llama-3.1-model/Meta-Llama-3.1-8B",
"fp8",
4,
torch_compile=False)
test_model("llama-3.1-model/Meta-Llama-3.1-8B",
"fp8_kv_cache",
1,
torch_compile=False)
test_model("llama-3.1-model/Meta-Llama-3.1-8B",
"fp8_kv_cache",
4,
torch_compile=False)