mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1558 lines
63 KiB
Python
1558 lines
63 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""TensorRT LLM perf sanity tests."""
|
|
|
|
import contextlib
|
|
import copy
|
|
import glob
|
|
import io
|
|
import os
|
|
import re
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
from typing import Dict, List, NamedTuple, Optional, Tuple
|
|
|
|
import pytest
|
|
import requests
|
|
import yaml
|
|
from test_common.http_utils import wait_for_endpoint_ready
|
|
|
|
from defs.trt_test_alternative import print_error, print_info
|
|
from tensorrt_llm._utils import get_free_port
|
|
|
|
from ..conftest import get_llm_root, llm_models_root
|
|
from .open_search_db_utils import (
|
|
SCENARIO_MATCH_FIELDS,
|
|
add_id,
|
|
check_perf_regression,
|
|
get_common_values,
|
|
get_history_data,
|
|
get_job_info,
|
|
post_new_perf_data,
|
|
prepare_baseline_data,
|
|
prepare_regressive_test_cases,
|
|
)
|
|
from .utils import collect_and_clean_myelin_time
|
|
|
|
# Model PATH of local dir synced from internal LLM models repo
|
|
MODEL_PATH_DICT = {
|
|
"deepseek_r1_fp8": "DeepSeek-R1/DeepSeek-R1",
|
|
"deepseek_r1_nvfp4": "DeepSeek-R1/DeepSeek-R1-FP4",
|
|
"deepseek_r1_0528_fp8": "DeepSeek-R1/DeepSeek-R1-0528/",
|
|
"deepseek_r1_0528_fp4": "DeepSeek-R1/DeepSeek-R1-0528-FP4/",
|
|
"deepseek_r1_0528_fp4_v2": "DeepSeek-R1/DeepSeek-R1-0528-FP4-v2/",
|
|
"deepseek_v32_fp4": "DeepSeek-V3.2-Exp-FP4-v2",
|
|
"gpt_oss_120b_fp4": "gpt_oss/gpt-oss-120b",
|
|
"k2_thinking_fp4": "Kimi-K2-Thinking-NVFP4",
|
|
}
|
|
|
|
SUPPORTED_GPU_TYPE = [
|
|
"H200",
|
|
"B200",
|
|
"B300",
|
|
"GB200",
|
|
"GB300",
|
|
]
|
|
|
|
DEFAULT_TIMEOUT = 7200
|
|
|
|
# Regex patterns for parsing benchmark output metrics
|
|
# Key is the metric name used in database (e.g., "mean_e2el", "seq_throughput")
|
|
PERF_METRIC_LOG_QUERIES = {
|
|
"seq_throughput": re.compile(r"Request throughput \(req\/s\):\s+(-?[\d\.]+)"),
|
|
"token_throughput": re.compile(r"Output token throughput \(tok\/s\):\s+(-?[\d\.]+)"),
|
|
"total_token_throughput": re.compile(r"Total Token throughput \(tok\/s\):\s+(-?[\d\.]+)"),
|
|
"user_throughput": re.compile(r"User throughput \(tok\/s\):\s+(-?[\d\.]+)"),
|
|
"mean_ttft": re.compile(r"Mean TTFT \(ms\):\s+(-?[\d\.]+)"),
|
|
"median_ttft": re.compile(r"Median TTFT \(ms\):\s+(-?[\d\.]+)"),
|
|
"p99_ttft": re.compile(r"P99 TTFT \(ms\):\s+(-?[\d\.]+)"),
|
|
"mean_itl": re.compile(r"Mean ITL \(ms\):\s+(-?[\d\.]+)"),
|
|
"median_itl": re.compile(r"Median ITL \(ms\):\s+(-?[\d\.]+)"),
|
|
"p99_itl": re.compile(r"P99 ITL \(ms\):\s+(-?[\d\.]+)"),
|
|
"mean_tpot": re.compile(r"Mean TPOT \(ms\):\s+(-?[\d\.]+)"),
|
|
"median_tpot": re.compile(r"Median TPOT \(ms\):\s+(-?[\d\.]+)"),
|
|
"p99_tpot": re.compile(r"P99 TPOT \(ms\):\s+(-?[\d\.]+)"),
|
|
"mean_e2el": re.compile(r"Mean E2EL \(ms\):\s+(-?[\d\.]+)"),
|
|
"median_e2el": re.compile(r"Median E2EL \(ms\):\s+(-?[\d\.]+)"),
|
|
"p99_e2el": re.compile(r"P99 E2EL \(ms\):\s+(-?[\d\.]+)"),
|
|
}
|
|
|
|
|
|
def get_model_dir(model_name: str) -> str:
|
|
"""Get model directory path from model name."""
|
|
if model_name in MODEL_PATH_DICT:
|
|
return os.path.join(llm_models_root(), MODEL_PATH_DICT[model_name])
|
|
return ""
|
|
|
|
|
|
def get_dataset_path() -> str:
|
|
"""Get dataset path for benchmark."""
|
|
return os.path.join(llm_models_root(), "datasets", "ShareGPT_V3_unfiltered_cleaned_split.json")
|
|
|
|
|
|
def to_env_dict(env_vars: str) -> Dict[str, str]:
|
|
"""Convert env vars string to dict."""
|
|
env = {}
|
|
for env_var in env_vars.split():
|
|
if "=" in env_var:
|
|
key, value = env_var.split("=", 1)
|
|
env[key] = value
|
|
return env
|
|
|
|
|
|
def add_host_port_to_cmd(cmd: List[str], host: str, port: int) -> List[str]:
|
|
"""Add host and port to command."""
|
|
return cmd + ["--host", host, "--port", str(port)]
|
|
|
|
|
|
class ServerConfig:
|
|
"""Configurations of trtllm-server."""
|
|
|
|
def __init__(self, server_config_data: dict, env_vars: str = ""):
|
|
# Extract required fields
|
|
self.concurrency = server_config_data.get("concurrency", 1)
|
|
self.model_name = server_config_data["model_name"]
|
|
self.model_path = ""
|
|
self.env_vars = env_vars
|
|
self.disagg_run_type = server_config_data.get("disagg_run_type", "aggr")
|
|
|
|
# Extract optional fields with defaults
|
|
self.tp = server_config_data.get("tensor_parallel_size", 1)
|
|
self.ep = server_config_data.get("moe_expert_parallel_size", 1)
|
|
self.pp = server_config_data.get("pipeline_parallel_size", 1)
|
|
self.cp = server_config_data.get("context_parallel_size", 1)
|
|
self.gpus = server_config_data.get("gpus", self.tp * self.cp * self.pp)
|
|
self.gpus_per_node = server_config_data.get("gpus_per_node", 0) or self.gpus
|
|
self.max_num_tokens = server_config_data.get("max_num_tokens", 2048)
|
|
self.max_batch_size = server_config_data.get("max_batch_size", 512)
|
|
self.max_seq_len = server_config_data.get("max_seq_len", 0)
|
|
self.disable_overlap_scheduler = server_config_data.get("disable_overlap_scheduler", False)
|
|
self.num_postprocess_workers = server_config_data.get("num_postprocess_workers", 0)
|
|
self.stream_interval = server_config_data.get("stream_interval", 10)
|
|
self.attn_backend = server_config_data.get("attn_backend", "TRTLLM")
|
|
self.enable_chunked_prefill = server_config_data.get("enable_chunked_prefill", False)
|
|
self.enable_attention_dp = server_config_data.get("enable_attention_dp", False)
|
|
self.trust_remote_code = server_config_data.get("trust_remote_code", False)
|
|
self.enable_lm_head_tp_in_adp = server_config_data.get("enable_lm_head_tp_in_adp", False)
|
|
|
|
# attention_dp_config
|
|
attention_dp_config = server_config_data.get("attention_dp_config", {})
|
|
self.attention_dp_balance = attention_dp_config.get("enable_balance", False)
|
|
self.batching_wait_iters = attention_dp_config.get("batching_wait_iters", 0)
|
|
self.timeout_iters = attention_dp_config.get("timeout_iters", 60)
|
|
|
|
# moe_config
|
|
moe_config = server_config_data.get("moe_config", {})
|
|
self.moe_backend = moe_config.get("backend", "")
|
|
self.moe_max_num_tokens = moe_config.get("max_num_tokens", 0)
|
|
self.use_low_precision_moe_combine = moe_config.get("use_low_precision_moe_combine", False)
|
|
load_balancer_config = moe_config.get("load_balancer", {})
|
|
self.load_balancer_num_slots = load_balancer_config.get("num_slots", 0)
|
|
self.load_balancer_layer_updates_per_iter = load_balancer_config.get(
|
|
"layer_updates_per_iter", 0
|
|
)
|
|
|
|
# cuda_graph_config
|
|
cuda_graph_config = server_config_data.get("cuda_graph_config", {})
|
|
self.enable_cuda_graph = False
|
|
if cuda_graph_config:
|
|
self.enable_cuda_graph = True
|
|
self.enable_padding = cuda_graph_config.get("enable_padding", True)
|
|
self.cuda_graph_batch_sizes = cuda_graph_config.get("batch_sizes", [])
|
|
self.cuda_graph_max_batch_size = cuda_graph_config.get("max_batch_size", 0)
|
|
else:
|
|
self.enable_padding = True
|
|
self.cuda_graph_batch_sizes = []
|
|
self.cuda_graph_max_batch_size = 0
|
|
|
|
# kv_cache_config
|
|
kv_cache_config = server_config_data.get("kv_cache_config", {})
|
|
self.kv_cache_dtype = kv_cache_config.get("dtype", "fp8")
|
|
self.enable_block_reuse = kv_cache_config.get("enable_block_reuse", False)
|
|
self.free_gpu_memory_fraction = kv_cache_config.get("free_gpu_memory_fraction", 0.8)
|
|
|
|
# cache_transceiver_config
|
|
cache_transceiver_config = server_config_data.get("cache_transceiver_config", {})
|
|
self.cache_transceiver_backend = cache_transceiver_config.get("backend", "")
|
|
self.cache_transceiver_max_tokens_in_buffer = cache_transceiver_config.get(
|
|
"max_tokens_in_buffer", 0
|
|
)
|
|
|
|
# Generate default name if not provided
|
|
self.name = server_config_data.get("name", "")
|
|
if not self.name:
|
|
self.name = (
|
|
f"{self.model_name}_tp{self.tp}_ep{self.ep}_pp{self.pp}_cp{self.cp}"
|
|
f"_bs{self.max_batch_size}_attn{self.attn_backend}_moe{self.moe_backend}"
|
|
)
|
|
if self.cache_transceiver_backend:
|
|
self.name += f"_spec{self.cache_transceiver_backend}"
|
|
|
|
# speculative_config
|
|
speculative_config = server_config_data.get("speculative_config", {})
|
|
self.spec_decoding_type = speculative_config.get("decoding_type", "")
|
|
self.num_nextn_predict_layers = speculative_config.get("num_nextn_predict_layers", 0)
|
|
eagle3_value = speculative_config.get("eagle3_layers_to_capture", [])
|
|
if isinstance(eagle3_value, int):
|
|
self.eagle3_layers_to_capture = [eagle3_value]
|
|
elif isinstance(eagle3_value, list):
|
|
self.eagle3_layers_to_capture = eagle3_value
|
|
else:
|
|
self.eagle3_layers_to_capture = []
|
|
self.max_draft_len = speculative_config.get("max_draft_len", 0)
|
|
self.speculative_model_dir = speculative_config.get("speculative_model_dir", "")
|
|
|
|
# match_mode: "config" (default) or "scenario"
|
|
self.match_mode = server_config_data.get("match_mode", "config")
|
|
|
|
# Store filtered config for extra_llm_api_config
|
|
exclude_keys = [
|
|
"mode",
|
|
"concurrency",
|
|
"name",
|
|
"model_name",
|
|
"disagg_run_type",
|
|
"gpus",
|
|
"gpus_per_node",
|
|
"match_mode",
|
|
"client_configs",
|
|
"match_mode",
|
|
]
|
|
self.extra_llm_api_config_data = {
|
|
k: v for k, v in server_config_data.items() if k not in exclude_keys
|
|
}
|
|
|
|
def to_cmd(
|
|
self, output_dir: str, numa_bind: bool = False, disagg_serving_type: str = ""
|
|
) -> List[str]:
|
|
"""Generate server command."""
|
|
model_dir = get_model_dir(self.model_name)
|
|
self.model_path = model_dir if os.path.exists(model_dir) else self.model_name
|
|
config_filename = f"extra-llm-api-config.{self.disagg_run_type}.{self.name}.yml"
|
|
config_path = os.path.join(output_dir, config_filename)
|
|
|
|
numa_bind_cmd = []
|
|
if numa_bind:
|
|
numa_bind_cmd = ["numactl", "-m 0,1"]
|
|
|
|
cmd = numa_bind_cmd + [
|
|
"trtllm-serve",
|
|
self.model_path,
|
|
"--backend",
|
|
"pytorch",
|
|
"--config",
|
|
config_path,
|
|
]
|
|
return cmd
|
|
|
|
def to_env(self) -> Dict[str, str]:
|
|
return to_env_dict(self.env_vars)
|
|
|
|
def to_match_keys(self) -> List[str]:
|
|
return [
|
|
"s_model_name",
|
|
"l_tp",
|
|
"l_ep",
|
|
"l_pp",
|
|
"l_cp",
|
|
"l_gpus_per_node",
|
|
"l_max_batch_size",
|
|
"b_disable_overlap_scheduler",
|
|
"l_num_postprocess_workers",
|
|
"s_attn_backend",
|
|
"b_enable_chunked_prefill",
|
|
"b_enable_attention_dp",
|
|
"b_enable_lm_head_tp_in_adp",
|
|
# attention_dp_config
|
|
"b_attention_dp_balance",
|
|
# moe_config
|
|
"s_moe_backend",
|
|
# cuda_graph_config
|
|
"b_enable_cuda_graph",
|
|
# kv_cache_config
|
|
"s_kv_cache_dtype",
|
|
# cache_transceiver_config
|
|
"s_cache_transceiver_backend",
|
|
# speculative_config
|
|
"s_spec_decoding_type",
|
|
"l_num_nextn_predict_layers",
|
|
]
|
|
|
|
def to_db_data(self) -> dict:
|
|
"""Convert ServerConfig to database data."""
|
|
db_data = {
|
|
"s_server_name": self.name,
|
|
"s_model_name": self.model_name.lower(),
|
|
"l_gpus": self.gpus,
|
|
"l_tp": self.tp,
|
|
"l_ep": self.ep,
|
|
"l_pp": self.pp,
|
|
"l_cp": self.cp,
|
|
"l_gpus_per_node": self.gpus_per_node,
|
|
"l_max_num_tokens": self.max_num_tokens,
|
|
"l_max_batch_size": self.max_batch_size,
|
|
"l_max_seq_len": self.max_seq_len,
|
|
"b_disable_overlap_scheduler": self.disable_overlap_scheduler,
|
|
"l_num_postprocess_workers": self.num_postprocess_workers,
|
|
"l_stream_interval": self.stream_interval,
|
|
"s_attn_backend": self.attn_backend,
|
|
"b_enable_chunked_prefill": self.enable_chunked_prefill,
|
|
"b_enable_attention_dp": self.enable_attention_dp,
|
|
"b_trust_remote_code": self.trust_remote_code,
|
|
"b_enable_lm_head_tp_in_adp": self.enable_lm_head_tp_in_adp,
|
|
# attention_dp_config
|
|
"b_attention_dp_balance": self.attention_dp_balance,
|
|
"l_batching_wait_iters": self.batching_wait_iters,
|
|
"l_timeout_iters": self.timeout_iters,
|
|
# moe_config
|
|
"s_moe_backend": self.moe_backend,
|
|
"l_moe_max_num_tokens": self.moe_max_num_tokens,
|
|
"b_use_low_precision_moe_combine": self.use_low_precision_moe_combine,
|
|
"l_load_balancer_num_slots": self.load_balancer_num_slots,
|
|
"l_load_balancer_layer_updates_per_iter": self.load_balancer_layer_updates_per_iter,
|
|
# cuda_graph_config
|
|
"b_enable_cuda_graph": self.enable_cuda_graph,
|
|
"b_enable_padding": self.enable_padding,
|
|
"l_cuda_graph_max_batch_size": self.cuda_graph_max_batch_size,
|
|
"s_cuda_graph_batch_sizes": ",".join(map(str, self.cuda_graph_batch_sizes)),
|
|
# kv_cache_config
|
|
"s_kv_cache_dtype": self.kv_cache_dtype,
|
|
"b_enable_block_reuse": self.enable_block_reuse,
|
|
"d_free_gpu_memory_fraction": self.free_gpu_memory_fraction,
|
|
# cache_transceiver_config
|
|
"s_cache_transceiver_backend": self.cache_transceiver_backend,
|
|
"l_cache_transceiver_max_tokens_in_buffer": self.cache_transceiver_max_tokens_in_buffer,
|
|
# speculative_config
|
|
"s_spec_decoding_type": self.spec_decoding_type,
|
|
"l_num_nextn_predict_layers": self.num_nextn_predict_layers,
|
|
"s_eagle3_layers_to_capture": ",".join(map(str, self.eagle3_layers_to_capture)),
|
|
"l_max_draft_len": self.max_draft_len,
|
|
"s_speculative_model_dir": self.speculative_model_dir,
|
|
"s_server_log_link": "",
|
|
"s_server_env_var": self.env_vars,
|
|
}
|
|
return db_data
|
|
|
|
def generate_extra_llm_api_config(self) -> str:
|
|
"""Generate extra-llm-api-config.yml content."""
|
|
config_data = dict(self.extra_llm_api_config_data)
|
|
|
|
# Handle speculative_model_dir path conversion
|
|
if (
|
|
"speculative_config" in config_data
|
|
and "speculative_model_dir" in config_data["speculative_config"]
|
|
):
|
|
spec_model_dir = config_data["speculative_config"]["speculative_model_dir"]
|
|
if spec_model_dir:
|
|
config_data["speculative_config"]["speculative_model_dir"] = os.path.join(
|
|
llm_models_root(), spec_model_dir
|
|
)
|
|
|
|
return yaml.dump(config_data, default_flow_style=False, sort_keys=False)
|
|
|
|
|
|
class ClientConfig:
|
|
"""Configurations of benchmark client."""
|
|
|
|
def __init__(self, client_config_data: dict, model_name: str, env_vars: str = ""):
|
|
self.model_name = model_name
|
|
self.concurrency = client_config_data.get("concurrency", 1)
|
|
self.iterations = client_config_data.get("iterations", 1)
|
|
self.isl = client_config_data.get("isl", 1024)
|
|
self.osl = client_config_data.get("osl", 1024)
|
|
self.random_range_ratio = client_config_data.get("random_range_ratio", 0.0)
|
|
self.backend = client_config_data.get("backend", "openai")
|
|
self.use_chat_template = client_config_data.get("use_chat_template", False)
|
|
self.streaming = client_config_data.get("streaming", True)
|
|
self.model_path = ""
|
|
self.env_vars = env_vars
|
|
|
|
# Generate default name if not provided
|
|
self.name = client_config_data.get("name", "")
|
|
if not self.name:
|
|
self.name = f"con{self.concurrency}_iter{self.iterations}_isl{self.isl}_osl{self.osl}"
|
|
|
|
def to_cmd(self) -> List[str]:
|
|
"""Generate benchmark command."""
|
|
model_dir = get_model_dir(self.model_name)
|
|
self.model_path = model_dir if os.path.exists(model_dir) else self.model_name
|
|
dataset_path = get_dataset_path()
|
|
benchmark_cmd = [
|
|
"python",
|
|
"-m",
|
|
"tensorrt_llm.serve.scripts.benchmark_serving",
|
|
"--model",
|
|
self.model_path,
|
|
"--tokenizer",
|
|
self.model_path,
|
|
"--dataset-name",
|
|
"random",
|
|
"--random-ids",
|
|
"--num-prompts",
|
|
str(self.concurrency * self.iterations),
|
|
"--max-concurrency",
|
|
str(self.concurrency),
|
|
"--random-input-len",
|
|
str(self.isl),
|
|
"--random-output-len",
|
|
str(self.osl),
|
|
"--random-range-ratio",
|
|
str(self.random_range_ratio),
|
|
"--trust-remote-code",
|
|
"--ignore-eos",
|
|
"--percentile-metrics",
|
|
"ttft,tpot,itl,e2el",
|
|
]
|
|
if dataset_path and os.path.exists(dataset_path):
|
|
benchmark_cmd.append("--dataset-path")
|
|
benchmark_cmd.append(dataset_path)
|
|
if self.backend:
|
|
benchmark_cmd.append("--backend")
|
|
benchmark_cmd.append(self.backend)
|
|
if self.use_chat_template:
|
|
benchmark_cmd.append("--use-chat-template")
|
|
if not self.streaming:
|
|
benchmark_cmd.append("--non-streaming")
|
|
return benchmark_cmd
|
|
|
|
def to_env(self) -> Dict[str, str]:
|
|
return to_env_dict(self.env_vars)
|
|
|
|
def to_match_keys(self) -> List[str]:
|
|
return [
|
|
"l_concurrency",
|
|
"l_iterations",
|
|
"l_isl",
|
|
"l_osl",
|
|
"d_random_range_ratio",
|
|
"s_backend",
|
|
"b_use_chat_template",
|
|
"b_streaming",
|
|
]
|
|
|
|
def to_db_data(self) -> dict:
|
|
"""Convert ClientConfig to database data."""
|
|
db_data = {
|
|
"s_client_name": self.name,
|
|
"l_concurrency": self.concurrency,
|
|
"l_iterations": self.iterations,
|
|
"l_isl": self.isl,
|
|
"l_osl": self.osl,
|
|
"d_random_range_ratio": self.random_range_ratio,
|
|
"s_backend": self.backend,
|
|
"b_use_chat_template": self.use_chat_template,
|
|
"b_streaming": self.streaming,
|
|
"s_client_log_link": "",
|
|
"s_client_env_vars": self.env_vars,
|
|
}
|
|
if self.backend:
|
|
db_data["s_backend"] = self.backend
|
|
if self.use_chat_template:
|
|
db_data["b_use_chat_template"] = self.use_chat_template
|
|
return db_data
|
|
|
|
|
|
class DisaggConfig:
|
|
"""Configurations for disaggregated server."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
disagg_serving_type: str,
|
|
hostname: str,
|
|
numa_bind: bool,
|
|
timeout: int,
|
|
benchmark_mode: str,
|
|
model_name: str,
|
|
hardware: dict,
|
|
server_env_var: str,
|
|
):
|
|
self.name = name
|
|
self.disagg_serving_type = disagg_serving_type
|
|
self.hostname = hostname
|
|
self.numa_bind = numa_bind
|
|
self.timeout = timeout
|
|
self.benchmark_mode = benchmark_mode
|
|
self.model_name = model_name
|
|
self.hardware = hardware
|
|
self.server_env_var = server_env_var
|
|
self.num_ctx_servers = hardware.get("num_ctx_servers", 0)
|
|
self.num_gen_servers = hardware.get("num_gen_servers", 0)
|
|
|
|
|
|
class AggrTestCmds(NamedTuple):
|
|
"""Commands for aggregated server perf sanity tests."""
|
|
|
|
server_cmds: List[List[str]]
|
|
client_cmds: Dict[int, List[List[str]]]
|
|
timeout: int
|
|
output_dir: str
|
|
|
|
def run_cmd(self, server_idx: int) -> List[str]:
|
|
"""Run all clients for a server and return outputs."""
|
|
outputs = []
|
|
server_proc = None
|
|
server_cmd = self.server_cmds[server_idx]
|
|
|
|
try:
|
|
server_hostname = "localhost"
|
|
server_port = get_free_port()
|
|
server_cmd_with_port = add_host_port_to_cmd(server_cmd, server_hostname, server_port)
|
|
|
|
server_file_path = os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.log")
|
|
|
|
print_info(f"Starting server. cmd is {server_cmd_with_port}")
|
|
with open(server_file_path, "w") as server_ctx:
|
|
server_proc = subprocess.Popen(
|
|
server_cmd_with_port,
|
|
stdout=server_ctx,
|
|
stderr=subprocess.STDOUT,
|
|
env=copy.deepcopy(os.environ),
|
|
)
|
|
|
|
wait_for_endpoint_ready(
|
|
f"http://{server_hostname}:{server_port}/health",
|
|
timeout=self.timeout,
|
|
server_proc=server_proc,
|
|
)
|
|
|
|
# Run all clients for this server
|
|
for client_idx, client_cmd in enumerate(self.client_cmds[server_idx]):
|
|
client_file_path = os.path.join(
|
|
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
|
|
)
|
|
|
|
client_cmd_with_port = add_host_port_to_cmd(
|
|
client_cmd, server_hostname, server_port
|
|
)
|
|
print_info(f"Starting client. cmd is {client_cmd_with_port}")
|
|
|
|
output = subprocess.check_output(
|
|
client_cmd_with_port,
|
|
stderr=subprocess.STDOUT,
|
|
env=copy.deepcopy(os.environ),
|
|
).decode()
|
|
|
|
with open(client_file_path, "w") as client_ctx:
|
|
client_ctx.write(output)
|
|
|
|
outputs.append(output)
|
|
|
|
finally:
|
|
if server_proc:
|
|
server_proc.terminate()
|
|
server_proc.wait()
|
|
|
|
return outputs
|
|
|
|
def get_cmd_str(self, server_idx: int) -> List[str]:
|
|
return ["aggr_server tests, please check config files"]
|
|
|
|
|
|
class DisaggTestCmds(NamedTuple):
|
|
"""Commands for multi-node disaggregated server perf sanity tests."""
|
|
|
|
server_cmds: List[Tuple[List[str], List[str], List[str]]]
|
|
client_cmds: Dict[int, List[List[str]]]
|
|
timeout: int
|
|
hostname: str
|
|
disagg_serving_type: str
|
|
num_ctx_servers: int
|
|
num_gen_servers: int
|
|
output_dir: str
|
|
|
|
def _generate_hostname_file(self, server_idx: int, port: int):
|
|
"""Create hostname file for coordination."""
|
|
hostnames_dir = os.path.join(self.output_dir, f"hostnames-{server_idx}")
|
|
if not os.path.exists(hostnames_dir):
|
|
os.makedirs(hostnames_dir, exist_ok=True)
|
|
hostname_file = os.path.join(hostnames_dir, f"{self.disagg_serving_type}.txt")
|
|
with open(hostname_file, "w") as f:
|
|
f.write(f"{self.hostname}:{port}")
|
|
|
|
def _generate_disagg_server_config(self, server_idx: int, disagg_server_port: int) -> str:
|
|
"""Generate disagg server config from hostname files."""
|
|
print_info(f"Generating disagg server config for server index {server_idx}")
|
|
hostnames_folder = os.path.join(self.output_dir, f"hostnames-{server_idx}")
|
|
expected_count = self.num_ctx_servers + self.num_gen_servers
|
|
start_time = time.time()
|
|
hostnames = []
|
|
|
|
while True:
|
|
elapsed_time = time.time() - start_time
|
|
print_info(
|
|
f"Waiting for hostnames in {hostnames_folder}, "
|
|
f"elapsed time: {elapsed_time}s, current: {len(hostnames)}, "
|
|
f"expected: {expected_count}"
|
|
)
|
|
if elapsed_time > self.timeout:
|
|
print_error(f"Time out. Hostnames files are not ready after {self.timeout}s")
|
|
break
|
|
time.sleep(10)
|
|
if not os.path.exists(hostnames_folder):
|
|
continue
|
|
hostnames = os.listdir(hostnames_folder)
|
|
if len(hostnames) >= expected_count:
|
|
break
|
|
|
|
print_info(f"All hostnames found in {hostnames_folder} after elapsed time: {elapsed_time}s")
|
|
|
|
# Read ctx and gen hostnames
|
|
ctx_hostnames = []
|
|
gen_hostnames = []
|
|
for hostname_file in hostnames:
|
|
hostname_file_path = os.path.join(hostnames_folder, hostname_file)
|
|
with open(hostname_file_path, "r") as f:
|
|
hostname_port = f.read().strip()
|
|
if hostname_file.startswith("CTX"):
|
|
ctx_hostnames.append(hostname_port)
|
|
elif hostname_file.startswith("GEN"):
|
|
gen_hostnames.append(hostname_port)
|
|
|
|
server_config = {
|
|
"hostname": self.hostname,
|
|
"port": disagg_server_port,
|
|
"backend": "pytorch",
|
|
"context_servers": {
|
|
"num_instances": self.num_ctx_servers,
|
|
"urls": ctx_hostnames,
|
|
},
|
|
"generation_servers": {
|
|
"num_instances": self.num_gen_servers,
|
|
"urls": gen_hostnames,
|
|
},
|
|
}
|
|
config_path = os.path.join(self.output_dir, f"server_config.{server_idx}.yaml")
|
|
with open(config_path, "w") as f:
|
|
yaml.dump(server_config, f)
|
|
print_info(f"Server config file {config_path} generated")
|
|
return config_path
|
|
|
|
def _get_disagg_server_hostname_and_port(self, server_idx: int) -> Tuple[str, int]:
|
|
"""Wait for and read disagg server config."""
|
|
config_path = os.path.join(self.output_dir, f"server_config.{server_idx}.yaml")
|
|
start_time = time.time()
|
|
while True:
|
|
if os.path.exists(config_path):
|
|
print_info(f"Server config file found: {config_path}")
|
|
break
|
|
elapsed_time = time.time() - start_time
|
|
if elapsed_time > self.timeout:
|
|
print_error(f"Server config file {config_path} not found after {self.timeout}s")
|
|
break
|
|
print_info(f"Waiting for server config file, elapsed time: {elapsed_time}s")
|
|
time.sleep(10)
|
|
|
|
with open(config_path, "r") as f:
|
|
server_config = yaml.safe_load(f)
|
|
return server_config["hostname"], server_config["port"]
|
|
|
|
def wait_for_benchmark_ready(self, benchmark_status_file: str):
|
|
"""Wait for benchmark to complete."""
|
|
start_time = time.time()
|
|
while True:
|
|
if os.path.exists(benchmark_status_file):
|
|
print_info(
|
|
f"Benchmark status file found, terminating server {self.disagg_serving_type}"
|
|
)
|
|
break
|
|
elapsed_time = time.time() - start_time
|
|
print_info(f"Waiting for benchmark status file, elapsed time: {elapsed_time}s")
|
|
if elapsed_time > self.timeout:
|
|
print_error(f"Timeout waiting for benchmark status file after {self.timeout}s")
|
|
break
|
|
time.sleep(10)
|
|
|
|
def wait_for_endpoint_ready(self, url: str, server_files: List[str] = None):
|
|
"""Wait for endpoint to be ready."""
|
|
start = time.monotonic()
|
|
iteration = 0
|
|
error_keywords = ["RuntimeError", "out of memory", "ValueError"]
|
|
while True:
|
|
iteration += 1
|
|
elapsed_time = time.monotonic() - start
|
|
if elapsed_time > self.timeout:
|
|
print_error(
|
|
f"Timeout waiting for endpoint {url} to be ready after {self.timeout} seconds"
|
|
)
|
|
break
|
|
print_info(f"Waiting for endpoint {url} to be ready, elapsed time: {elapsed_time}s")
|
|
|
|
if server_files and iteration % 30 == 0:
|
|
for server_file in server_files:
|
|
if os.path.exists(server_file):
|
|
try:
|
|
with open(server_file, "r") as f:
|
|
content = f.read()
|
|
for line in content.splitlines():
|
|
for keyword in error_keywords:
|
|
if keyword in line:
|
|
print_error(
|
|
f"Found '{keyword}' in server file {server_file}: {line}"
|
|
)
|
|
except Exception as e:
|
|
print_info(f"Failed to read server file {server_file}: {e}")
|
|
|
|
try:
|
|
time.sleep(10)
|
|
if requests.get(url).status_code == 200:
|
|
print_info(f"endpoint {url} is ready")
|
|
return
|
|
except Exception as err:
|
|
print_info(f"endpoint {url} is not ready, with exception: {err}")
|
|
|
|
def run_cmd(self, server_idx: int) -> List[str]:
|
|
"""Run commands for a server and return outputs."""
|
|
outputs = []
|
|
benchmark_status_file = os.path.join(self.output_dir, f"benchmark_status.{server_idx}.txt")
|
|
port = get_free_port()
|
|
|
|
ctx_cmd, gen_cmd, disagg_cmd = self.server_cmds[server_idx]
|
|
if "CTX" in self.disagg_serving_type or "GEN" in self.disagg_serving_type:
|
|
self._generate_hostname_file(server_idx, port)
|
|
server_file_path = os.path.join(
|
|
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
|
|
)
|
|
is_ctx = "CTX" in self.disagg_serving_type
|
|
server_cmd = ctx_cmd if is_ctx else gen_cmd
|
|
server_cmd = add_host_port_to_cmd(server_cmd, self.hostname, port)
|
|
try:
|
|
print_info(
|
|
f"Starting server. disagg_serving_type: {self.disagg_serving_type} cmd is {server_cmd}"
|
|
)
|
|
with open(server_file_path, "w") as server_ctx:
|
|
server_proc = subprocess.Popen(
|
|
server_cmd,
|
|
stdout=server_ctx,
|
|
stderr=subprocess.STDOUT,
|
|
env=copy.deepcopy(os.environ),
|
|
)
|
|
self.wait_for_benchmark_ready(benchmark_status_file)
|
|
finally:
|
|
print_info(f"Server {self.disagg_serving_type} stopped")
|
|
server_proc.terminate()
|
|
server_proc.wait()
|
|
|
|
elif self.disagg_serving_type == "DISAGG_SERVER":
|
|
disagg_server_file_path = os.path.join(
|
|
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
|
|
)
|
|
try:
|
|
self._generate_disagg_server_config(server_idx, port)
|
|
print_info(f"Starting disagg server. cmd is {disagg_cmd}")
|
|
with open(disagg_server_file_path, "w") as disagg_server_ctx:
|
|
disagg_server_proc = subprocess.Popen(
|
|
disagg_cmd,
|
|
stdout=disagg_server_ctx,
|
|
stderr=subprocess.STDOUT,
|
|
env=copy.deepcopy(os.environ),
|
|
)
|
|
self.wait_for_benchmark_ready(benchmark_status_file)
|
|
finally:
|
|
print_info(f"Disagg server {self.disagg_serving_type} stopped")
|
|
disagg_server_proc.terminate()
|
|
disagg_server_proc.wait()
|
|
|
|
elif self.disagg_serving_type == "BENCHMARK":
|
|
try:
|
|
disagg_server_hostname, disagg_server_port = (
|
|
self._get_disagg_server_hostname_and_port(server_idx)
|
|
)
|
|
server_files = [
|
|
os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.DISAGG_SERVER.log"),
|
|
]
|
|
for ctx_idx in range(self.num_ctx_servers):
|
|
server_files.append(
|
|
os.path.join(
|
|
self.output_dir, f"trtllm-serve.{server_idx}.CTX_{ctx_idx}.log"
|
|
)
|
|
)
|
|
for gen_idx in range(self.num_gen_servers):
|
|
server_files.append(
|
|
os.path.join(
|
|
self.output_dir, f"trtllm-serve.{server_idx}.GEN_{gen_idx}.log"
|
|
)
|
|
)
|
|
self.wait_for_endpoint_ready(
|
|
f"http://{disagg_server_hostname}:{disagg_server_port}/health",
|
|
server_files=server_files,
|
|
)
|
|
|
|
# Run all clients for this server
|
|
for client_idx, client_cmd in enumerate(self.client_cmds[server_idx]):
|
|
benchmark_file_path = os.path.join(
|
|
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
|
|
)
|
|
|
|
client_cmd_with_port = add_host_port_to_cmd(
|
|
client_cmd, disagg_server_hostname, disagg_server_port
|
|
)
|
|
print_info(f"Starting benchmark. cmd is {client_cmd_with_port}")
|
|
|
|
output = subprocess.check_output(
|
|
client_cmd_with_port,
|
|
env=copy.deepcopy(os.environ),
|
|
stderr=subprocess.STDOUT,
|
|
).decode()
|
|
|
|
with open(benchmark_file_path, "w") as benchmark_ctx:
|
|
benchmark_ctx.write(output)
|
|
outputs.append(output)
|
|
|
|
finally:
|
|
with open(benchmark_status_file, "w") as status_file:
|
|
status_file.write("Done")
|
|
|
|
return outputs
|
|
|
|
def get_cmd_str(self, server_idx: int) -> List[str]:
|
|
return ["multi-node disaggregated server tests, please check config files"]
|
|
|
|
|
|
def parse_select_pattern(select_pattern: str) -> list:
|
|
"""Parse select pattern (server config names).
|
|
|
|
Args:
|
|
select_pattern: Server config names separated by comma
|
|
(e.g., "r1_fp4_v2_dep4_mtp1_1k1k,r1_fp4_v2_tep4_mtp3_1k1k,r1_fp4_v2_tp4_mtp3_1k1k").
|
|
|
|
Returns:
|
|
List of server config name strings.
|
|
"""
|
|
return [name.strip() for name in select_pattern.split(",")]
|
|
|
|
|
|
class PerfSanityTestConfig:
|
|
"""Configuration for perf sanity tests."""
|
|
|
|
def __init__(self, test_case_name: str, output_dir: str):
|
|
self._output_dir = output_dir
|
|
self._perf_results: Dict[int, List[Dict[str, float]]] = {}
|
|
|
|
# Parse test case name
|
|
self.parse_test_case_name(test_case_name)
|
|
|
|
def parse_test_case_name(self, test_case_name: str):
|
|
"""Parse test case name into components."""
|
|
self._test_param_labels = test_case_name
|
|
|
|
# Extract configs from test param labels
|
|
labels = self._test_param_labels.split("-")
|
|
|
|
def get_gpu_type() -> str:
|
|
try:
|
|
# GB200 uses dgx_b200 for wrongly adding dgx_b200 to opensearch in the past.
|
|
mapping = {
|
|
"GB200": "dgx_b200",
|
|
"GB300": "gb300",
|
|
"B200": "b200",
|
|
"B300": "b300",
|
|
}
|
|
output = subprocess.check_output(
|
|
"nvidia-smi -q | grep 'Product Name' | head -1",
|
|
shell=True,
|
|
stderr=subprocess.DEVNULL,
|
|
text=True,
|
|
)
|
|
model = output.split()[-1]
|
|
return mapping.get(model, "unsupported")
|
|
except (subprocess.CalledProcessError, FileNotFoundError, IndexError):
|
|
print_error("Failed to get GPU type")
|
|
return "unsupported"
|
|
|
|
assert len(labels) > 1, "perf_sanity test must have a config file!"
|
|
is_disagg = "disagg" in labels[0]
|
|
self.upload_to_db = "upload" in labels[0]
|
|
self.gpu_type = get_gpu_type()
|
|
|
|
if is_disagg:
|
|
# For disagg: disagg_upload-deepseek-r1-fp4_8k1k_ctx1_gen1_dep32_bs128_eplb0_mtp0_ccb-UCX
|
|
self.runtime = "multi_node_disagg_server"
|
|
self.config_dir = "tests/integration/defs/perf/disagg/test_configs/disagg/perf"
|
|
config_base = "-".join(labels[1:])
|
|
self.config_file = (
|
|
f"{config_base}.yaml" if not config_base.endswith(".yaml") else config_base
|
|
)
|
|
self.select_pattern = None
|
|
else:
|
|
# For aggr: aggr_upload-config_yml or aggr_upload-config_yml-server_config_name
|
|
self.runtime = "aggr_server"
|
|
self.config_dir = "tests/scripts/perf-sanity"
|
|
config_base = labels[1]
|
|
self.config_file = (
|
|
f"{config_base}.yaml"
|
|
if config_base and not config_base.endswith(".yaml")
|
|
else config_base
|
|
)
|
|
# select_pattern is server config name (e.g., "r1_fp8_dep8_mtp1_1k1k")
|
|
self.select_pattern = "-".join(labels[2:]) if len(labels) > 2 else None
|
|
|
|
self.config_dir = os.getenv(
|
|
"TRTLLM_CONFIG_FOLDER", os.path.join(get_llm_root(), self.config_dir)
|
|
)
|
|
|
|
# Initialize server configs
|
|
self.server_configs: List = []
|
|
self.server_client_configs: Dict[int, List[ClientConfig]] = {}
|
|
|
|
def parse_config_file(self):
|
|
"""Parse config file based on runtime."""
|
|
config_file_path = os.path.join(self.config_dir, self.config_file)
|
|
|
|
if self.runtime == "aggr_server":
|
|
self._parse_aggr_config_file(config_file_path)
|
|
elif self.runtime == "multi_node_disagg_server":
|
|
self._parse_disagg_config_file(config_file_path, self.config_file)
|
|
|
|
def _parse_aggr_config_file(self, config_file_path: str):
|
|
"""Parse YAML config file for aggregated server."""
|
|
# Parse selection pattern (server config names)
|
|
if self.select_pattern:
|
|
selected_server_names = parse_select_pattern(self.select_pattern)
|
|
else:
|
|
selected_server_names = None
|
|
|
|
with open(config_file_path, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
metadata = config.get("metadata", {})
|
|
environment = config.get("environment", {})
|
|
hardware = config.get("hardware", {})
|
|
gpus_per_node = hardware.get("gpus_per_node", 0)
|
|
|
|
model_name = metadata.get("model_name", "")
|
|
server_env_var = environment.get("server_env_var", "")
|
|
client_env_var = environment.get("client_env_var", "")
|
|
|
|
server_configs = []
|
|
server_client_configs = {}
|
|
|
|
for server_idx, server_config_data in enumerate(config["server_configs"]):
|
|
# Check if this server should be included based on selected_server_names
|
|
if (
|
|
selected_server_names is not None
|
|
and server_config_data.get("name") not in selected_server_names
|
|
):
|
|
continue
|
|
|
|
server_config_data["model_name"] = (
|
|
model_name
|
|
if "model_name" not in server_config_data
|
|
else server_config_data["model_name"]
|
|
)
|
|
server_config_data["concurrency"] = -1
|
|
server_config_data["gpus_per_node"] = gpus_per_node
|
|
|
|
server_config = ServerConfig(server_config_data, server_env_var)
|
|
server_id = len(server_configs)
|
|
server_configs.append(server_config)
|
|
|
|
client_configs = []
|
|
for client_config_data in server_config_data["client_configs"]:
|
|
client_config = ClientConfig(
|
|
client_config_data, server_config_data["model_name"], client_env_var
|
|
)
|
|
client_configs.append(client_config)
|
|
|
|
server_client_configs[server_id] = client_configs
|
|
|
|
self.server_configs = server_configs
|
|
self.server_client_configs = server_client_configs
|
|
|
|
def _parse_disagg_config_file(self, config_file_path: str, config_file: str):
|
|
"""Parse YAML config file for disaggregated server."""
|
|
disagg_serving_type = os.environ.get("DISAGG_SERVING_TYPE", "BENCHMARK")
|
|
|
|
# Get config file base name (without extension)
|
|
config_file_base_name = os.path.splitext(config_file)[0]
|
|
|
|
with open(config_file_path, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
metadata = config.get("metadata", {})
|
|
hardware = config.get("hardware", {})
|
|
benchmark = config.get("benchmark", {})
|
|
environment = config.get("environment", {})
|
|
slurm_config = config.get("slurm", {})
|
|
worker_config = config.get("worker_config", {})
|
|
|
|
timeout = slurm_config.get("timeout", DEFAULT_TIMEOUT)
|
|
numa_bind = slurm_config.get("numa_bind", False)
|
|
gpus_per_node = hardware.get("gpus_per_node", 0)
|
|
model_name = metadata.get("model_name", "")
|
|
assert model_name, "model_name is required in metadata section"
|
|
|
|
benchmark_mode = benchmark.get("mode", "e2e")
|
|
if "gen_only" in benchmark_mode:
|
|
hardware["num_ctx_servers"] = 0
|
|
|
|
worker_env_var = environment.get("worker_env_var", "")
|
|
server_env_var = environment.get("server_env_var", "")
|
|
client_env_var = environment.get("client_env_var", "")
|
|
|
|
# Parse concurrency_list - can be string or list
|
|
concurrency_str = benchmark.get("concurrency_list", "1")
|
|
if isinstance(concurrency_str, str):
|
|
concurrency_values = [int(x) for x in concurrency_str.split()]
|
|
elif isinstance(concurrency_str, list):
|
|
concurrency_values = [int(x) for x in concurrency_str]
|
|
else:
|
|
concurrency_values = [int(concurrency_str)]
|
|
|
|
# Gen only mode only runs max concurrency
|
|
if "gen_only" in benchmark_mode:
|
|
concurrency_values = [max(concurrency_values)]
|
|
|
|
# Create ctx server config
|
|
ctx_server_config_data = {
|
|
"concurrency": max(concurrency_values),
|
|
"name": config_file_base_name,
|
|
"model_name": model_name,
|
|
"gpus_per_node": gpus_per_node,
|
|
"disagg_run_type": "ctx",
|
|
**worker_config.get("ctx", {}),
|
|
}
|
|
|
|
# Create gen server config
|
|
gen_server_config_data = {
|
|
"concurrency": max(concurrency_values),
|
|
"name": config_file_base_name,
|
|
"model_name": model_name,
|
|
"gpus_per_node": gpus_per_node,
|
|
"disagg_run_type": "gen",
|
|
**worker_config.get("gen", {}),
|
|
}
|
|
|
|
ctx_server_config = ServerConfig(ctx_server_config_data, worker_env_var)
|
|
gen_server_config = ServerConfig(gen_server_config_data, worker_env_var)
|
|
|
|
# Create disagg config
|
|
disagg_config = DisaggConfig(
|
|
name=config_file_base_name,
|
|
disagg_serving_type=disagg_serving_type,
|
|
hostname=socket.gethostname(),
|
|
numa_bind=numa_bind,
|
|
timeout=timeout,
|
|
benchmark_mode=benchmark_mode,
|
|
model_name=model_name,
|
|
hardware=hardware,
|
|
server_env_var=server_env_var,
|
|
)
|
|
|
|
# server_configs is a list with one element (tuple of ctx, gen, disagg config)
|
|
self.server_configs = [(ctx_server_config, gen_server_config, disagg_config)]
|
|
|
|
# Create client configs for each concurrency value
|
|
client_configs = []
|
|
for concurrency in concurrency_values:
|
|
client_config_data = {
|
|
"concurrency": concurrency,
|
|
"iterations": benchmark.get("multi_round", 1),
|
|
"isl": benchmark.get("input_length", 1024),
|
|
"osl": benchmark.get("output_length", 1024),
|
|
"random_range_ratio": benchmark.get("benchmark_ratio", 0.0),
|
|
"backend": "openai",
|
|
"use_chat_template": False,
|
|
"streaming": benchmark.get("streaming", True),
|
|
}
|
|
client_config = ClientConfig(client_config_data, model_name, client_env_var)
|
|
client_configs.append(client_config)
|
|
|
|
self.server_client_configs = {0: client_configs}
|
|
|
|
def get_commands(self):
|
|
"""Get commands based on runtime."""
|
|
perf_sanity_output_dir = os.path.join(self._output_dir, self._test_param_labels)
|
|
os.makedirs(perf_sanity_output_dir, exist_ok=True)
|
|
|
|
if self.runtime == "aggr_server":
|
|
return self._get_aggr_commands(perf_sanity_output_dir)
|
|
elif self.runtime == "multi_node_disagg_server":
|
|
return self._get_disagg_commands(perf_sanity_output_dir)
|
|
|
|
def _get_aggr_commands(self, output_dir: str):
|
|
"""Get commands for aggregated server."""
|
|
server_cmds = []
|
|
client_cmds = {}
|
|
|
|
for server_idx, client_configs in self.server_client_configs.items():
|
|
server_config = self.server_configs[server_idx]
|
|
server_cmd = server_config.to_cmd(output_dir)
|
|
|
|
# Generate extra-llm-api-config.yml
|
|
config_content = server_config.generate_extra_llm_api_config()
|
|
config_filename = f"extra-llm-api-config.aggr.{server_config.name}.yml"
|
|
config_path = os.path.join(output_dir, config_filename)
|
|
with open(config_path, "w") as f:
|
|
f.write(config_content)
|
|
|
|
server_cmds.append(server_cmd)
|
|
client_cmds[server_idx] = []
|
|
|
|
for client_config in client_configs:
|
|
client_cmd = client_config.to_cmd()
|
|
client_cmds[server_idx].append(client_cmd)
|
|
|
|
return AggrTestCmds(
|
|
server_cmds=server_cmds,
|
|
client_cmds=client_cmds,
|
|
timeout=DEFAULT_TIMEOUT,
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
def _get_disagg_commands(self, output_dir: str):
|
|
"""Get commands for disaggregated server."""
|
|
server_cmds = []
|
|
client_cmds = {}
|
|
|
|
for server_idx, (ctx_config, gen_config, disagg_config) in enumerate(self.server_configs):
|
|
numa_bind = disagg_config.numa_bind
|
|
timeout = disagg_config.timeout
|
|
disagg_serving_type = disagg_config.disagg_serving_type
|
|
|
|
# Generate ctx server command
|
|
ctx_cmd = ctx_config.to_cmd(output_dir, numa_bind, "CTX")
|
|
if "CTX" in disagg_serving_type:
|
|
config_content = ctx_config.generate_extra_llm_api_config()
|
|
config_path = os.path.join(
|
|
output_dir, f"extra-llm-api-config.ctx.{ctx_config.name}.yml"
|
|
)
|
|
with open(config_path, "w") as f:
|
|
f.write(config_content)
|
|
|
|
# Generate gen server command
|
|
gen_cmd = gen_config.to_cmd(output_dir, numa_bind, "GEN")
|
|
if "GEN" in disagg_serving_type:
|
|
config_content = gen_config.generate_extra_llm_api_config()
|
|
config_path = os.path.join(
|
|
output_dir, f"extra-llm-api-config.gen.{gen_config.name}.yml"
|
|
)
|
|
with open(config_path, "w") as f:
|
|
f.write(config_content)
|
|
|
|
# Generate disagg server command
|
|
disagg_cmd = [
|
|
"trtllm-serve",
|
|
"disaggregated",
|
|
"-c",
|
|
f"{output_dir}/server_config.{server_idx}.yaml",
|
|
"-t",
|
|
str(timeout),
|
|
"-r",
|
|
str(timeout),
|
|
]
|
|
|
|
server_cmds.append((ctx_cmd, gen_cmd, disagg_cmd))
|
|
|
|
# Add client commands
|
|
client_cmds[server_idx] = []
|
|
for client_config in self.server_client_configs[server_idx]:
|
|
client_cmd = client_config.to_cmd()
|
|
client_cmds[server_idx].append(client_cmd)
|
|
|
|
disagg_config = self.server_configs[0][2]
|
|
return DisaggTestCmds(
|
|
server_cmds=server_cmds,
|
|
client_cmds=client_cmds,
|
|
timeout=disagg_config.timeout,
|
|
hostname=disagg_config.hostname,
|
|
disagg_serving_type=disagg_config.disagg_serving_type,
|
|
num_ctx_servers=disagg_config.num_ctx_servers,
|
|
num_gen_servers=disagg_config.num_gen_servers,
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
def run_ex(self, commands) -> Dict[int, List[str]]:
|
|
"""Run commands and collect outputs."""
|
|
outputs = {}
|
|
|
|
for server_idx in range(len(commands.server_cmds)):
|
|
try:
|
|
with io.StringIO() as buf:
|
|
with contextlib.redirect_stdout(buf):
|
|
server_outputs = commands.run_cmd(server_idx)
|
|
for output in server_outputs:
|
|
print(collect_and_clean_myelin_time(output))
|
|
|
|
# Check for errors in each output
|
|
for output in server_outputs:
|
|
self._check_benchmark_output_for_errors(output)
|
|
|
|
print(buf.getvalue())
|
|
|
|
outputs[server_idx] = server_outputs
|
|
|
|
except Exception as e:
|
|
print_error(f"Test command failed for server {server_idx}. Error: {e}")
|
|
if isinstance(e, subprocess.CalledProcessError):
|
|
print_error("--- stdout ---")
|
|
if e.stdout:
|
|
print_error(e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout)
|
|
print_error("--------------")
|
|
outputs[server_idx] = []
|
|
|
|
return outputs
|
|
|
|
def _check_benchmark_output_for_errors(self, output: str) -> None:
|
|
"""Check whether the benchmark output contains error messages."""
|
|
if not output:
|
|
return
|
|
|
|
# Check for non-zero failed requests
|
|
failed_requests_match = re.search(r"Failed requests:\s+(\d+)", output)
|
|
if failed_requests_match:
|
|
failed_count = int(failed_requests_match.group(1))
|
|
if failed_count > 0:
|
|
error_msg = f"Benchmark output contains {failed_count} failed requests."
|
|
raise Exception(error_msg)
|
|
|
|
# Check for explicit failure markers
|
|
if "!FAILED REQUESTS!" in output or "!CHECK LOG FOR ERRORS!" in output:
|
|
error_msg = "Benchmark output contains failure markers."
|
|
raise Exception(error_msg)
|
|
|
|
def get_perf_result(self, outputs: Dict[int, List[str]]):
|
|
"""Parse performance results from outputs."""
|
|
|
|
def parse_metrics_from_output(output: str) -> Optional[Dict[str, float]]:
|
|
"""Parse all metrics from a single output string."""
|
|
metrics = {}
|
|
for line in output.split("\n"):
|
|
for metric_type, regex in PERF_METRIC_LOG_QUERIES.items():
|
|
if metric_type in metrics:
|
|
continue
|
|
match = regex.search(line)
|
|
if match:
|
|
metrics[metric_type] = float(match.group(1))
|
|
break
|
|
return metrics
|
|
|
|
self._perf_results = {}
|
|
for server_idx, client_configs in self.server_client_configs.items():
|
|
self._perf_results[server_idx] = []
|
|
server_outputs = outputs.get(server_idx, [])
|
|
for output in server_outputs:
|
|
metrics = parse_metrics_from_output(output)
|
|
self._perf_results[server_idx].append(metrics)
|
|
|
|
def check_test_failure(self):
|
|
"""Check if any server failed based on perf results."""
|
|
error_msg = ""
|
|
for server_idx, client_configs in self.server_client_configs.items():
|
|
server_perf_results = self._perf_results.get(server_idx, [])
|
|
if len(server_perf_results) != len(client_configs):
|
|
error_msg += (
|
|
f"Server {server_idx}'s perf results number: {len(server_perf_results)} "
|
|
f"is not equal to client number: {len(client_configs)}. "
|
|
)
|
|
for client_idx, metrics in enumerate(server_perf_results):
|
|
if len(metrics) != len(PERF_METRIC_LOG_QUERIES):
|
|
error_msg += (
|
|
f"Some metrics in Server {server_idx} Client {client_idx} are missing. "
|
|
f"The broken metrics is {metrics}. "
|
|
)
|
|
|
|
if error_msg:
|
|
raise Exception(error_msg)
|
|
|
|
print_info("All servers passed")
|
|
|
|
def upload_test_results_to_database(self):
|
|
"""Upload test results and baseline to database."""
|
|
|
|
def add_prefix(key: str, prefix_name: str) -> str:
|
|
type_prefix = key[0:2]
|
|
rest = key[2:]
|
|
return f"{type_prefix}{prefix_name}_{rest}"
|
|
|
|
def add_list_prefix(config_list: List, prefix_name: str) -> List:
|
|
return [add_prefix(key, prefix_name) for key in config_list]
|
|
|
|
def add_dict_prefix(config_dict: dict, prefix_name: str) -> dict:
|
|
return {add_prefix(key, prefix_name): value for key, value in config_dict.items()}
|
|
|
|
match_keys = []
|
|
is_scenario_mode = False
|
|
|
|
if self.runtime == "aggr_server":
|
|
job_config = get_job_info()
|
|
is_post_merge = job_config["b_is_post_merge"]
|
|
new_data_dict = {}
|
|
cmd_idx = 0
|
|
for server_idx, client_configs in self.server_client_configs.items():
|
|
server_config = self.server_configs[server_idx]
|
|
server_config_dict = server_config.to_db_data()
|
|
server_perf_results = self._perf_results.get(server_idx, [])
|
|
# Skip if server failed
|
|
if len(server_perf_results) != len(client_configs):
|
|
cmd_idx += len(client_configs)
|
|
continue
|
|
|
|
for client_idx, client_config in enumerate(client_configs):
|
|
client_config_dict = client_config.to_db_data()
|
|
|
|
# Skip if metrics missing
|
|
if server_perf_results[client_idx] is None:
|
|
print_info(
|
|
f"Skipped posting command {cmd_idx}'s test results since some metrics are missing."
|
|
)
|
|
cmd_idx += 1
|
|
continue
|
|
|
|
new_data = {
|
|
"s_gpu_type": self.gpu_type,
|
|
"s_runtime": "multi_node_aggr_server"
|
|
if server_config.gpus != server_config.gpus_per_node
|
|
else "aggr_server",
|
|
}
|
|
new_data.update(job_config)
|
|
new_data.update(server_config_dict)
|
|
new_data.update(client_config_dict)
|
|
# Add test_case_name for convenient filtering on OpenSearch
|
|
new_data["s_test_case_name"] = f"{server_config.name}-{client_config.name}"
|
|
|
|
for metric_name in PERF_METRIC_LOG_QUERIES:
|
|
new_data[f"d_{metric_name}"] = server_perf_results[client_idx][metric_name]
|
|
|
|
add_id(new_data)
|
|
new_data_dict[cmd_idx] = new_data
|
|
cmd_idx += 1
|
|
|
|
if not match_keys:
|
|
if server_config.match_mode == "scenario":
|
|
match_keys = SCENARIO_MATCH_FIELDS.copy()
|
|
is_scenario_mode = True
|
|
else:
|
|
match_keys.extend(["s_gpu_type", "s_runtime"])
|
|
match_keys.extend(server_config.to_match_keys())
|
|
match_keys.extend(client_config.to_match_keys())
|
|
|
|
elif self.runtime == "multi_node_disagg_server":
|
|
# Only BENCHMARK node uploads
|
|
if self.server_configs[0][2].disagg_serving_type != "BENCHMARK":
|
|
return
|
|
|
|
job_config = get_job_info()
|
|
is_post_merge = job_config["b_is_post_merge"]
|
|
new_data_dict = {}
|
|
cmd_idx = 0
|
|
|
|
for server_idx, (ctx_config, gen_config, disagg_config) in enumerate(
|
|
self.server_configs
|
|
):
|
|
client_configs = self.server_client_configs[server_idx]
|
|
server_perf_results = self._perf_results.get(server_idx, [])
|
|
# Skip if server failed
|
|
if len(server_perf_results) != len(client_configs):
|
|
cmd_idx += len(client_configs)
|
|
continue
|
|
|
|
for client_idx, client_config in enumerate(client_configs):
|
|
# Skip if metrics missing
|
|
if server_perf_results[client_idx] is None:
|
|
print_info(
|
|
f"Skipped posting command {cmd_idx}'s test results since some metrics are missing."
|
|
)
|
|
cmd_idx += 1
|
|
continue
|
|
|
|
# Get server configs with prefixed keys
|
|
ctx_server_config_dict = add_dict_prefix(ctx_config.to_db_data(), "ctx")
|
|
gen_server_config_dict = add_dict_prefix(gen_config.to_db_data(), "gen")
|
|
client_config_dict = client_config.to_db_data()
|
|
|
|
num_ctx_servers = disagg_config.num_ctx_servers
|
|
num_gen_servers = disagg_config.num_gen_servers
|
|
|
|
new_data = {
|
|
"s_gpu_type": self.gpu_type,
|
|
"s_runtime": "multi_node_disagg_server",
|
|
"s_benchmark_mode": disagg_config.benchmark_mode,
|
|
"s_server_env_var": disagg_config.server_env_var,
|
|
"l_num_ctx_servers": num_ctx_servers,
|
|
"l_num_gen_servers": num_gen_servers,
|
|
}
|
|
new_data.update(job_config)
|
|
if num_ctx_servers > 0:
|
|
new_data.update(ctx_server_config_dict)
|
|
if num_gen_servers > 0:
|
|
new_data.update(gen_server_config_dict)
|
|
new_data.update(client_config_dict)
|
|
# Add test_case_name for convenient filtering on OpenSearch
|
|
new_data["s_test_case_name"] = f"{disagg_config.name}-{client_config.name}"
|
|
|
|
for metric_name in PERF_METRIC_LOG_QUERIES:
|
|
new_data[f"d_{metric_name}"] = server_perf_results[client_idx][metric_name]
|
|
|
|
add_id(new_data)
|
|
new_data_dict[cmd_idx] = new_data
|
|
cmd_idx += 1
|
|
|
|
if not match_keys:
|
|
match_keys.extend(
|
|
[
|
|
"s_gpu_type",
|
|
"s_runtime",
|
|
"s_benchmark_mode",
|
|
"l_num_ctx_servers",
|
|
"l_num_gen_servers",
|
|
]
|
|
)
|
|
if num_ctx_servers > 0:
|
|
match_keys.extend(add_list_prefix(ctx_config.to_match_keys(), "ctx"))
|
|
if num_gen_servers > 0:
|
|
match_keys.extend(add_list_prefix(gen_config.to_match_keys(), "gen"))
|
|
match_keys.extend(client_config.to_match_keys())
|
|
else:
|
|
return
|
|
|
|
if not new_data_dict:
|
|
print_info("No data to upload to database.")
|
|
return
|
|
|
|
# Find common values across all data entries to narrow down query scope
|
|
common_values_dict = get_common_values(new_data_dict, match_keys)
|
|
|
|
# Get history data for each cmd_idx
|
|
history_baseline_dict, history_data_dict = get_history_data(
|
|
new_data_dict, match_keys, common_values_dict
|
|
)
|
|
|
|
# Update regression info in new_data_dict
|
|
prepare_regressive_test_cases(history_baseline_dict, new_data_dict)
|
|
|
|
if is_post_merge:
|
|
# Prepare new baseline data for post-merge
|
|
new_baseline_data_dict = prepare_baseline_data(
|
|
history_baseline_dict, history_data_dict, new_data_dict
|
|
)
|
|
else:
|
|
# Pre-merge does not need to upload baseline data
|
|
new_baseline_data_dict = None
|
|
|
|
if self.upload_to_db:
|
|
# Upload the new perf data and baseline data to database
|
|
post_new_perf_data(new_baseline_data_dict, new_data_dict)
|
|
|
|
check_perf_regression(new_data_dict, fail_on_regression=is_scenario_mode)
|
|
|
|
|
|
# Perf sanity test case parameters
|
|
AGG_TEST_TYPES = ["aggr_upload", "aggr"]
|
|
DISAGG_TEST_TYPES = ["disagg_upload", "disagg"]
|
|
|
|
AGGR_CONFIG_FOLDER = "tests/scripts/perf-sanity"
|
|
DISAGG_CONFIG_FOLDER = "tests/integration/defs/perf/disagg/test_configs/disagg/perf"
|
|
|
|
|
|
def get_server_config_names(yaml_path: str) -> List[str]:
|
|
"""Read a YAML file and return the list of server_config names."""
|
|
try:
|
|
with open(yaml_path, "r") as f:
|
|
data = yaml.safe_load(f)
|
|
if data and "server_configs" in data:
|
|
return [config.get("name", "") for config in data["server_configs"]]
|
|
except Exception:
|
|
pass
|
|
return []
|
|
|
|
|
|
def get_yaml_files_with_server_names(directory: str) -> Dict[str, List[str]]:
|
|
"""Scan directory for YAML files and return dict of {basename: [server_config_names]}."""
|
|
yaml_files = glob.glob(os.path.join(directory, "*.yaml"))
|
|
result = {}
|
|
for yaml_path in sorted(yaml_files):
|
|
basename = os.path.splitext(os.path.basename(yaml_path))[0]
|
|
server_names = get_server_config_names(yaml_path)
|
|
result[basename] = server_names
|
|
return result
|
|
|
|
|
|
def get_aggr_test_cases() -> List[str]:
|
|
"""Generate aggr test cases based on actual server_config names in YAML files."""
|
|
llm_root = get_llm_root()
|
|
aggr_config_dir = os.path.join(llm_root, AGGR_CONFIG_FOLDER)
|
|
yaml_server_names = get_yaml_files_with_server_names(aggr_config_dir)
|
|
|
|
test_cases = []
|
|
for config_yml, server_names in yaml_server_names.items():
|
|
for test_type in AGG_TEST_TYPES:
|
|
# Case without select_pattern (runs all server configs)
|
|
test_cases.append(f"{test_type}-{config_yml}")
|
|
|
|
# Cases with single server config name
|
|
for server_name in server_names:
|
|
test_cases.append(f"{test_type}-{config_yml}-{server_name}")
|
|
|
|
return test_cases
|
|
|
|
|
|
def get_disagg_test_cases() -> List[str]:
|
|
"""Generate disagg test cases."""
|
|
llm_root = get_llm_root()
|
|
disagg_config_dir = os.path.join(llm_root, DISAGG_CONFIG_FOLDER)
|
|
yaml_files = glob.glob(os.path.join(disagg_config_dir, "*.yaml"))
|
|
basenames = sorted([os.path.splitext(os.path.basename(f))[0] for f in yaml_files])
|
|
|
|
test_cases = []
|
|
for config_yml in basenames:
|
|
for test_type in DISAGG_TEST_TYPES:
|
|
test_cases.append(f"{test_type}-{config_yml}")
|
|
|
|
return test_cases
|
|
|
|
|
|
# Hardcoded multi-test test cases from test db.
|
|
MULTI_TEST_TEST_CASES = []
|
|
|
|
# Generate all test case combinations
|
|
# For aggr: {test_type}-{config_yml}, {test_type}-{config_yml}-{server_config_name}
|
|
# For disagg: {test_type}-{config_yml}
|
|
PERF_SANITY_TEST_CASES = get_aggr_test_cases() + get_disagg_test_cases() + MULTI_TEST_TEST_CASES
|
|
|
|
|
|
@pytest.mark.parametrize("perf_sanity_test_case", PERF_SANITY_TEST_CASES)
|
|
def test_e2e(output_dir, perf_sanity_test_case):
|
|
# Create config and parse test case name
|
|
config = PerfSanityTestConfig(perf_sanity_test_case, output_dir)
|
|
|
|
# Parse config file to get server_configs and server_client_configs
|
|
config.parse_config_file()
|
|
|
|
# Get commands
|
|
commands = config.get_commands()
|
|
|
|
# Run commands and collect outputs
|
|
outputs = config.run_ex(commands)
|
|
|
|
# For disagg mode, only BENCHMARK node parses results and uploads
|
|
if config.runtime == "multi_node_disagg_server":
|
|
disagg_config = config.server_configs[0][2]
|
|
if disagg_config.disagg_serving_type != "BENCHMARK":
|
|
print_info(
|
|
f"Disagg serving type is {disagg_config.disagg_serving_type}, skipping perf result parsing and upload."
|
|
)
|
|
return
|
|
|
|
# Parse performance results
|
|
config.get_perf_result(outputs)
|
|
|
|
# Check for test failures
|
|
config.check_test_failure()
|
|
|
|
# Upload results to database
|
|
config.upload_test_results_to_database()
|