TensorRT-LLMs/benchmarks/python/kv_cache_offload/benchmark.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

192 lines
6.8 KiB
Python

import argparse
import datetime
import json
import random
import time
import tensorrt_llm.bindings.executor as trtllm
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = False
sampling_config = trtllm.SamplingConfig(1)
def generate_random_tokens(rounds=10, count=64) -> list[list[int]]:
ret = []
for i in range(rounds):
ret.append([random.randint(0, 1000) for _ in range(count)])
return ret
# Read input tokens from json file
def read_input_json(input_dataset_path: str,
num_users) -> tuple[list[list[int]], list[int]]:
with open(input_dataset_path, "r") as f:
data = json.load(f)
input_tokens = []
output_lens = []
for n in range(num_users):
sample = data["samples"][n]
input_tokens.append(sample["input_ids"])
output_lens.append(sample["output_len"])
return input_tokens, output_lens
# Prepare and enqueue the requests
def enqueue_requests(args: argparse.Namespace, executor: trtllm.Executor,
input_tokens) -> list[int]:
request_ids = []
for tokens in input_tokens:
req = trtllm.Request(input_token_ids=tokens,
max_tokens=args.output_len,
streaming=False,
sampling_config=sampling_config,
output_config=output_config)
req_id = executor.enqueue_request(req)
request_ids.append(req_id)
return request_ids
def get_TTFT(stats_queue):
iter_latency = []
cache_hit_rates = []
for stats in stats_queue:
iter_latency.append(stats.iter_latency_ms)
cache_hit_rates.append(stats.kv_cache_stats.cache_hit_rate)
TTFT_idx = [i for i, x in enumerate(cache_hit_rates) if x > 0.01][1]
return iter_latency[TTFT_idx]
# Wait for responses and store output tokens
def wait_for_responses(args: argparse.Namespace, request_ids: list[int],
executor: trtllm.Executor) -> list[list[int]]:
output_tokens = {req_id: [] for req_id in request_ids}
num_finished = 0
iterations = 0
while (num_finished < len(request_ids) and iterations < args.timeout_ms):
responses = executor.await_responses(
datetime.timedelta(milliseconds=args.timeout_ms))
for response in responses:
req_id = response.request_id
if not response.has_error():
result = response.result
num_finished += 1 if result.is_final else 0
for _, outTokens in enumerate(result.output_token_ids):
output_tokens[req_id].extend(outTokens)
else:
raise RuntimeError(
str(req_id) + " encountered error:" + response.error_msg)
return list(output_tokens.values())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Executor Bindings Example")
parser.add_argument("--n", type=int, required=True, help="Number of users")
parser.add_argument("--free_gpu_memory_fraction",
required=False,
type=float,
default=0.9,
help="free_gpu_memory_fraction")
parser.add_argument("--kv_host_cache_bytes",
required=False,
type=int,
default=55000000000,
help="host_cache_size")
parser.add_argument("--model_path",
type=str,
required=True,
help="Directory containing model engine")
parser.add_argument("--input_dataset_path",
type=str,
required=True,
help="Text file containing the input tokens")
parser.add_argument("--beam_width",
type=int,
required=False,
default=1,
help="The beam width")
parser.add_argument("--streaming",
default=False,
action="store_true",
help="Operate in streaming mode")
parser.add_argument("--output_len",
type=int,
required=False,
default=64,
help="The number of tokens to be generated for output.")
parser.add_argument("--rounds",
type=int,
required=False,
default=10,
help="How many runs of user input to run.")
parser.add_argument(
"--timeout_ms",
type=int,
required=False,
default=10000,
help="The maximum time to wait for all responses, in milliseconds")
parser.add_argument(
"--log_iteration_data",
action='store_true',
help="Print the verbose iteration status data (default: False).")
args = parser.parse_args()
kv_cache_config = trtllm.KvCacheConfig(
enable_block_reuse=True,
free_gpu_memory_fraction=args.free_gpu_memory_fraction,
host_cache_size=args.kv_host_cache_bytes)
executor_config = trtllm.ExecutorConfig(args.beam_width,
kv_cache_config=kv_cache_config)
# Create the executor.
executor = trtllm.Executor(args.model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
new_inputs = [generate_random_tokens(args.rounds) for _ in range(args.n)]
stats_queue = []
if executor.can_enqueue_requests():
## Process long context to generate kvcache
context_tokens, _ = read_input_json(args.input_dataset_path, args.n)
# Enqueue the requests
request_ids = enqueue_requests(args, executor, context_tokens)
# Wait for the responses
output_tokens = wait_for_responses(args, request_ids, executor)
stats_queue.extend(executor.get_latest_iteration_stats())
# Start the multi-turn runs
## Start timing
start_time = time.time()
for r in range(args.rounds):
current_input_tokens = [
output_tokens[i] + new_inputs[i][r] for i in range(args.n)
]
# Enqueue the requests
request_ids = enqueue_requests(args, executor, current_input_tokens)
# Wait for the responses
output_tokens = wait_for_responses(args, request_ids, executor)
stats_queue.extend(executor.get_latest_iteration_stats())
## End timing
end_time = time.time()
elapsed_time = (end_time - start_time) * 1000
print(f"E2E TIME: {elapsed_time:.2f} (ms)")
print(f"TTFT: {get_TTFT(stats_queue)} (ms)")
if args.log_iteration_data:
for stats in stats_queue:
print(stats.to_json_str())