TensorRT-LLMs/tests/llmapi/test_llm_quant.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

77 lines
2.7 KiB
Python

import pytest
from tensorrt_llm.llmapi import LLM, KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
# isort: off
from test_llm import cnn_dailymail_path, llama_model_path, get_model_path
from utils.util import skip_blackwell, skip_pre_ampere, skip_pre_blackwell, skip_pre_hopper
# isort: on
@skip_pre_ampere
@skip_blackwell
def test_llm_int4_awq_quantization():
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ)
assert quant_config.quant_mode.has_any_quant()
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(llama_model_path,
quant_config=quant_config,
calib_config=calib_config)
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == "D E F G H I"
@skip_pre_hopper
def test_llm_fp8_quantization():
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8)
assert quant_config.quant_mode.has_any_quant()
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(llama_model_path,
quant_config=quant_config,
calib_config=calib_config)
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == "D E F G H I"
@skip_pre_blackwell
def test_llm_nvfp4_quantization():
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4,
kv_cache_quant_algo=QuantAlgo.FP8)
assert quant_config.quant_mode.has_any_quant()
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(llama_model_path,
quant_config=quant_config,
calib_config=calib_config)
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == "D E F G H I"
@skip_pre_hopper
@pytest.mark.skip("https://nvbugs/5027953")
def test_llm_fp8_quantization_modelOpt_ckpt():
llama_fp8_model_path = get_model_path(
"llama-3.1-model/Llama-3.1-8B-Instruct-FP8")
llm = LLM(llama_fp8_model_path,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4))
sampling_params = SamplingParams(max_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.outputs[0].text == " D E F G H I"
if __name__ == "__main__":
test_llm_int4_awq_quantization()
test_llm_fp8_quantization_modelOpt_ckpt()