mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
254 lines
11 KiB
Python
254 lines
11 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Dict, List
|
|
|
|
import jinja2
|
|
from transformers import AutoTokenizer
|
|
|
|
import tensorrt_llm.bindings.executor as trtllm
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--engine_dir", type=str, default="engine_outputs")
|
|
parser.add_argument("--tokenizer_dir", type=str, default="tokenizer_dir")
|
|
parser.add_argument("--payload", type=str, default="./payload.json")
|
|
parser.add_argument("--concurrency", type=int, default=50)
|
|
parser.add_argument("--check_deterministic_accuracy",
|
|
action="store_true",
|
|
default=False)
|
|
parser.add_argument("--deterministic_accuracy_threshold",
|
|
type=int,
|
|
default=1)
|
|
parser.add_argument("--batch", action="store_true", default=False)
|
|
parser.add_argument("--wait", type=float, default=0.0)
|
|
parser.add_argument("--output", type=str, default='out-strs')
|
|
return parser.parse_args()
|
|
|
|
|
|
def create_engine(engine):
|
|
parallel_config = trtllm.ParallelConfig(
|
|
communication_type=trtllm.CommunicationType.MPI,
|
|
communication_mode=trtllm.CommunicationMode.LEADER)
|
|
trt_scheduler_config = trtllm.SchedulerConfig(
|
|
trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT)
|
|
kv_cache_config = trtllm.KvCacheConfig(
|
|
free_gpu_memory_fraction=0.9,
|
|
enable_block_reuse=True,
|
|
)
|
|
extend_runtime_perf_knob_config = trtllm.ExtendedRuntimePerfKnobConfig()
|
|
extend_runtime_perf_knob_config.cuda_graph_mode = False
|
|
extend_runtime_perf_knob_config.multi_block_mode = False
|
|
executor_config = trtllm.ExecutorConfig(
|
|
1,
|
|
iter_stats_max_iterations=100,
|
|
# nvbugs/4662826
|
|
request_stats_max_iterations=0,
|
|
parallel_config=parallel_config,
|
|
# normalize_log_probs=False,
|
|
batching_type=trtllm.BatchingType.INFLIGHT,
|
|
# batching_type=trtllm.BatchingType.STATIC,
|
|
scheduler_config=trt_scheduler_config,
|
|
kv_cache_config=kv_cache_config,
|
|
enable_chunked_context=True,
|
|
extended_runtime_perf_knob_config=extend_runtime_perf_knob_config,
|
|
)
|
|
|
|
return trtllm.Executor(model_path=engine,
|
|
model_type=trtllm.ModelType.DECODER_ONLY,
|
|
executor_config=executor_config)
|
|
|
|
|
|
def create_request(payload_file, template_str, tokenizer):
|
|
json_data = {
|
|
'model':
|
|
'my-model',
|
|
'messages': [
|
|
{
|
|
'role': 'user',
|
|
'content': 'Hello there how are you?',
|
|
},
|
|
{
|
|
'role': 'assistant',
|
|
'content': 'Good and you?',
|
|
},
|
|
{
|
|
'role': 'user',
|
|
'content': 'Whats your name?',
|
|
},
|
|
],
|
|
'max_tokens':
|
|
1024,
|
|
'temperature':
|
|
0,
|
|
#'top_k':1,
|
|
#'nvext': {"top_k": 1},
|
|
'stream':
|
|
False
|
|
}
|
|
|
|
json_data['messages'][2]['content'] = """
|
|
Classify the sentiment expressed in the following text and provide the response in a single word Positive/Negative/Neutral. Explain your answer in 2 lines.
|
|
TEXT:: Today I will exaggerate, will be melodramatic (mostly the case when I am excited) and be naive (as always). Just came out from the screening of the Avengers Endgame ("Endgame")! The journey had started in the year 2008, when Tony Stark, during his capture in a cave in Afghanistan, had created a combat suit and came out of his captivity.
|
|
Then the combat suit made of iron was perfected and Tony Stark officially became the Iron Man!! The Marvel Cinematic Universe ("MCU") thus was initiated. The journey continued since then and in 2012 all the MCU heroes came together and formed the original "Avengers" (so much fun and good it was).
|
|
21 Movies in the MCU and culminating into the Infinity War (2018) and finally into the Endgame! The big adventure for me started from Jurassic Park and then came Titanic, Lagaan, Dark Knight; and then came the Avengers in 2012. Saw my absolute favorite Sholay in the hall in 2014. In the above-mentioned genre, there are good movies, great movies and then there is the Endgame.
|
|
Today after a long long time, I came out of the hall with 100% happiness, satisfaction and over the top excitement/emotions. The movie is Epic, Marvel (in the real sense) and perfect culmination of the greatest cinematic saga of all time. It is amazing, humorous, emotional and has mind-blowing action! It is one of the finest Superhero Movie of all time.
|
|
Just pure Awesome! It's intelligent!
|
|
"""
|
|
with open(payload_file, 'r') as f:
|
|
msg_system = json.load(f)
|
|
msg_user = []
|
|
msg_user.append({
|
|
"role":
|
|
"user",
|
|
"content":
|
|
msg_system[0]["content"] + "\n\n" + msg_system[1]["content"]
|
|
})
|
|
msg_user.extend(msg_system[2:])
|
|
json_data['messages'] = msg_user
|
|
|
|
environment = jinja2.Environment()
|
|
template = environment.from_string(template_str)
|
|
json_data['bos_token'] = '<s>'
|
|
json_data['eos_token'] = '</s>'
|
|
prompt = template.render(json_data)
|
|
|
|
tokens = tokenizer.encode(prompt)
|
|
|
|
sample_params = trtllm.SamplingConfig(
|
|
beam_width=1, # beam_width=1 for inflight batching
|
|
top_k=1, # SizeType topK
|
|
top_p=1.0,
|
|
top_p_min=None,
|
|
top_p_reset_ids=None, # SizeType topPResetIds
|
|
top_p_decay=None, # FloatType topPDecay
|
|
seed=1234,
|
|
temperature=1,
|
|
min_tokens=1, # SizeType minLength
|
|
beam_search_diversity_rate=None, # FloatType beamSearchDiversityRate
|
|
repetition_penalty=1, # FloatType repetitionPenalty
|
|
presence_penalty=0, # FloatType presencePenalty
|
|
frequency_penalty=0, # FloatType frequencyPenalty
|
|
length_penalty=1, # FloatType lengthPenalty
|
|
early_stopping=
|
|
None, # SizeType earlyStopping. Controls beam search, so irrelevant until we have beam_width > 1
|
|
)
|
|
#sample_params = trtllm.SamplingConfig(temperature=0, seed=1234)
|
|
|
|
return trtllm.Request(
|
|
input_token_ids=tokens[1:],
|
|
max_tokens=1024,
|
|
sampling_config=sample_params,
|
|
streaming=False,
|
|
stop_words=None,
|
|
# stop_words=[[2]], # </s>
|
|
), prompt
|
|
|
|
|
|
def get_tokenizer(tokenizer_file):
|
|
return AutoTokenizer.from_pretrained(tokenizer_file)
|
|
|
|
|
|
def get_template(tokenizer_file):
|
|
with open(os.path.join(tokenizer_file,
|
|
"tokenizer_config.json")) as tok_config:
|
|
cfg = json.load(tok_config)
|
|
return cfg['chat_template']
|
|
|
|
|
|
def enqueue_requests(pool, executor, request, concurrency=50, wait=0):
|
|
for _ in range(concurrency):
|
|
_ = pool.submit(executor.enqueue_request, request)
|
|
if wait > 0:
|
|
time.sleep(wait)
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
executor = create_engine(args.engine_dir)
|
|
if executor.can_enqueue_requests():
|
|
template = get_template(args.tokenizer_dir)
|
|
tokenizer = get_tokenizer(args.tokenizer_dir)
|
|
concurrency = int(args.concurrency)
|
|
|
|
request, prompt = create_request(args.payload, template, tokenizer)
|
|
os.makedirs(args.output, exist_ok=True)
|
|
with open(os.path.join(args.output, "prompt.txt"), 'w') as f:
|
|
f.write(prompt)
|
|
|
|
try:
|
|
for _ in range(1):
|
|
outputs: Dict[str, List[trtllm.Result]] = {}
|
|
num_finished = 0
|
|
|
|
if not args.batch:
|
|
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
|
enqueue_requests(pool,
|
|
executor,
|
|
request,
|
|
concurrency=concurrency,
|
|
wait=args.wait)
|
|
else:
|
|
executor.enqueue_requests(
|
|
[request for _ in range(concurrency)])
|
|
while num_finished < concurrency:
|
|
responses = executor.await_responses()
|
|
for response in responses:
|
|
if response.has_error():
|
|
outputs[response.request_id] = response.error_msg
|
|
num_finished += 1
|
|
else:
|
|
result = response.result
|
|
if result.is_final:
|
|
num_finished += 1
|
|
if response.request_id not in outputs:
|
|
outputs[response.request_id] = []
|
|
outputs[response.request_id].append(result)
|
|
output_strs = {}
|
|
for req_id, output in outputs.items():
|
|
if isinstance(output, str):
|
|
raise RuntimeError(output)
|
|
elif isinstance(output, list):
|
|
if len(output) != 1:
|
|
raise RuntimeError("Expected list size of 1")
|
|
output_strs[req_id] = tokenizer.decode(
|
|
output[0].output_token_ids[0])
|
|
with open(os.path.join(args.output, f"{req_id}.out"),
|
|
"w") as f:
|
|
f.write(output_strs[req_id])
|
|
else:
|
|
raise RuntimeError("Unexpected output")
|
|
|
|
output_set = set(output_strs.values())
|
|
num_unique_responses = len(output_set)
|
|
if args.check_deterministic_accuracy:
|
|
assert num_unique_responses <= args.deterministic_accuracy_threshold, f"Expected num unique responses <= {args.deterministic_accuracy_threshold} while got {num_unique_responses} "
|
|
result_str = f"Num Unique responses in {len(outputs)}: {len(output_set)}"
|
|
print(result_str)
|
|
with open(os.path.join(args.output, "num_outputs"), 'w') as f:
|
|
f.write(result_str + '\n')
|
|
finally:
|
|
executor.shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|