TensorRT-LLMs/tests/unittest/llmapi/test_llm_quant.py
xiweny 6979afa6f2
test: reorganize tests folder hierarchy (#2996)
1. move TRT path tests to 'trt' folder
2. optimize some import usage
2025-03-27 12:07:53 +08:00

76 lines
2.7 KiB
Python

import pytest
from tensorrt_llm.llmapi import LLM, KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
# isort: off
from .test_llm import cnn_dailymail_path, llama_model_path, get_model_path
from utils.util import skip_blackwell, skip_pre_blackwell, skip_pre_hopper
# isort: on
@skip_blackwell
def test_llm_int4_awq_quantization():
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ)
assert quant_config.quant_mode.has_any_quant()
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(llama_model_path,
quant_config=quant_config,
calib_config=calib_config)
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == "D E F G H I"
@skip_pre_hopper
def test_llm_fp8_quantization():
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8)
assert quant_config.quant_mode.has_any_quant()
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(llama_model_path,
quant_config=quant_config,
calib_config=calib_config)
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == "D E F G H I"
@skip_pre_blackwell
def test_llm_nvfp4_quantization():
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4,
kv_cache_quant_algo=QuantAlgo.FP8)
assert quant_config.quant_mode.has_any_quant()
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(llama_model_path,
quant_config=quant_config,
calib_config=calib_config)
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == "D E F G H I"
@skip_pre_hopper
@pytest.mark.skip("https://nvbugs/5027953")
def test_llm_fp8_quantization_modelOpt_ckpt():
llama_fp8_model_path = get_model_path(
"llama-3.1-model/Llama-3.1-8B-Instruct-FP8")
llm = LLM(llama_fp8_model_path,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4))
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == " D E F G H I"
if __name__ == "__main__":
test_llm_int4_awq_quantization()
test_llm_fp8_quantization_modelOpt_ckpt()