[https://nvbugs/5888410][fix] Enable warmup for Helix CP (#11460)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2026-02-12 14:24:51 -08:00 committed by GitHub
parent 07cd3d4ff2
commit 9c2d23c2e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 6 deletions

View File

@ -655,13 +655,17 @@ class PyTorchModelEngine(ModelEngine):
if self.mapping.cp_size > 1:
cp_type = self.mapping.cp_config.get("cp_type", None)
logger.info(
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
)
return
if cp_type != CpType.HELIX:
logger.info(
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
)
return
self._run_torch_compile_warmup(resource_manager)
self._run_autotuner_warmup(resource_manager)
# Autotuner warmup uses context-only requests. Helix CP
# is decode-only and runs into issues with autotuner warmup.
if not self.mapping.has_cp_helix():
self._run_autotuner_warmup(resource_manager)
self._run_cuda_graph_warmup(resource_manager)
# Set the value back to the original value after all warmups are complete
@ -2400,7 +2404,6 @@ class PyTorchModelEngine(ModelEngine):
position_id = past_seen_token_num
if self.mapping.has_cp_helix():
assert not self.is_warmup, "Warmup is not called for helix parallelism."
# We compute a global position_id because each helix rank has only a subset of
# tokens for a sequence.
position_id = request.total_input_len_cp + request.py_decoding_iter - 1

View File

@ -608,6 +608,10 @@ class KVCacheManager(BaseResourceManager):
# during warmup.
token_num = token_nums[
i] if token_nums is not None else 1 + max_num_draft_tokens
# Helix active rank sets past_seen_token_num = seqlen_this_rank_cp - 1
# in _prepare_tp_inputs; need token_num >= 2 so that doesn't go negative.
if self.mapping.has_cp_helix():
token_num = max(token_num, 2)
encoder_input_tokens = [
1
] * token_num if self.impl.cross_kv else None
@ -650,12 +654,14 @@ class KVCacheManager(BaseResourceManager):
req.py_prompt_len = req.prompt_len
req.seqlen_this_rank_cp = req.prompt_len
req.total_input_len_cp = token_num * self.mapping.cp_size - 1
req.py_decoding_iter = 1
else:
req.py_helix_is_inactive_rank = True
req.prompt_len = token_num
req.py_prompt_len = req.prompt_len
req.seqlen_this_rank_cp = req.prompt_len
req.total_input_len_cp = token_num * self.mapping.cp_size - 1
req.py_decoding_iter = 1
req.py_draft_tokens = [1] * max_num_draft_tokens
if prepare_resource:
for _ in range(max_num_draft_tokens):