mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[https://nvbugs/5888410][fix] Enable warmup for Helix CP (#11460)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
07cd3d4ff2
commit
9c2d23c2e5
@ -655,13 +655,17 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
if self.mapping.cp_size > 1:
|
||||
cp_type = self.mapping.cp_config.get("cp_type", None)
|
||||
logger.info(
|
||||
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
|
||||
)
|
||||
return
|
||||
if cp_type != CpType.HELIX:
|
||||
logger.info(
|
||||
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
|
||||
)
|
||||
return
|
||||
|
||||
self._run_torch_compile_warmup(resource_manager)
|
||||
self._run_autotuner_warmup(resource_manager)
|
||||
# Autotuner warmup uses context-only requests. Helix CP
|
||||
# is decode-only and runs into issues with autotuner warmup.
|
||||
if not self.mapping.has_cp_helix():
|
||||
self._run_autotuner_warmup(resource_manager)
|
||||
self._run_cuda_graph_warmup(resource_manager)
|
||||
|
||||
# Set the value back to the original value after all warmups are complete
|
||||
@ -2400,7 +2404,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
position_id = past_seen_token_num
|
||||
if self.mapping.has_cp_helix():
|
||||
assert not self.is_warmup, "Warmup is not called for helix parallelism."
|
||||
# We compute a global position_id because each helix rank has only a subset of
|
||||
# tokens for a sequence.
|
||||
position_id = request.total_input_len_cp + request.py_decoding_iter - 1
|
||||
|
||||
@ -608,6 +608,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
# during warmup.
|
||||
token_num = token_nums[
|
||||
i] if token_nums is not None else 1 + max_num_draft_tokens
|
||||
# Helix active rank sets past_seen_token_num = seqlen_this_rank_cp - 1
|
||||
# in _prepare_tp_inputs; need token_num >= 2 so that doesn't go negative.
|
||||
if self.mapping.has_cp_helix():
|
||||
token_num = max(token_num, 2)
|
||||
encoder_input_tokens = [
|
||||
1
|
||||
] * token_num if self.impl.cross_kv else None
|
||||
@ -650,12 +654,14 @@ class KVCacheManager(BaseResourceManager):
|
||||
req.py_prompt_len = req.prompt_len
|
||||
req.seqlen_this_rank_cp = req.prompt_len
|
||||
req.total_input_len_cp = token_num * self.mapping.cp_size - 1
|
||||
req.py_decoding_iter = 1
|
||||
else:
|
||||
req.py_helix_is_inactive_rank = True
|
||||
req.prompt_len = token_num
|
||||
req.py_prompt_len = req.prompt_len
|
||||
req.seqlen_this_rank_cp = req.prompt_len
|
||||
req.total_input_len_cp = token_num * self.mapping.cp_size - 1
|
||||
req.py_decoding_iter = 1
|
||||
req.py_draft_tokens = [1] * max_num_draft_tokens
|
||||
if prepare_resource:
|
||||
for _ in range(max_num_draft_tokens):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user