[https://nvbugs/5552889][fix] fix: Prevent empty batch when using attention DP with disagg (#8372)

Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Patrice Castonguay 2025-10-15 21:11:04 -04:00 committed by Yanchao Lu
parent 9bcd2e6c0a
commit 65c138108e
4 changed files with 145 additions and 20 deletions

View File

@ -821,12 +821,7 @@ class PyExecutor:
f'{len(scheduled_batch.generation_requests)} generation requests'
)
if self.enable_attention_dp:
tp_batch_sizes = self.dist.tp_allgather(
scheduled_batch.batch_size)
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
can_queue = self._can_queue(scheduled_batch)
if not can_queue:
self.micro_batches[microbatch_id] = None
@ -1004,6 +999,16 @@ class PyExecutor:
self.send_handles[microbatch_id].wait()
self.send_handles[microbatch_id] = None
def _can_queue(self, scheduled_batch):
if self.enable_attention_dp:
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
return can_queue
def _prepare_and_schedule_batch(self):
new_requests = self._fetch_and_activate_new_requests()
if self.should_stop_processing:
@ -1126,8 +1131,8 @@ class PyExecutor:
finished_requests = []
if scheduled_batch.batch_size > 0 or (
self.enable_attention_dp and self.dist.tp_size > 1):
can_queue = self._can_queue(scheduled_batch)
if can_queue:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
@ -1139,8 +1144,11 @@ class PyExecutor:
self._kv_connector_start_batch(scheduled_batch)
if scheduled_batch.batch_size > 0 or (
self.enable_attention_dp and self.dist.tp_size > 1):
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
if self.kv_connector_manager:
can_queue = self._can_queue(scheduled_batch)
if can_queue:
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
# init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated.
if self.guided_decoder is not None:
@ -1298,7 +1306,8 @@ class PyExecutor:
self._pause_requests(scheduled_batch.paused_requests)
if scheduled_batch.batch_size > 0:
can_queue = self._can_queue(scheduled_batch)
if can_queue:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
@ -1307,7 +1316,11 @@ class PyExecutor:
self._kv_connector_start_batch(scheduled_batch)
if scheduled_batch.batch_size > 0:
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
if self.kv_connector_manager:
can_queue = self._can_queue(scheduled_batch)
if can_queue:
# The generation requests that are do not have batch_idx,
# needs to be in front of the batch due to the assumptions

View File

@ -0,0 +1,61 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-Lite/bf16
backend: "pytorch"
context_servers:
num_instances: 1
build_config:
max_batch_size: 10
max_num_tokens: 512
max_seq_len: 768
max_batch_size: 10
max_num_tokens: 512
max_seq_len: 768
tensor_parallel_size: 2
moe_expert_parallel_size: 2
enable_attention_dp: true
pipeline_parallel_size: 1
print_iter_log: true
cuda_graph_config: null
disable_overlap_scheduler: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.05
max_tokens: 512
cache_transceiver_config:
max_tokens_in_buffer: 8448
backend: DEFAULT
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
build_config:
max_batch_size: 1
max_num_tokens: 2048
max_seq_len: 2560
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
enable_lm_head_tp_in_adp: false
pipeline_parallel_size: 1
max_batch_size: 1
max_num_tokens: 2048
max_seq_len: 2560
cuda_graph_config:
enable_padding: true
batch_sizes:
- 1
print_iter_log: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.7
max_tokens: 2560
moe_config:
backend: CUTLASS
cache_transceiver_config:
max_tokens_in_buffer: 8448
backend: DEFAULT
stream_interval: 1
num_postprocess_workers: 1
urls:
- "localhost:8002"

View File

@ -261,6 +261,10 @@ def get_test_config(test_desc, example_dir, test_root):
(4,
f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml"
),
"deepseek_v3_lite_bf16_empty_batch":
(3,
f"{test_configs_root}/disagg_config_deepseek_v3_lite_empty_batch.yaml"
),
}
if test_desc not in config_map:
@ -1530,14 +1534,19 @@ def run_disaggregated_benchmark(example_dir,
benchmark_model_root,
shared_gpt_path,
env=None,
cwd=None):
cwd=None,
num_ranks=2,
random_input_len=16,
random_output_len=64,
num_prompts=100,
max_concurrency=32,
skip_warmup=False):
"""Run disaggregated test with given configuration."""
run_env = env.copy()
run_env["UCX_TLS"] = "^ib"
num_rank = 2
workers_cmd = [
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
config_file
]
@ -1589,15 +1598,15 @@ def run_disaggregated_benchmark(example_dir,
'--dataset-path',
shared_gpt_path,
'--random-input-len',
'256',
str(random_input_len),
'--random-output-len',
'64',
str(random_output_len),
'--random-prefix-len',
'0',
'--num-prompts',
'320',
str(num_prompts),
'--max-concurrency',
'32',
str(max_concurrency),
'--host',
'localhost',
'--port',
@ -1608,7 +1617,8 @@ def run_disaggregated_benchmark(example_dir,
'e2el,ttft',
]
# warm up
check_call(benchmark_cmd, env=env)
if not skip_warmup:
check_call(benchmark_cmd, env=env)
output = check_output(benchmark_cmd, env=env)
e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)"
ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)"
@ -1718,3 +1728,43 @@ def test_disaggregated_benchmark_on_diff_backends(
assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el
assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft
@pytest.mark.parametrize("benchmark_model_root", ['DeepSeek-V3-Lite-bf16'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_bf16_empty_batch(
disaggregated_example_root, llm_venv, benchmark_model_root,
benchmark_root, shared_gpt_path):
src_dst_dict = {
benchmark_model_root:
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
test_desc = "deepseek_v3_lite_bf16_empty_batch"
num_ranks, config_file = get_test_config(test_desc,
disaggregated_example_root,
os.path.dirname(__file__))
env = llm_venv._new_env.copy()
e2el, ttft = run_disaggregated_benchmark(
disaggregated_example_root,
config_file,
benchmark_root,
benchmark_model_root,
shared_gpt_path,
env=env,
cwd=llm_venv.get_working_directory(),
num_ranks=num_ranks,
num_prompts=10,
max_concurrency=10,
random_input_len=384,
random_output_len=1536,
skip_warmup=True)
print(f"E2EL: {e2el} ms, TTFT: {ttft} ms")
assert e2el > 0 and ttft > 0

View File

@ -158,6 +158,7 @@ l0_dgx_h100:
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_empty_batch[DeepSeek-V3-Lite-bf16]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp[DeepSeek-V3-Lite-fp8]
- disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16]