mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-5972][chore] Load balance decode token KV cache with helix parallelism (#9757)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
d5b9ad91c9
commit
af315d8ef1
@ -694,6 +694,7 @@ class ExecutorRequestQueue:
|
||||
position_ids=position_ids_this_rank,
|
||||
)
|
||||
req.total_input_len_cp = input_len
|
||||
req.seqlen_this_rank_cp = len(input_ids_this_rank)
|
||||
req_with_children.append(req)
|
||||
if req.child_requests:
|
||||
req_with_children.extend(req.child_requests)
|
||||
|
||||
@ -489,6 +489,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_max_new_tokens = self.max_new_tokens
|
||||
self.py_min_length = self.sampling_config.min_length
|
||||
self.py_helix_is_inactive_rank = False
|
||||
self.seqlen_this_rank_cp = 0
|
||||
self.total_input_len_cp = 0
|
||||
self.py_batch_idx = None
|
||||
self.py_draft_pages_allocated = 0
|
||||
self.py_rewind_len = 0
|
||||
|
||||
@ -568,13 +568,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Reset the global cuda graph dummy request to None in warmup.
|
||||
self.cuda_graph_runner.padding_dummy_request = None
|
||||
|
||||
cp_type = self.mapping.cp_config.get('cp_type', None)
|
||||
if cp_type is not None:
|
||||
if cp_type in [CpType.ULYSSES, CpType.STAR]:
|
||||
logger.info(
|
||||
"[ModelEngine::warmup] Skipping warmup for cp_type: ",
|
||||
cp_type.name)
|
||||
return
|
||||
if self.mapping.cp_size > 1:
|
||||
cp_type = self.mapping.cp_config.get("cp_type", None)
|
||||
logger.info(
|
||||
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
|
||||
)
|
||||
return
|
||||
|
||||
self._run_torch_compile_warmup(resource_manager)
|
||||
self._run_autotuner_warmup(resource_manager)
|
||||
@ -1671,12 +1670,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called.
|
||||
if not self.is_warmup and not request.is_cuda_graph_dummy:
|
||||
position_id = request.total_input_len_cp + request.py_decoding_iter - 1
|
||||
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
|
||||
if self.mapping.cp_rank == self.mapping.cp_size - 1:
|
||||
past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1
|
||||
if request.py_helix_is_inactive_rank:
|
||||
past_seen_token_num = request.seqlen_this_rank_cp
|
||||
else:
|
||||
# past_seen_token_num doesn't grow on inactive ranks.
|
||||
past_seen_token_num = request.orig_prompt_len
|
||||
# Discount the token added to active rank in resource manager as it hasn't
|
||||
# been previously seen.
|
||||
past_seen_token_num = request.seqlen_this_rank_cp - 1
|
||||
|
||||
position_ids.append(position_id)
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
|
||||
@ -468,13 +468,17 @@ class KVCacheManager(BaseResourceManager):
|
||||
req, block_ids)
|
||||
|
||||
for req in generation_batch:
|
||||
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
|
||||
if self.mapping.has_cp_helix():
|
||||
if self.mapping.cp_rank != self.mapping.cp_size - 1:
|
||||
# Distribute the decode blocks across CP ranks in a round-robin manner.
|
||||
decode_block_id = (req.py_decoding_iter -
|
||||
1) // self.tokens_per_block
|
||||
if decode_block_id % self.mapping.cp_size == self.mapping.cp_rank:
|
||||
req.py_helix_is_inactive_rank = False
|
||||
req.seqlen_this_rank_cp += 1
|
||||
else:
|
||||
req.py_helix_is_inactive_rank = True
|
||||
# Skip allocating KV cache at decode for inactive helix ranks.
|
||||
if req.py_helix_is_inactive_rank:
|
||||
continue
|
||||
# Skip allocating KV cache at decode for inactive helix ranks.
|
||||
continue
|
||||
self.impl.add_token(req.py_request_id)
|
||||
for _ in range(get_draft_token_length(req)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
|
||||
@ -524,6 +524,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
|
||||
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
|
||||
|
||||
@ -189,3 +189,4 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix
|
||||
|
||||
@ -407,6 +407,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
req.sampling_config.beam_width = 1
|
||||
req.py_multimodal_data = {}
|
||||
req.total_input_len_cp = prompt_lens[idx] * 2
|
||||
req.seqlen_this_rank_cp = prompt_lens[idx]
|
||||
req.py_decoding_iter = 1
|
||||
gen_requests.append(req)
|
||||
scheduled_requests.generation_requests = gen_requests
|
||||
|
||||
Loading…
Reference in New Issue
Block a user