mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
163 lines
6.2 KiB
Python
163 lines
6.2 KiB
Python
import asyncio
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from contextlib import contextmanager
|
|
|
|
from profile_utils import profile
|
|
from transformers import AutoTokenizer
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm.builder import BuildConfig, build
|
|
from tensorrt_llm.executor import GenerationExecutor, SamplingParams
|
|
from tensorrt_llm.models import LLaMAForCausalLM
|
|
from tensorrt_llm.models.llama.config import LLaMAConfig
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.llm_data import llm_models_root
|
|
from utils.util import force_ampere
|
|
|
|
tensorrt_llm.logger.set_level('verbose')
|
|
|
|
batch_input_text = [
|
|
"Born in north-east France, Soyer trained as a",
|
|
"What is large language model?"
|
|
]
|
|
batch_output_text_expected = [
|
|
"chef in Paris and London before moving to New York",
|
|
"\nLarge language model is a model that is"
|
|
]
|
|
|
|
|
|
@contextmanager
|
|
def workspace(suffix, prefix="./trtllm_workspace"):
|
|
keep_workspace = os.environ.get("TRTLLM_KEEP", False)
|
|
if not keep_workspace:
|
|
temp = tempfile.TemporaryDirectory(suffix)
|
|
yield temp.name
|
|
else:
|
|
temp = f"{prefix}/{suffix}"
|
|
os.makedirs(temp, exist_ok=True)
|
|
yield temp
|
|
|
|
|
|
# 233s on ipp1-1197: loading weights 37s, network/engine 27s, save engine: 35s, load engine (14GB) about 100s
|
|
@profile("save-and-load")
|
|
@force_ampere
|
|
def test_save_load():
|
|
'''When the engine_dir parameter of to_trt and generate is not None
|
|
to_trt() saves the engine to disk.
|
|
generate() loads engine from the disk.
|
|
This is optional, but users can store the engine into any folder they want, and use later
|
|
'''
|
|
max_batch_size, max_isl, max_osl = 8, 256, 256
|
|
hf_model_dir = str(llm_models_root() / "llama-models/llama-7b-hf")
|
|
|
|
with workspace("llama-save-load") as engine_dir:
|
|
# build and run by one llama object
|
|
llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir)
|
|
build_config = BuildConfig(max_batch_size=max_batch_size,
|
|
max_input_len=max_isl,
|
|
max_seq_len=max_osl + max_isl,
|
|
plugin_config=llama.default_plugin_config())
|
|
build_config.plugin_config.gemm_plugin = 'auto' # faster build
|
|
engine = build(llama, build_config)
|
|
engine.save(engine_dir)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(hf_model_dir)
|
|
|
|
# use context manager to make sure the __exit__ can release the resources immediately
|
|
with GenerationExecutor.create(engine_dir) as executor:
|
|
batch_input_ids = [
|
|
tokenizer.encode(inp) for inp in batch_input_text
|
|
]
|
|
outputs = executor.generate(
|
|
batch_input_ids, sampling_params=SamplingParams(max_tokens=10))
|
|
|
|
for idx, output in enumerate(outputs):
|
|
tensorrt_llm.logger.info(f"Input: {batch_input_text[idx]}")
|
|
output_text = tokenizer.decode(output.outputs[0].token_ids)
|
|
tensorrt_llm.logger.info(f'Output: {output_text}')
|
|
# note the output.text contains everything from the input, so only compare the suffix here.
|
|
assert output_text.endswith(
|
|
batch_output_text_expected[idx]
|
|
), f"Expecting and got: {batch_output_text_expected[idx]!r} Got: {output_text!r}"
|
|
|
|
|
|
@profile(tag="fake-weights")
|
|
@force_ampere
|
|
def test_high_level_fake_weights():
|
|
'''sanity to make sure the flow works.
|
|
'''
|
|
max_batch_size, max_isl, max_osl = 8, 256, 256
|
|
hf_model_dir = str(llm_models_root() / "llama-models/llama-7b-hf")
|
|
|
|
# Fake weights, skipping save and load engine. Make it faster to sanity test
|
|
config = LLaMAConfig.from_hugging_face(hf_model_dir)
|
|
llama = LLaMAForCausalLM(config)
|
|
build_config = BuildConfig(max_batch_size=max_batch_size,
|
|
max_input_len=max_isl,
|
|
max_seq_len=max_osl + max_isl,
|
|
plugin_config=llama.default_plugin_config())
|
|
build_config.plugin_config.gemm_plugin = 'auto' # faster build
|
|
build(llama, build_config)
|
|
|
|
|
|
@force_ampere
|
|
def test_async_io():
|
|
max_batch_size, max_isl, max_osl = 8, 256, 256
|
|
hf_model_dir = str(llm_models_root() / "llama-models/llama-7b-hf")
|
|
|
|
llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir)
|
|
build_config = BuildConfig(max_batch_size=max_batch_size,
|
|
max_input_len=max_isl,
|
|
max_seq_len=max_osl + max_isl)
|
|
build_config.plugin_config.gemm_plugin = 'auto' # faster build
|
|
engine = build(llama, build_config)
|
|
|
|
engine_dir = "llama-ifb"
|
|
engine_temp = tempfile.TemporaryDirectory(engine_dir)
|
|
engine_dir = engine_temp.name
|
|
engine.save(engine_dir)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(hf_model_dir)
|
|
|
|
async def main():
|
|
with GenerationExecutor.create(engine_dir) as async_engine:
|
|
|
|
async def generate_and_print(idx, inp):
|
|
result = async_engine.generate_async(
|
|
tokenizer.encode(inp),
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
streaming=False)
|
|
await result.aresult()
|
|
output_text = tokenizer.decode(result.outputs[0].token_ids)
|
|
tensorrt_llm.logger.info(output_text)
|
|
assert output_text.endswith(batch_output_text_expected[idx])
|
|
|
|
async for stream in async_engine.generate_async(
|
|
tokenizer.encode(inp),
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
streaming=True):
|
|
output_text = tokenizer.decode(stream.outputs[0].token_ids)
|
|
tensorrt_llm.logger.info(
|
|
f"prompt: {inp!r}, generation: {output_text!r}")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
tasks = []
|
|
# submit many request concurrently
|
|
for idx, inp in enumerate(batch_input_text):
|
|
task = loop.create_task(generate_and_print(idx, inp))
|
|
tasks.append(task)
|
|
|
|
# wait all task done
|
|
await asyncio.gather(*tasks)
|
|
|
|
asyncio.run(main())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_save_load()
|
|
test_async_io()
|
|
test_high_level_fake_weights()
|