mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[https://nvbugs/5820874][fix] Adjust deepgemm tuning buckets to cover larger num_tokens's scope (#11259)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
This commit is contained in:
parent
719e82c429
commit
eae480b713
@ -1450,17 +1450,17 @@ def _(
|
||||
|
||||
def deep_gemm_gen_tuning_buckets(x: int):
|
||||
buckets = tuple(range(8, 128, 8))
|
||||
# Clamp x to be between 4096 and 8192.
|
||||
if x >= 128:
|
||||
x = min(x, 8192)
|
||||
x = max(x, 4096)
|
||||
buckets += tuple(range(128, x, 128))
|
||||
return buckets
|
||||
|
||||
|
||||
class fp8SwapABGemmRunner(TunableRunner):
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, deep_gemm_gen_tuning_buckets), ),
|
||||
tune_max_num_tokens=4096,
|
||||
)
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, deep_gemm_gen_tuning_buckets), ), )
|
||||
|
||||
def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool):
|
||||
self.output_dtype = output_dtype
|
||||
@ -1477,9 +1477,7 @@ class fp8SwapABGemmRunner(TunableRunner):
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
# Encode swap_ab as False (0) and True (1). Currently enabled when GEMM m <= 128.
|
||||
input, _, _ = inputs
|
||||
return [0, 1] if input.shape[0] <= 128 else [0]
|
||||
return [0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1494,8 +1492,7 @@ class fp8SwapABGemmRunner(TunableRunner):
|
||||
dtype=self.output_dtype,
|
||||
)
|
||||
|
||||
forward_func = deep_gemm.fp8_gemm_ntt if tactic == 1 else deep_gemm.fp8_gemm_nt
|
||||
forward_func(
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
(a, a_sf),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
@ -1511,14 +1508,13 @@ def fp8_swap_ab_gemm(
|
||||
weight_scale: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
disable_ue8m0_cast: bool = False,
|
||||
tune_max_num_tokens: int = 4096,
|
||||
) -> torch.Tensor:
|
||||
tuner = AutoTuner.get()
|
||||
fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
|
||||
output_dtype,
|
||||
disable_ue8m0_cast,
|
||||
)
|
||||
fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::fp8_swap_ab_gemm",
|
||||
[fp8_swap_ab_gemm_runner],
|
||||
@ -1538,7 +1534,6 @@ def _(
|
||||
weight_scale: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
disable_ue8m0_cast: bool = False,
|
||||
tune_max_num_tokens: int = 4096,
|
||||
) -> torch.Tensor:
|
||||
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
|
||||
|
||||
|
||||
@ -549,20 +549,26 @@ class AggrTestCmds(NamedTuple):
|
||||
server_cmd_with_port = add_host_port_to_cmd(server_cmd, server_hostname, server_port)
|
||||
|
||||
server_file_path = os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.log")
|
||||
server_error_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{server_idx}.error.log"
|
||||
)
|
||||
|
||||
print_info(f"Starting server. cmd is {server_cmd_with_port}")
|
||||
with open(server_file_path, "w") as server_ctx:
|
||||
with (
|
||||
open(server_file_path, "w") as server_ctx,
|
||||
open(server_error_file_path, "w") as server_err_ctx,
|
||||
):
|
||||
server_proc = subprocess.Popen(
|
||||
server_cmd_with_port,
|
||||
stdout=server_ctx,
|
||||
stderr=subprocess.STDOUT,
|
||||
stderr=server_err_ctx,
|
||||
env=copy.deepcopy(os.environ),
|
||||
)
|
||||
|
||||
wait_for_endpoint_ready(
|
||||
f"http://{server_hostname}:{server_port}/health",
|
||||
timeout=self.timeout,
|
||||
check_files=[server_file_path],
|
||||
check_files=[server_file_path, server_error_file_path],
|
||||
server_proc=server_proc,
|
||||
)
|
||||
|
||||
@ -571,20 +577,27 @@ class AggrTestCmds(NamedTuple):
|
||||
client_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
|
||||
)
|
||||
client_error_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.error.log"
|
||||
)
|
||||
|
||||
client_cmd_with_port = add_host_port_to_cmd(
|
||||
client_cmd, server_hostname, server_port
|
||||
)
|
||||
print_info(f"Starting client. cmd is {client_cmd_with_port}")
|
||||
|
||||
output = subprocess.check_output(
|
||||
result = subprocess.run(
|
||||
client_cmd_with_port,
|
||||
stderr=subprocess.STDOUT,
|
||||
capture_output=True,
|
||||
env=copy.deepcopy(os.environ),
|
||||
).decode()
|
||||
check=True,
|
||||
)
|
||||
output = result.stdout.decode()
|
||||
|
||||
with open(client_file_path, "w") as client_ctx:
|
||||
client_ctx.write(output)
|
||||
with open(client_error_file_path, "w") as client_err_ctx:
|
||||
client_err_ctx.write(result.stderr.decode())
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
@ -723,7 +736,10 @@ class DisaggTestCmds(NamedTuple):
|
||||
if "CTX" in self.disagg_serving_type or "GEN" in self.disagg_serving_type:
|
||||
self._generate_hostname_file(server_idx, port)
|
||||
server_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
|
||||
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.log"
|
||||
)
|
||||
server_error_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.error.log"
|
||||
)
|
||||
is_ctx = "CTX" in self.disagg_serving_type
|
||||
server_cmd = ctx_cmd if is_ctx else gen_cmd
|
||||
@ -732,11 +748,14 @@ class DisaggTestCmds(NamedTuple):
|
||||
print_info(
|
||||
f"Starting server. disagg_serving_type: {self.disagg_serving_type} cmd is {server_cmd}"
|
||||
)
|
||||
with open(server_file_path, "w") as server_ctx:
|
||||
with (
|
||||
open(server_file_path, "w") as server_ctx,
|
||||
open(server_error_file_path, "w") as server_err_ctx,
|
||||
):
|
||||
server_proc = subprocess.Popen(
|
||||
server_cmd,
|
||||
stdout=server_ctx,
|
||||
stderr=subprocess.STDOUT,
|
||||
stderr=server_err_ctx,
|
||||
env=copy.deepcopy(os.environ),
|
||||
)
|
||||
self.wait_for_benchmark_ready(benchmark_status_file)
|
||||
@ -747,16 +766,22 @@ class DisaggTestCmds(NamedTuple):
|
||||
|
||||
elif self.disagg_serving_type == "DISAGG_SERVER":
|
||||
disagg_server_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
|
||||
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.log"
|
||||
)
|
||||
disagg_server_error_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.error.log"
|
||||
)
|
||||
try:
|
||||
self._generate_disagg_server_config(server_idx, port)
|
||||
print_info(f"Starting disagg server. cmd is {disagg_cmd}")
|
||||
with open(disagg_server_file_path, "w") as disagg_server_ctx:
|
||||
with (
|
||||
open(disagg_server_file_path, "w") as disagg_server_ctx,
|
||||
open(disagg_server_error_file_path, "w") as disagg_server_err_ctx,
|
||||
):
|
||||
disagg_server_proc = subprocess.Popen(
|
||||
disagg_cmd,
|
||||
stdout=disagg_server_ctx,
|
||||
stderr=subprocess.STDOUT,
|
||||
stderr=disagg_server_err_ctx,
|
||||
env=copy.deepcopy(os.environ),
|
||||
)
|
||||
self.wait_for_benchmark_ready(benchmark_status_file)
|
||||
@ -770,21 +795,28 @@ class DisaggTestCmds(NamedTuple):
|
||||
disagg_server_hostname, disagg_server_port = (
|
||||
self._get_disagg_server_hostname_and_port(server_idx)
|
||||
)
|
||||
server_files = [
|
||||
os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.DISAGG_SERVER.log"),
|
||||
]
|
||||
for ctx_idx in range(self.num_ctx_servers):
|
||||
server_files.append(
|
||||
server_files = (
|
||||
[
|
||||
os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{server_idx}.CTX_{ctx_idx}.log"
|
||||
)
|
||||
)
|
||||
for gen_idx in range(self.num_gen_servers):
|
||||
server_files.append(
|
||||
self.output_dir, f"trtllm-serve.DISAGG_SERVER.{server_idx}.log"
|
||||
),
|
||||
os.path.join(
|
||||
self.output_dir, f"trtllm-serve.{server_idx}.GEN_{gen_idx}.log"
|
||||
self.output_dir, f"trtllm-serve.DISAGG_SERVER.{server_idx}.error.log"
|
||||
),
|
||||
]
|
||||
+ [
|
||||
os.path.join(
|
||||
self.output_dir, f"trtllm-serve.CTX_{ctx_idx}.{server_idx}.log"
|
||||
)
|
||||
)
|
||||
for ctx_idx in range(self.num_ctx_servers)
|
||||
]
|
||||
+ [
|
||||
os.path.join(
|
||||
self.output_dir, f"trtllm-serve.GEN_{gen_idx}.{server_idx}.log"
|
||||
)
|
||||
for gen_idx in range(self.num_gen_servers)
|
||||
]
|
||||
)
|
||||
wait_for_endpoint_ready(
|
||||
f"http://{disagg_server_hostname}:{disagg_server_port}/health",
|
||||
timeout=self.timeout,
|
||||
@ -796,20 +828,27 @@ class DisaggTestCmds(NamedTuple):
|
||||
benchmark_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
|
||||
)
|
||||
benchmark_error_file_path = os.path.join(
|
||||
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.error.log"
|
||||
)
|
||||
|
||||
client_cmd_with_port = add_host_port_to_cmd(
|
||||
client_cmd, disagg_server_hostname, disagg_server_port
|
||||
)
|
||||
print_info(f"Starting benchmark. cmd is {client_cmd_with_port}")
|
||||
|
||||
output = subprocess.check_output(
|
||||
result = subprocess.run(
|
||||
client_cmd_with_port,
|
||||
capture_output=True,
|
||||
env=copy.deepcopy(os.environ),
|
||||
stderr=subprocess.STDOUT,
|
||||
).decode()
|
||||
check=True,
|
||||
)
|
||||
output = result.stdout.decode()
|
||||
|
||||
with open(benchmark_file_path, "w") as benchmark_ctx:
|
||||
benchmark_ctx.write(output)
|
||||
with open(benchmark_error_file_path, "w") as benchmark_err_ctx:
|
||||
benchmark_err_ctx.write(result.stderr.decode())
|
||||
outputs.append(output)
|
||||
|
||||
finally:
|
||||
@ -1197,11 +1236,21 @@ class PerfSanityTestConfig:
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Test command failed for server {server_idx}. Error: {e}")
|
||||
if isinstance(e, subprocess.CalledProcessError):
|
||||
print_error("--- stdout ---")
|
||||
if e.stdout:
|
||||
print_error(e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout)
|
||||
print_error("--------------")
|
||||
# Print content of trtllm-serve error log files
|
||||
error_log_pattern = os.path.join(
|
||||
commands.output_dir, f"trtllm-serve*{server_idx}.error.log"
|
||||
)
|
||||
error_log_files = glob.glob(error_log_pattern)
|
||||
for error_log_file in error_log_files:
|
||||
if os.path.exists(error_log_file):
|
||||
print_error(f"--- {error_log_file} ---")
|
||||
with open(error_log_file, "r") as f:
|
||||
content = f.read()
|
||||
if content.strip():
|
||||
print_error(content)
|
||||
else:
|
||||
print_error("(empty)")
|
||||
print_error("-" * len(f"--- {error_log_file} ---"))
|
||||
outputs[server_idx] = []
|
||||
|
||||
return outputs
|
||||
|
||||
@ -20,6 +20,7 @@ l0_dgx_b200_perf_sanity:
|
||||
# - perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_dep8_mtp1_1k1k] TIMEOUT (90) # failed
|
||||
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_tp8_mtp3_8k1k]
|
||||
# - perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_dep8_mtp1_8k1k] TIMEOUT (90) # failed
|
||||
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_tp8_6k1k] TIMEOUT (90)
|
||||
# deepseek_r1_fp4_v2_blackwell
|
||||
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_blackwell-r1_fp4_v2_tp4_mtp3_1k1k]
|
||||
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_blackwell-r1_fp4_v2_tp4_mtp3_8k1k]
|
||||
|
||||
@ -134,3 +134,31 @@ server_configs:
|
||||
osl: 1024
|
||||
backend: "openai"
|
||||
dataset_file: datasets/perf-ci/deepseek_r1-8k1k-20480-ratio-1_for_serve.json
|
||||
|
||||
# 6k1k configs - TP8 with TRTLLM, MTP1
|
||||
- name: "r1_fp8_tp8_6k1k"
|
||||
model_name: "deepseek_r1_0528_fp8"
|
||||
tensor_parallel_size: 8
|
||||
moe_expert_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
max_batch_size: 512
|
||||
max_num_tokens: 8192
|
||||
attn_backend: "TRTLLM"
|
||||
enable_attention_dp: false
|
||||
moe_config:
|
||||
backend: 'TRTLLM'
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 64
|
||||
kv_cache_config:
|
||||
dtype: 'fp8'
|
||||
enable_block_reuse: false
|
||||
free_gpu_memory_fraction: 0.8
|
||||
client_configs:
|
||||
- name: "con64_iter10_6k1k"
|
||||
concurrency: 64
|
||||
iterations: 10
|
||||
isl: 6144
|
||||
osl: 1024
|
||||
backend: "openai"
|
||||
random_range_ratio: 0.2
|
||||
|
||||
@ -4,7 +4,7 @@ import time
|
||||
|
||||
import requests
|
||||
|
||||
ERROR_KEYWORDS = ["RuntimeError", "out of memory", "ValueError"]
|
||||
ERROR_KEYWORDS = ["RuntimeError", "out of memory", "ValueError", "FileNotFoundError"]
|
||||
|
||||
|
||||
def wait_for_endpoint_ready(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user