[TRTLLM-5972][chore] Load balance decode token KV cache with helix parallelism (#9757)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2025-12-12 06:29:05 -08:00 committed by GitHub
parent d5b9ad91c9
commit af315d8ef1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 26 additions and 17 deletions

View File

@ -694,6 +694,7 @@ class ExecutorRequestQueue:
position_ids=position_ids_this_rank,
)
req.total_input_len_cp = input_len
req.seqlen_this_rank_cp = len(input_ids_this_rank)
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)

View File

@ -489,6 +489,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_max_new_tokens = self.max_new_tokens
self.py_min_length = self.sampling_config.min_length
self.py_helix_is_inactive_rank = False
self.seqlen_this_rank_cp = 0
self.total_input_len_cp = 0
self.py_batch_idx = None
self.py_draft_pages_allocated = 0
self.py_rewind_len = 0

View File

@ -568,13 +568,12 @@ class PyTorchModelEngine(ModelEngine):
# Reset the global cuda graph dummy request to None in warmup.
self.cuda_graph_runner.padding_dummy_request = None
cp_type = self.mapping.cp_config.get('cp_type', None)
if cp_type is not None:
if cp_type in [CpType.ULYSSES, CpType.STAR]:
logger.info(
"[ModelEngine::warmup] Skipping warmup for cp_type: ",
cp_type.name)
return
if self.mapping.cp_size > 1:
cp_type = self.mapping.cp_config.get("cp_type", None)
logger.info(
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
)
return
self._run_torch_compile_warmup(resource_manager)
self._run_autotuner_warmup(resource_manager)
@ -1671,12 +1670,12 @@ class PyTorchModelEngine(ModelEngine):
# Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called.
if not self.is_warmup and not request.is_cuda_graph_dummy:
position_id = request.total_input_len_cp + request.py_decoding_iter - 1
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
if self.mapping.cp_rank == self.mapping.cp_size - 1:
past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1
if request.py_helix_is_inactive_rank:
past_seen_token_num = request.seqlen_this_rank_cp
else:
# past_seen_token_num doesn't grow on inactive ranks.
past_seen_token_num = request.orig_prompt_len
# Discount the token added to active rank in resource manager as it hasn't
# been previously seen.
past_seen_token_num = request.seqlen_this_rank_cp - 1
position_ids.append(position_id)
num_cached_tokens_per_seq.append(past_seen_token_num)

View File

@ -468,13 +468,17 @@ class KVCacheManager(BaseResourceManager):
req, block_ids)
for req in generation_batch:
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
if self.mapping.has_cp_helix():
if self.mapping.cp_rank != self.mapping.cp_size - 1:
# Distribute the decode blocks across CP ranks in a round-robin manner.
decode_block_id = (req.py_decoding_iter -
1) // self.tokens_per_block
if decode_block_id % self.mapping.cp_size == self.mapping.cp_rank:
req.py_helix_is_inactive_rank = False
req.seqlen_this_rank_cp += 1
else:
req.py_helix_is_inactive_rank = True
# Skip allocating KV cache at decode for inactive helix ranks.
if req.py_helix_is_inactive_rank:
continue
# Skip allocating KV cache at decode for inactive helix ranks.
continue
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

View File

@ -524,6 +524,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]

View File

@ -189,3 +189,4 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix

View File

@ -407,6 +407,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
req.sampling_config.beam_width = 1
req.py_multimodal_data = {}
req.total_input_len_cp = prompt_lens[idx] * 2
req.seqlen_this_rank_cp = prompt_lens[idx]
req.py_decoding_iter = 1
gen_requests.append(req)
scheduled_requests.generation_requests = gen_requests