[https://nvbugs/5820874][fix] Adjust deepgemm tuning buckets to cover larger num_tokens's scope (#11259)

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
This commit is contained in:
chenfeiz0326 2026-02-05 23:12:38 +08:00 committed by GitHub
parent 719e82c429
commit eae480b713
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 119 additions and 46 deletions

View File

@ -1450,17 +1450,17 @@ def _(
def deep_gemm_gen_tuning_buckets(x: int):
buckets = tuple(range(8, 128, 8))
# Clamp x to be between 4096 and 8192.
if x >= 128:
x = min(x, 8192)
x = max(x, 4096)
buckets += tuple(range(128, x, 128))
return buckets
class fp8SwapABGemmRunner(TunableRunner):
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, deep_gemm_gen_tuning_buckets), ),
tune_max_num_tokens=4096,
)
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, deep_gemm_gen_tuning_buckets), ), )
def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool):
self.output_dtype = output_dtype
@ -1477,9 +1477,7 @@ class fp8SwapABGemmRunner(TunableRunner):
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
# Encode swap_ab as False (0) and True (1). Currently enabled when GEMM m <= 128.
input, _, _ = inputs
return [0, 1] if input.shape[0] <= 128 else [0]
return [0]
def forward(
self,
@ -1494,8 +1492,7 @@ class fp8SwapABGemmRunner(TunableRunner):
dtype=self.output_dtype,
)
forward_func = deep_gemm.fp8_gemm_ntt if tactic == 1 else deep_gemm.fp8_gemm_nt
forward_func(
deep_gemm.fp8_gemm_nt(
(a, a_sf),
(weight, weight_scale),
output,
@ -1511,14 +1508,13 @@ def fp8_swap_ab_gemm(
weight_scale: torch.Tensor,
output_dtype: torch.dtype = torch.bfloat16,
disable_ue8m0_cast: bool = False,
tune_max_num_tokens: int = 4096,
) -> torch.Tensor:
tuner = AutoTuner.get()
fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
output_dtype,
disable_ue8m0_cast,
)
fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
_, best_tactic = tuner.choose_one(
"trtllm::fp8_swap_ab_gemm",
[fp8_swap_ab_gemm_runner],
@ -1538,7 +1534,6 @@ def _(
weight_scale: torch.Tensor,
output_dtype: torch.dtype = torch.bfloat16,
disable_ue8m0_cast: bool = False,
tune_max_num_tokens: int = 4096,
) -> torch.Tensor:
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)

View File

@ -549,20 +549,26 @@ class AggrTestCmds(NamedTuple):
server_cmd_with_port = add_host_port_to_cmd(server_cmd, server_hostname, server_port)
server_file_path = os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.log")
server_error_file_path = os.path.join(
self.output_dir, f"trtllm-serve.{server_idx}.error.log"
)
print_info(f"Starting server. cmd is {server_cmd_with_port}")
with open(server_file_path, "w") as server_ctx:
with (
open(server_file_path, "w") as server_ctx,
open(server_error_file_path, "w") as server_err_ctx,
):
server_proc = subprocess.Popen(
server_cmd_with_port,
stdout=server_ctx,
stderr=subprocess.STDOUT,
stderr=server_err_ctx,
env=copy.deepcopy(os.environ),
)
wait_for_endpoint_ready(
f"http://{server_hostname}:{server_port}/health",
timeout=self.timeout,
check_files=[server_file_path],
check_files=[server_file_path, server_error_file_path],
server_proc=server_proc,
)
@ -571,20 +577,27 @@ class AggrTestCmds(NamedTuple):
client_file_path = os.path.join(
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
)
client_error_file_path = os.path.join(
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.error.log"
)
client_cmd_with_port = add_host_port_to_cmd(
client_cmd, server_hostname, server_port
)
print_info(f"Starting client. cmd is {client_cmd_with_port}")
output = subprocess.check_output(
result = subprocess.run(
client_cmd_with_port,
stderr=subprocess.STDOUT,
capture_output=True,
env=copy.deepcopy(os.environ),
).decode()
check=True,
)
output = result.stdout.decode()
with open(client_file_path, "w") as client_ctx:
client_ctx.write(output)
with open(client_error_file_path, "w") as client_err_ctx:
client_err_ctx.write(result.stderr.decode())
outputs.append(output)
@ -723,7 +736,10 @@ class DisaggTestCmds(NamedTuple):
if "CTX" in self.disagg_serving_type or "GEN" in self.disagg_serving_type:
self._generate_hostname_file(server_idx, port)
server_file_path = os.path.join(
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.log"
)
server_error_file_path = os.path.join(
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.error.log"
)
is_ctx = "CTX" in self.disagg_serving_type
server_cmd = ctx_cmd if is_ctx else gen_cmd
@ -732,11 +748,14 @@ class DisaggTestCmds(NamedTuple):
print_info(
f"Starting server. disagg_serving_type: {self.disagg_serving_type} cmd is {server_cmd}"
)
with open(server_file_path, "w") as server_ctx:
with (
open(server_file_path, "w") as server_ctx,
open(server_error_file_path, "w") as server_err_ctx,
):
server_proc = subprocess.Popen(
server_cmd,
stdout=server_ctx,
stderr=subprocess.STDOUT,
stderr=server_err_ctx,
env=copy.deepcopy(os.environ),
)
self.wait_for_benchmark_ready(benchmark_status_file)
@ -747,16 +766,22 @@ class DisaggTestCmds(NamedTuple):
elif self.disagg_serving_type == "DISAGG_SERVER":
disagg_server_file_path = os.path.join(
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.log"
)
disagg_server_error_file_path = os.path.join(
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.error.log"
)
try:
self._generate_disagg_server_config(server_idx, port)
print_info(f"Starting disagg server. cmd is {disagg_cmd}")
with open(disagg_server_file_path, "w") as disagg_server_ctx:
with (
open(disagg_server_file_path, "w") as disagg_server_ctx,
open(disagg_server_error_file_path, "w") as disagg_server_err_ctx,
):
disagg_server_proc = subprocess.Popen(
disagg_cmd,
stdout=disagg_server_ctx,
stderr=subprocess.STDOUT,
stderr=disagg_server_err_ctx,
env=copy.deepcopy(os.environ),
)
self.wait_for_benchmark_ready(benchmark_status_file)
@ -770,21 +795,28 @@ class DisaggTestCmds(NamedTuple):
disagg_server_hostname, disagg_server_port = (
self._get_disagg_server_hostname_and_port(server_idx)
)
server_files = [
os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.DISAGG_SERVER.log"),
]
for ctx_idx in range(self.num_ctx_servers):
server_files.append(
server_files = (
[
os.path.join(
self.output_dir, f"trtllm-serve.{server_idx}.CTX_{ctx_idx}.log"
)
)
for gen_idx in range(self.num_gen_servers):
server_files.append(
self.output_dir, f"trtllm-serve.DISAGG_SERVER.{server_idx}.log"
),
os.path.join(
self.output_dir, f"trtllm-serve.{server_idx}.GEN_{gen_idx}.log"
self.output_dir, f"trtllm-serve.DISAGG_SERVER.{server_idx}.error.log"
),
]
+ [
os.path.join(
self.output_dir, f"trtllm-serve.CTX_{ctx_idx}.{server_idx}.log"
)
)
for ctx_idx in range(self.num_ctx_servers)
]
+ [
os.path.join(
self.output_dir, f"trtllm-serve.GEN_{gen_idx}.{server_idx}.log"
)
for gen_idx in range(self.num_gen_servers)
]
)
wait_for_endpoint_ready(
f"http://{disagg_server_hostname}:{disagg_server_port}/health",
timeout=self.timeout,
@ -796,20 +828,27 @@ class DisaggTestCmds(NamedTuple):
benchmark_file_path = os.path.join(
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
)
benchmark_error_file_path = os.path.join(
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.error.log"
)
client_cmd_with_port = add_host_port_to_cmd(
client_cmd, disagg_server_hostname, disagg_server_port
)
print_info(f"Starting benchmark. cmd is {client_cmd_with_port}")
output = subprocess.check_output(
result = subprocess.run(
client_cmd_with_port,
capture_output=True,
env=copy.deepcopy(os.environ),
stderr=subprocess.STDOUT,
).decode()
check=True,
)
output = result.stdout.decode()
with open(benchmark_file_path, "w") as benchmark_ctx:
benchmark_ctx.write(output)
with open(benchmark_error_file_path, "w") as benchmark_err_ctx:
benchmark_err_ctx.write(result.stderr.decode())
outputs.append(output)
finally:
@ -1197,11 +1236,21 @@ class PerfSanityTestConfig:
except Exception as e:
print_error(f"Test command failed for server {server_idx}. Error: {e}")
if isinstance(e, subprocess.CalledProcessError):
print_error("--- stdout ---")
if e.stdout:
print_error(e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout)
print_error("--------------")
# Print content of trtllm-serve error log files
error_log_pattern = os.path.join(
commands.output_dir, f"trtllm-serve*{server_idx}.error.log"
)
error_log_files = glob.glob(error_log_pattern)
for error_log_file in error_log_files:
if os.path.exists(error_log_file):
print_error(f"--- {error_log_file} ---")
with open(error_log_file, "r") as f:
content = f.read()
if content.strip():
print_error(content)
else:
print_error("(empty)")
print_error("-" * len(f"--- {error_log_file} ---"))
outputs[server_idx] = []
return outputs

View File

@ -20,6 +20,7 @@ l0_dgx_b200_perf_sanity:
# - perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_dep8_mtp1_1k1k] TIMEOUT (90) # failed
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_tp8_mtp3_8k1k]
# - perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_dep8_mtp1_8k1k] TIMEOUT (90) # failed
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_tp8_6k1k] TIMEOUT (90)
# deepseek_r1_fp4_v2_blackwell
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_blackwell-r1_fp4_v2_tp4_mtp3_1k1k]
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_blackwell-r1_fp4_v2_tp4_mtp3_8k1k]

View File

@ -134,3 +134,31 @@ server_configs:
osl: 1024
backend: "openai"
dataset_file: datasets/perf-ci/deepseek_r1-8k1k-20480-ratio-1_for_serve.json
# 6k1k configs - TP8 with TRTLLM, MTP1
- name: "r1_fp8_tp8_6k1k"
model_name: "deepseek_r1_0528_fp8"
tensor_parallel_size: 8
moe_expert_parallel_size: 1
pipeline_parallel_size: 1
max_batch_size: 512
max_num_tokens: 8192
attn_backend: "TRTLLM"
enable_attention_dp: false
moe_config:
backend: 'TRTLLM'
cuda_graph_config:
enable_padding: true
max_batch_size: 64
kv_cache_config:
dtype: 'fp8'
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
client_configs:
- name: "con64_iter10_6k1k"
concurrency: 64
iterations: 10
isl: 6144
osl: 1024
backend: "openai"
random_range_ratio: 0.2

View File

@ -4,7 +4,7 @@ import time
import requests
ERROR_KEYWORDS = ["RuntimeError", "out of memory", "ValueError"]
ERROR_KEYWORDS = ["RuntimeError", "out of memory", "ValueError", "FileNotFoundError"]
def wait_for_endpoint_ready(