TensorRT-LLMs/tests/hlapi/test_llm_quant.py
Kaiyu Xie db4edea1e1
Update TensorRT-LLM (#1763)
* Update TensorRT-LLM

---------

Co-authored-by: Kota Tsuyuzaki <bloodeagle40234@gmail.com>
Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>
2024-06-11 16:59:02 +08:00

52 lines
1.5 KiB
Python

import os
import sys
from tensorrt_llm.hlapi.llm import LLM, ModelConfig, SamplingParams
from tensorrt_llm.quantization import QuantAlgo
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import skip_pre_ampere, skip_pre_hopper
try:
from .test_llm import llama_model_path
except ImportError:
from test_llm import llama_model_path
@skip_pre_ampere
def test_llm_int4_awq_quantization():
config = ModelConfig(llama_model_path)
config.quant_config.quant_algo = QuantAlgo.W4A16_AWQ
assert config.quant_config.quant_mode.has_any_quant()
llm = LLM(config)
sampling_params = SamplingParams(max_new_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.text == "D E F G H I"
@skip_pre_hopper
def test_llm_fp8_quantization():
config = ModelConfig(llama_model_path)
config.quant_config.quant_algo = QuantAlgo.FP8
config.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
config.quant_config.exclude_modules = [
'lm_head', 'router', 'vocab_embedding', 'position_embedding',
'block_embedding'
]
assert config.quant_config.quant_mode.has_any_quant()
llm = LLM(config)
sampling_params = SamplingParams(max_new_tokens=6)
for output in llm.generate(["A B C"], sampling_params=sampling_params):
print(output)
assert output.text == "D E F G H I"
if __name__ == "__main__":
test_llm_int4_awq_quantization()
test_llm_fp8_quantization()