TensorRT-LLMs/tests/hlapi/test_llm_quant.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

44 lines
1.2 KiB
Python

import os
import sys
import tempfile
from tensorrt_llm.hlapi.llm import LLM, ModelConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root
from utils.util import skip_pre_ampere, skip_pre_hopper
llama_model_path = str(llm_models_root() / "llama-models/llama-7b-hf")
@skip_pre_ampere
def test_llm_int4_awq_quantization():
config = ModelConfig(llama_model_path)
config.quant_config.init_from_description(quantize_weights=True,
use_int4_weights=True,
per_group=True)
config.quant_config.quantize_lm_head = True
assert config.quant_config.has_any_quant()
llm = LLM(config)
with tempfile.TemporaryDirectory() as tmpdir:
llm.save(tmpdir)
@skip_pre_hopper
def test_llm_fp8_quantization():
config = ModelConfig(llama_model_path)
config.quant_config.set_fp8_qdq()
config.quant_config.set_fp8_kv_cache()
assert config.quant_config.has_any_quant()
llm = LLM(config)
with tempfile.TemporaryDirectory() as tmpdir:
llm.save(tmpdir)
if __name__ == "__main__":
test_llm_int4_awq_quantization()
test_llm_fp8_quantization()