From 65c138108ec8f2f8325ec24351faa94946ca949e Mon Sep 17 00:00:00 2001 From: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Date: Wed, 15 Oct 2025 21:11:04 -0400 Subject: [PATCH] [https://nvbugs/5552889][fix] fix: Prevent empty batch when using attention DP with disagg (#8372) Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 37 +++++++---- ...g_config_deepseek_v3_lite_empty_batch.yaml | 61 +++++++++++++++++ .../defs/disaggregated/test_disaggregated.py | 66 ++++++++++++++++--- .../test_lists/test-db/l0_dgx_h100.yml | 1 + 4 files changed, 145 insertions(+), 20 deletions(-) create mode 100644 tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 248599835d..0bcc090ff5 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -821,12 +821,7 @@ class PyExecutor: f'{len(scheduled_batch.generation_requests)} generation requests' ) - if self.enable_attention_dp: - tp_batch_sizes = self.dist.tp_allgather( - scheduled_batch.batch_size) - can_queue = 0 not in tp_batch_sizes - else: - can_queue = scheduled_batch.batch_size > 0 + can_queue = self._can_queue(scheduled_batch) if not can_queue: self.micro_batches[microbatch_id] = None @@ -1004,6 +999,16 @@ class PyExecutor: self.send_handles[microbatch_id].wait() self.send_handles[microbatch_id] = None + def _can_queue(self, scheduled_batch): + + if self.enable_attention_dp: + tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size) + can_queue = 0 not in tp_batch_sizes + else: + can_queue = scheduled_batch.batch_size > 0 + + return can_queue + def _prepare_and_schedule_batch(self): new_requests = self._fetch_and_activate_new_requests() if self.should_stop_processing: @@ -1126,8 +1131,8 @@ class PyExecutor: finished_requests = [] - if scheduled_batch.batch_size > 0 or ( - self.enable_attention_dp and self.dist.tp_size > 1): + can_queue = self._can_queue(scheduled_batch) + if can_queue: if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer self._prepare_disagg_gen_transmission_complete( @@ -1139,8 +1144,11 @@ class PyExecutor: self._kv_connector_start_batch(scheduled_batch) - if scheduled_batch.batch_size > 0 or ( - self.enable_attention_dp and self.dist.tp_size > 1): + # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed + if self.kv_connector_manager: + can_queue = self._can_queue(scheduled_batch) + + if can_queue: # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers. # init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated. if self.guided_decoder is not None: @@ -1298,7 +1306,8 @@ class PyExecutor: self._pause_requests(scheduled_batch.paused_requests) - if scheduled_batch.batch_size > 0: + can_queue = self._can_queue(scheduled_batch) + if can_queue: if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer self._prepare_disagg_gen_transmission_complete( @@ -1307,7 +1316,11 @@ class PyExecutor: self._kv_connector_start_batch(scheduled_batch) - if scheduled_batch.batch_size > 0: + # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed + if self.kv_connector_manager: + can_queue = self._can_queue(scheduled_batch) + + if can_queue: # The generation requests that are do not have batch_idx, # needs to be in front of the batch due to the assumptions diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml new file mode 100644 index 0000000000..3646377829 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml @@ -0,0 +1,61 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/bf16 +backend: "pytorch" +context_servers: + num_instances: 1 + build_config: + max_batch_size: 10 + max_num_tokens: 512 + max_seq_len: 768 + max_batch_size: 10 + max_num_tokens: 512 + max_seq_len: 768 + tensor_parallel_size: 2 + moe_expert_parallel_size: 2 + enable_attention_dp: true + pipeline_parallel_size: 1 + print_iter_log: true + cuda_graph_config: null + disable_overlap_scheduler: true + kv_cache_config: + enable_block_reuse: false + free_gpu_memory_fraction: 0.05 + max_tokens: 512 + cache_transceiver_config: + max_tokens_in_buffer: 8448 + backend: DEFAULT + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + build_config: + max_batch_size: 1 + max_num_tokens: 2048 + max_seq_len: 2560 + tensor_parallel_size: 1 + moe_expert_parallel_size: 1 + enable_attention_dp: false + enable_lm_head_tp_in_adp: false + pipeline_parallel_size: 1 + max_batch_size: 1 + max_num_tokens: 2048 + max_seq_len: 2560 + cuda_graph_config: + enable_padding: true + batch_sizes: + - 1 + print_iter_log: true + kv_cache_config: + enable_block_reuse: false + free_gpu_memory_fraction: 0.7 + max_tokens: 2560 + moe_config: + backend: CUTLASS + cache_transceiver_config: + max_tokens_in_buffer: 8448 + backend: DEFAULT + stream_interval: 1 + num_postprocess_workers: 1 + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 720da1acbd..53b3f10d02 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -261,6 +261,10 @@ def get_test_config(test_desc, example_dir, test_root): (4, f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml" ), + "deepseek_v3_lite_bf16_empty_batch": + (3, + f"{test_configs_root}/disagg_config_deepseek_v3_lite_empty_batch.yaml" + ), } if test_desc not in config_map: @@ -1530,14 +1534,19 @@ def run_disaggregated_benchmark(example_dir, benchmark_model_root, shared_gpt_path, env=None, - cwd=None): + cwd=None, + num_ranks=2, + random_input_len=16, + random_output_len=64, + num_prompts=100, + max_concurrency=32, + skip_warmup=False): """Run disaggregated test with given configuration.""" run_env = env.copy() run_env["UCX_TLS"] = "^ib" - num_rank = 2 workers_cmd = [ 'mpirun', '--allow-run-as-root', '--oversubscribe', '-n', - str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', + str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', config_file ] @@ -1589,15 +1598,15 @@ def run_disaggregated_benchmark(example_dir, '--dataset-path', shared_gpt_path, '--random-input-len', - '256', + str(random_input_len), '--random-output-len', - '64', + str(random_output_len), '--random-prefix-len', '0', '--num-prompts', - '320', + str(num_prompts), '--max-concurrency', - '32', + str(max_concurrency), '--host', 'localhost', '--port', @@ -1608,7 +1617,8 @@ def run_disaggregated_benchmark(example_dir, 'e2el,ttft', ] # warm up - check_call(benchmark_cmd, env=env) + if not skip_warmup: + check_call(benchmark_cmd, env=env) output = check_output(benchmark_cmd, env=env) e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)" ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)" @@ -1718,3 +1728,43 @@ def test_disaggregated_benchmark_on_diff_backends( assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft + + +@pytest.mark.parametrize("benchmark_model_root", ['DeepSeek-V3-Lite-bf16'], + indirect=True) +def test_disaggregated_deepseek_v3_lite_bf16_empty_batch( + disaggregated_example_root, llm_venv, benchmark_model_root, + benchmark_root, shared_gpt_path): + + src_dst_dict = { + benchmark_model_root: + f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + test_desc = "deepseek_v3_lite_bf16_empty_batch" + num_ranks, config_file = get_test_config(test_desc, + disaggregated_example_root, + os.path.dirname(__file__)) + + env = llm_venv._new_env.copy() + e2el, ttft = run_disaggregated_benchmark( + disaggregated_example_root, + config_file, + benchmark_root, + benchmark_model_root, + shared_gpt_path, + env=env, + cwd=llm_venv.get_working_directory(), + num_ranks=num_ranks, + num_prompts=10, + max_concurrency=10, + random_input_len=384, + random_output_len=1536, + skip_warmup=True) + print(f"E2EL: {e2el} ms, TTFT: {ttft} ms") + + assert e2el > 0 and ttft > 0 diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index f92a4e0a69..83025f2de4 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -158,6 +158,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16] + - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_empty_batch[DeepSeek-V3-Lite-bf16] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp[DeepSeek-V3-Lite-fp8] - disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16]