mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
130 lines
5.0 KiB
Python
130 lines
5.0 KiB
Python
# Usage: python iteration_log_parser.py --filename <filename> --total_rank <total_rank> --csv
|
|
# The log file should be generated by redirecting the output of the server side and device side to the same file.
|
|
# The benchmark scripts are under tensorrt_llm/serve/scripts/
|
|
|
|
import argparse
|
|
import json
|
|
|
|
|
|
class IterState:
|
|
|
|
def __init__(self, total_rank, iter):
|
|
self.total_rank = total_rank
|
|
self.ctx_num = [-1] * total_rank
|
|
self.gen_num = [-1] * total_rank
|
|
self.request_num = [-1] * total_rank
|
|
self.iter_time = -1
|
|
self.iter = iter
|
|
|
|
@property
|
|
def ctx_phase(self):
|
|
assert all([i != -1 for i in self.ctx_num]), self.ctx_num
|
|
return any([i > 0 for i in self.ctx_num])
|
|
|
|
@property
|
|
def gen_phase(self):
|
|
return not self.ctx_phase
|
|
|
|
def update(self, rank, state_dict, elapsed_time):
|
|
self.ctx_num[rank] = state_dict["num_ctx_tokens"]
|
|
self.request_num[rank] = state_dict["num_ctx_requests"]
|
|
self.gen_num[rank] = state_dict["num_generation_tokens"]
|
|
if rank == 0:
|
|
self.iter_time = elapsed_time
|
|
|
|
|
|
class IterStateManager:
|
|
|
|
def __init__(self, total_rank):
|
|
self.iters: dict[int, IterState] = {}
|
|
self.total_rank = total_rank
|
|
self.benchmark_duration = 0
|
|
self.ttft_medium = 0
|
|
self.ttft_p99 = 0
|
|
|
|
def update(self, line):
|
|
if "states = " not in line:
|
|
return
|
|
rank = int(line.split("rank =")[1].strip().split(",")[0].strip())
|
|
states = json.loads(line.split("states =")[1].strip().replace("'", '"'))
|
|
elapsed_time = float(
|
|
line.split("host_step_time =")[1].strip().split("ms,")[0].strip())
|
|
iter = int(line.split("iter =")[1].strip().split(",")[0].strip())
|
|
if iter not in self.iters:
|
|
self.iters[iter] = IterState(self.total_rank, iter)
|
|
self.iters[iter].update(rank, states, elapsed_time)
|
|
|
|
def clear(self):
|
|
self.iters = {}
|
|
self.benchmark_duration = 0
|
|
self.ttft_medium = 0
|
|
self.ttft_p99 = 0
|
|
|
|
def print_result(self, prefix="", csv=False):
|
|
ctx_phase = {"iter_num": 0, "time": 0, "request_num": 0}
|
|
gen_phase = {"iter_num": 0, "time": 0}
|
|
for iter_state in self.iters.values():
|
|
if iter_state.ctx_phase:
|
|
ctx_phase["iter_num"] += 1
|
|
ctx_phase["request_num"] += iter_state.request_num[0]
|
|
ctx_phase["time"] += iter_state.iter_time
|
|
else:
|
|
gen_phase["iter_num"] += 1
|
|
gen_phase["time"] += iter_state.iter_time
|
|
if csv:
|
|
print(
|
|
f"{ctx_phase['iter_num']},{ctx_phase['time']/1000:.2f},{gen_phase['iter_num']},{gen_phase['time']/1000:.2f},{(ctx_phase['time'] + gen_phase['time'])/1000:.2f},{self.ttft_medium:.2f},{self.ttft_p99:.2f}"
|
|
)
|
|
else:
|
|
print(
|
|
f"{prefix}ctx_phase: {ctx_phase}, gen_phase: {gen_phase}, ctx_phase_ratio: {ctx_phase['time'] / (ctx_phase['time'] + gen_phase['time']) * 100.0:.2f}%, Total time: {ctx_phase['time'] + gen_phase['time']}, TTFT_medium: {self.ttft_medium:.2f}, TTFT_p99: {self.ttft_p99:.2f}"
|
|
)
|
|
|
|
|
|
def main(filename, total_rank, csv):
|
|
with open(filename, "r", encoding="unicode_escape") as f:
|
|
round = 0
|
|
start = 0
|
|
iter_state_manager = IterStateManager(total_rank)
|
|
for line in f:
|
|
if "Starting main benchmark run" in line:
|
|
if start == 2:
|
|
# print the result of the previous round
|
|
iter_state_manager.print_result(
|
|
f"Round: {round}, ",
|
|
csv,
|
|
)
|
|
round += 1
|
|
start = 1
|
|
iter_state_manager.clear()
|
|
elif "[I] iter =" in line and start == 1:
|
|
iter_state_manager.update(line)
|
|
elif "Serving Benchmark Result" in line:
|
|
start = 2
|
|
elif "Benchmark duration" in line:
|
|
benchmark_duration = (
|
|
line.split("Benchmark duration (s): ")[1].strip().strip())
|
|
iter_state_manager.benchmark_duration = benchmark_duration
|
|
elif "Median TTFT" in line:
|
|
ttft_medium = float(line.split("Median TTFT (ms): ")[1].strip())
|
|
iter_state_manager.ttft_medium = ttft_medium
|
|
elif "P99 TTFT" in line:
|
|
ttft_p99 = float(line.split("P99 TTFT (ms): ")[1].strip())
|
|
iter_state_manager.ttft_p99 = ttft_p99
|
|
|
|
if start == 2:
|
|
# print the result of the previous round
|
|
iter_state_manager.print_result(
|
|
f"Round: {round}, ",
|
|
csv,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--filename", type=str, required=True)
|
|
parser.add_argument("--total_rank", type=int, default=1)
|
|
parser.add_argument("--csv", action="store_true")
|
|
args = parser.parse_args()
|
|
main(args.filename, args.total_rank, args.csv)
|