mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6675][infra] Nixl test completion (#6623)
Signed-off-by: Bo Deng <deemod@nvidia.com>
This commit is contained in:
parent
232a39de1f
commit
d289d85bff
@ -581,7 +581,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={k: [results[k]]
|
||||
for k in metrics},
|
||||
for k in metrics if k in results},
|
||||
extra_info={
|
||||
k: results[k]
|
||||
for k in results if k not in metrics and k not in ignored_metrics
|
||||
|
||||
@ -453,6 +453,40 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite"
|
||||
MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
|
||||
|
||||
def test_nixl_backend(self):
|
||||
ctx_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "nixl"
|
||||
}
|
||||
}
|
||||
gen_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "nixl"
|
||||
}
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
"backend": "pytorch",
|
||||
"context_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8001"]
|
||||
},
|
||||
"generation_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8002"]
|
||||
}
|
||||
}
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@parametrize_with_ids("overlap_scheduler", [True, False])
|
||||
@parametrize_with_ids("mtp_nextn",
|
||||
[0, pytest.param(2, marks=skip_pre_hopper)])
|
||||
@ -550,6 +584,42 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen3/Qwen3-8B"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8"
|
||||
|
||||
def test_nixl_backend(self):
|
||||
ctx_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "nixl"
|
||||
}
|
||||
}
|
||||
gen_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "nixl"
|
||||
}
|
||||
}
|
||||
ctx_server_config["cache_transceiver_config"]
|
||||
ctx_server_config["cache_transceiver_config"]
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
"backend": "pytorch",
|
||||
"context_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8001"]
|
||||
},
|
||||
"generation_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8002"]
|
||||
}
|
||||
}
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||
def test_auto_dtype(self, overlap_scheduler):
|
||||
ctx_server_config = {
|
||||
|
||||
@ -14,11 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from defs.conftest import skip_arm, skip_no_hopper
|
||||
from defs.trt_test_alternative import check_call, popen
|
||||
import yaml
|
||||
from defs.conftest import llm_models_root, skip_arm, skip_no_hopper
|
||||
from defs.trt_test_alternative import check_call, check_output, popen
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -1051,3 +1054,227 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp(
|
||||
"deepseek_v3_lite_fp8_tp1_two_mtp",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def benchmark_root():
|
||||
llm_root = os.getenv("LLM_ROOT")
|
||||
return os.path.join(llm_root, "tensorrt_llm", "serve", "scripts")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def shared_gpt_path():
|
||||
DEFAULT_LLM_MODEL_ROOT = os.path.join("/scratch.trt_llm_data", "llm-models")
|
||||
LLM_MODELS_ROOT = os.environ.get("LLM_MODELS_ROOT", DEFAULT_LLM_MODEL_ROOT)
|
||||
return os.path.join(LLM_MODELS_ROOT, "datasets",
|
||||
"ShareGPT_V3_unfiltered_cleaned_split.json")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def benchmark_model_root(request):
|
||||
models_root = llm_models_root()
|
||||
if (request.param == "DeepSeek-V3-Lite-fp8"):
|
||||
model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "fp8")
|
||||
elif (request.param == "DeepSeek-V3-Lite-bf16"):
|
||||
model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "bf16")
|
||||
elif request.param == "llama-v3-8b-hf":
|
||||
model_path = os.path.join(models_root, "llama-models-v3", "8B")
|
||||
elif request.param == "llama-3.1-8b-instruct-hf-fp8":
|
||||
model_path = os.path.join(models_root, "llama-3.1-model",
|
||||
"Llama-3.1-8B-Instruct-FP8")
|
||||
else:
|
||||
raise ValueError(f"Failed to find the model: {request.param}")
|
||||
return model_path
|
||||
|
||||
|
||||
def run_disaggregated_benchmark(example_dir,
|
||||
config_file,
|
||||
benchmark_root,
|
||||
benchmark_model_root,
|
||||
shared_gpt_path,
|
||||
env=None,
|
||||
cwd=None):
|
||||
"""Run disaggregated test with given configuration."""
|
||||
run_env = env.copy()
|
||||
run_env["UCX_TLS"] = "^ib"
|
||||
num_rank = 2
|
||||
workers_cmd = [
|
||||
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
|
||||
str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||
config_file
|
||||
]
|
||||
|
||||
server_start_timeout = 900
|
||||
server_cmd = [
|
||||
'trtllm-serve', 'disaggregated', '--server_start_timeout',
|
||||
str(server_start_timeout), '-c', config_file
|
||||
]
|
||||
try:
|
||||
with ( # Start workers
|
||||
open('output_workers.log', 'w') as output_workers,
|
||||
popen(workers_cmd,
|
||||
stdout=output_workers,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as workers_proc,
|
||||
# Start server
|
||||
open('output_disagg.log', 'w') as output_disagg,
|
||||
popen(server_cmd,
|
||||
stdout=output_disagg,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as server_proc):
|
||||
# Ensure the sever has started
|
||||
client_dir = f"{example_dir}/clients"
|
||||
client_cmd = [
|
||||
'python3', f'{client_dir}/disagg_client.py', '-c',
|
||||
f'{example_dir}/disagg_config.yaml', '-p',
|
||||
f'{client_dir}/prompts.json', '--ignore-eos',
|
||||
'--server-start-timeout',
|
||||
str(server_start_timeout)
|
||||
]
|
||||
# Warm up
|
||||
check_call(client_cmd,
|
||||
env=env,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
# Start Benchmark
|
||||
benchmark_script = os.path.join(benchmark_root,
|
||||
"benchmark_serving.py")
|
||||
benchmark_cmd = [
|
||||
'python3',
|
||||
benchmark_script,
|
||||
'--model',
|
||||
benchmark_model_root,
|
||||
'--tokenizer',
|
||||
benchmark_model_root,
|
||||
'--dataset-name',
|
||||
'random',
|
||||
'--dataset-path',
|
||||
shared_gpt_path,
|
||||
'--random-input-len',
|
||||
'256',
|
||||
'--random-output-len',
|
||||
'64',
|
||||
'--random-prefix-len',
|
||||
'0',
|
||||
'--num-prompts',
|
||||
'320',
|
||||
'--max-concurrency',
|
||||
'32',
|
||||
'--host',
|
||||
'localhost',
|
||||
'--port',
|
||||
'8000',
|
||||
'--ignore-eos',
|
||||
'--no-test-input',
|
||||
'--percentile-metrics',
|
||||
'e2el,ttft',
|
||||
]
|
||||
# warm up
|
||||
check_call(benchmark_cmd, env=env)
|
||||
output = check_output(benchmark_cmd, env=env)
|
||||
e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)"
|
||||
ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)"
|
||||
e2el_match = re.search(e2el_pattern, output)
|
||||
ttft_match = re.search(ttft_pattern, output)
|
||||
if e2el_match and ttft_match:
|
||||
median_e2el = float(e2el_match.group(1))
|
||||
median_ttft = float(ttft_match.group(1))
|
||||
return median_e2el, median_ttft
|
||||
else:
|
||||
raise ValueError("No benchmark result found")
|
||||
|
||||
except Exception:
|
||||
# Print outputs on error
|
||||
logger.error("-------- Workers output --------")
|
||||
with open('output_workers.log', 'r') as f:
|
||||
logger.error(f.read())
|
||||
|
||||
logger.error("-------- Disagg server output --------")
|
||||
with open('output_disagg.log', 'r') as f:
|
||||
logger.error(f.read())
|
||||
raise
|
||||
finally:
|
||||
server_proc.terminate()
|
||||
workers_proc.terminate()
|
||||
server_proc.wait()
|
||||
workers_proc.wait()
|
||||
|
||||
|
||||
def get_config_for_benchmark(model_root, backend):
|
||||
serve_config = {
|
||||
"model": model_root,
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
"backend": "pytorch",
|
||||
"context_servers": {
|
||||
"num_instances": 1,
|
||||
"max_batch_size": 2,
|
||||
"max_num_tokens": 384,
|
||||
"max_seq_len": 320,
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
"disable_overlap_scheduler": True,
|
||||
"cache_transceiver_config": {
|
||||
"backend": backend,
|
||||
"max_tokens_in_buffer": 512,
|
||||
},
|
||||
"urls": ["localhost:8001"]
|
||||
},
|
||||
"generation_servers": {
|
||||
"num_instances": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
"max_batch_size": 2,
|
||||
"max_num_tokens": 384,
|
||||
"max_seq_len": 320,
|
||||
"cache_transceiver_config": {
|
||||
"backend": backend,
|
||||
"max_tokens_in_buffer": 512,
|
||||
},
|
||||
"urls": ["localhost:8002"]
|
||||
}
|
||||
}
|
||||
return serve_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("benchmark_model_root", [
|
||||
'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf',
|
||||
'llama-3.1-8b-instruct-hf-fp8'
|
||||
],
|
||||
indirect=True)
|
||||
def test_disaggregated_benchmark_on_diff_backends(
|
||||
disaggregated_test_root, disaggregated_example_root, llm_venv,
|
||||
benchmark_model_root, benchmark_root, shared_gpt_path):
|
||||
nixl_config = get_config_for_benchmark(benchmark_model_root, "nixl")
|
||||
ucx_config = get_config_for_benchmark(benchmark_model_root, "ucx")
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
nixl_config_path = os.path.join(temp_dir.name, "nixl_config.yaml")
|
||||
ucx_config_path = os.path.join(temp_dir.name, "ucx_config.yaml")
|
||||
with open(nixl_config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(nixl_config, f)
|
||||
with open(ucx_config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(ucx_config, f)
|
||||
|
||||
env = llm_venv._new_env.copy()
|
||||
nixl_e2el, nixl_ttft = run_disaggregated_benchmark(
|
||||
disaggregated_example_root,
|
||||
nixl_config_path,
|
||||
benchmark_root,
|
||||
benchmark_model_root,
|
||||
shared_gpt_path,
|
||||
env=env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
ucx_e2el, ucx_ttft = run_disaggregated_benchmark(
|
||||
disaggregated_example_root,
|
||||
ucx_config_path,
|
||||
benchmark_root,
|
||||
benchmark_model_root,
|
||||
shared_gpt_path,
|
||||
env=env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
print(f"Nixl E2EL: {nixl_e2el} ms, UCX E2EL: {ucx_e2el} ms")
|
||||
print(f"Nixl TTFT: {nixl_ttft} ms, UCX TTFT: {ucx_ttft} ms")
|
||||
|
||||
assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el
|
||||
assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft
|
||||
|
||||
@ -545,6 +545,8 @@ accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype
|
||||
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
|
||||
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
|
||||
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]
|
||||
|
||||
@ -103,6 +103,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtl
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3]
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype
|
||||
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0]
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8]
|
||||
|
||||
@ -70,3 +70,9 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
|
||||
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
|
||||
@ -50,6 +50,8 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_bs1
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
|
||||
- condition:
|
||||
@ -107,6 +109,10 @@ l0_dgx_h100:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user