mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5552889][fix] fix: Prevent empty batch when using attention DP with disagg (#8372)
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
9bcd2e6c0a
commit
65c138108e
@ -821,12 +821,7 @@ class PyExecutor:
|
||||
f'{len(scheduled_batch.generation_requests)} generation requests'
|
||||
)
|
||||
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(
|
||||
scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
|
||||
if not can_queue:
|
||||
self.micro_batches[microbatch_id] = None
|
||||
@ -1004,6 +999,16 @@ class PyExecutor:
|
||||
self.send_handles[microbatch_id].wait()
|
||||
self.send_handles[microbatch_id] = None
|
||||
|
||||
def _can_queue(self, scheduled_batch):
|
||||
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
|
||||
return can_queue
|
||||
|
||||
def _prepare_and_schedule_batch(self):
|
||||
new_requests = self._fetch_and_activate_new_requests()
|
||||
if self.should_stop_processing:
|
||||
@ -1126,8 +1131,8 @@ class PyExecutor:
|
||||
|
||||
finished_requests = []
|
||||
|
||||
if scheduled_batch.batch_size > 0 or (
|
||||
self.enable_attention_dp and self.dist.tp_size > 1):
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
if can_queue:
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
self._prepare_disagg_gen_transmission_complete(
|
||||
@ -1139,8 +1144,11 @@ class PyExecutor:
|
||||
|
||||
self._kv_connector_start_batch(scheduled_batch)
|
||||
|
||||
if scheduled_batch.batch_size > 0 or (
|
||||
self.enable_attention_dp and self.dist.tp_size > 1):
|
||||
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
|
||||
if self.kv_connector_manager:
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
|
||||
if can_queue:
|
||||
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
|
||||
# init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated.
|
||||
if self.guided_decoder is not None:
|
||||
@ -1298,7 +1306,8 @@ class PyExecutor:
|
||||
|
||||
self._pause_requests(scheduled_batch.paused_requests)
|
||||
|
||||
if scheduled_batch.batch_size > 0:
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
if can_queue:
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
self._prepare_disagg_gen_transmission_complete(
|
||||
@ -1307,7 +1316,11 @@ class PyExecutor:
|
||||
|
||||
self._kv_connector_start_batch(scheduled_batch)
|
||||
|
||||
if scheduled_batch.batch_size > 0:
|
||||
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
|
||||
if self.kv_connector_manager:
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
|
||||
if can_queue:
|
||||
|
||||
# The generation requests that are do not have batch_idx,
|
||||
# needs to be in front of the batch due to the assumptions
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
model: DeepSeek-V3-Lite/bf16
|
||||
backend: "pytorch"
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
build_config:
|
||||
max_batch_size: 10
|
||||
max_num_tokens: 512
|
||||
max_seq_len: 768
|
||||
max_batch_size: 10
|
||||
max_num_tokens: 512
|
||||
max_seq_len: 768
|
||||
tensor_parallel_size: 2
|
||||
moe_expert_parallel_size: 2
|
||||
enable_attention_dp: true
|
||||
pipeline_parallel_size: 1
|
||||
print_iter_log: true
|
||||
cuda_graph_config: null
|
||||
disable_overlap_scheduler: true
|
||||
kv_cache_config:
|
||||
enable_block_reuse: false
|
||||
free_gpu_memory_fraction: 0.05
|
||||
max_tokens: 512
|
||||
cache_transceiver_config:
|
||||
max_tokens_in_buffer: 8448
|
||||
backend: DEFAULT
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
build_config:
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 2048
|
||||
max_seq_len: 2560
|
||||
tensor_parallel_size: 1
|
||||
moe_expert_parallel_size: 1
|
||||
enable_attention_dp: false
|
||||
enable_lm_head_tp_in_adp: false
|
||||
pipeline_parallel_size: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 2048
|
||||
max_seq_len: 2560
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
batch_sizes:
|
||||
- 1
|
||||
print_iter_log: true
|
||||
kv_cache_config:
|
||||
enable_block_reuse: false
|
||||
free_gpu_memory_fraction: 0.7
|
||||
max_tokens: 2560
|
||||
moe_config:
|
||||
backend: CUTLASS
|
||||
cache_transceiver_config:
|
||||
max_tokens_in_buffer: 8448
|
||||
backend: DEFAULT
|
||||
stream_interval: 1
|
||||
num_postprocess_workers: 1
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -261,6 +261,10 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
(4,
|
||||
f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml"
|
||||
),
|
||||
"deepseek_v3_lite_bf16_empty_batch":
|
||||
(3,
|
||||
f"{test_configs_root}/disagg_config_deepseek_v3_lite_empty_batch.yaml"
|
||||
),
|
||||
}
|
||||
|
||||
if test_desc not in config_map:
|
||||
@ -1530,14 +1534,19 @@ def run_disaggregated_benchmark(example_dir,
|
||||
benchmark_model_root,
|
||||
shared_gpt_path,
|
||||
env=None,
|
||||
cwd=None):
|
||||
cwd=None,
|
||||
num_ranks=2,
|
||||
random_input_len=16,
|
||||
random_output_len=64,
|
||||
num_prompts=100,
|
||||
max_concurrency=32,
|
||||
skip_warmup=False):
|
||||
"""Run disaggregated test with given configuration."""
|
||||
run_env = env.copy()
|
||||
run_env["UCX_TLS"] = "^ib"
|
||||
num_rank = 2
|
||||
workers_cmd = [
|
||||
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
|
||||
str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||
config_file
|
||||
]
|
||||
|
||||
@ -1589,15 +1598,15 @@ def run_disaggregated_benchmark(example_dir,
|
||||
'--dataset-path',
|
||||
shared_gpt_path,
|
||||
'--random-input-len',
|
||||
'256',
|
||||
str(random_input_len),
|
||||
'--random-output-len',
|
||||
'64',
|
||||
str(random_output_len),
|
||||
'--random-prefix-len',
|
||||
'0',
|
||||
'--num-prompts',
|
||||
'320',
|
||||
str(num_prompts),
|
||||
'--max-concurrency',
|
||||
'32',
|
||||
str(max_concurrency),
|
||||
'--host',
|
||||
'localhost',
|
||||
'--port',
|
||||
@ -1608,7 +1617,8 @@ def run_disaggregated_benchmark(example_dir,
|
||||
'e2el,ttft',
|
||||
]
|
||||
# warm up
|
||||
check_call(benchmark_cmd, env=env)
|
||||
if not skip_warmup:
|
||||
check_call(benchmark_cmd, env=env)
|
||||
output = check_output(benchmark_cmd, env=env)
|
||||
e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)"
|
||||
ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)"
|
||||
@ -1718,3 +1728,43 @@ def test_disaggregated_benchmark_on_diff_backends(
|
||||
|
||||
assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el
|
||||
assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft
|
||||
|
||||
|
||||
@pytest.mark.parametrize("benchmark_model_root", ['DeepSeek-V3-Lite-bf16'],
|
||||
indirect=True)
|
||||
def test_disaggregated_deepseek_v3_lite_bf16_empty_batch(
|
||||
disaggregated_example_root, llm_venv, benchmark_model_root,
|
||||
benchmark_root, shared_gpt_path):
|
||||
|
||||
src_dst_dict = {
|
||||
benchmark_model_root:
|
||||
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
test_desc = "deepseek_v3_lite_bf16_empty_batch"
|
||||
num_ranks, config_file = get_test_config(test_desc,
|
||||
disaggregated_example_root,
|
||||
os.path.dirname(__file__))
|
||||
|
||||
env = llm_venv._new_env.copy()
|
||||
e2el, ttft = run_disaggregated_benchmark(
|
||||
disaggregated_example_root,
|
||||
config_file,
|
||||
benchmark_root,
|
||||
benchmark_model_root,
|
||||
shared_gpt_path,
|
||||
env=env,
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
num_ranks=num_ranks,
|
||||
num_prompts=10,
|
||||
max_concurrency=10,
|
||||
random_input_len=384,
|
||||
random_output_len=1536,
|
||||
skip_warmup=True)
|
||||
print(f"E2EL: {e2el} ms, TTFT: {ttft} ms")
|
||||
|
||||
assert e2el > 0 and ttft > 0
|
||||
|
||||
@ -158,6 +158,7 @@ l0_dgx_h100:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_empty_batch[DeepSeek-V3-Lite-bf16]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user