From 9c2d23c2e56e92ce1834775434884a7ccf1395c0 Mon Sep 17 00:00:00 2001 From: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:24:51 -0800 Subject: [PATCH] [https://nvbugs/5888410][fix] Enable warmup for Helix CP (#11460) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 15 +++++++++------ .../_torch/pyexecutor/resource_manager.py | 6 ++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 5bf0c2f78b..866bc4cf05 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -655,13 +655,17 @@ class PyTorchModelEngine(ModelEngine): if self.mapping.cp_size > 1: cp_type = self.mapping.cp_config.get("cp_type", None) - logger.info( - f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}." - ) - return + if cp_type != CpType.HELIX: + logger.info( + f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}." + ) + return self._run_torch_compile_warmup(resource_manager) - self._run_autotuner_warmup(resource_manager) + # Autotuner warmup uses context-only requests. Helix CP + # is decode-only and runs into issues with autotuner warmup. + if not self.mapping.has_cp_helix(): + self._run_autotuner_warmup(resource_manager) self._run_cuda_graph_warmup(resource_manager) # Set the value back to the original value after all warmups are complete @@ -2400,7 +2404,6 @@ class PyTorchModelEngine(ModelEngine): position_id = past_seen_token_num if self.mapping.has_cp_helix(): - assert not self.is_warmup, "Warmup is not called for helix parallelism." # We compute a global position_id because each helix rank has only a subset of # tokens for a sequence. position_id = request.total_input_len_cp + request.py_decoding_iter - 1 diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index df80d786de..8f1d5ff73c 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -608,6 +608,10 @@ class KVCacheManager(BaseResourceManager): # during warmup. token_num = token_nums[ i] if token_nums is not None else 1 + max_num_draft_tokens + # Helix active rank sets past_seen_token_num = seqlen_this_rank_cp - 1 + # in _prepare_tp_inputs; need token_num >= 2 so that doesn't go negative. + if self.mapping.has_cp_helix(): + token_num = max(token_num, 2) encoder_input_tokens = [ 1 ] * token_num if self.impl.cross_kv else None @@ -650,12 +654,14 @@ class KVCacheManager(BaseResourceManager): req.py_prompt_len = req.prompt_len req.seqlen_this_rank_cp = req.prompt_len req.total_input_len_cp = token_num * self.mapping.cp_size - 1 + req.py_decoding_iter = 1 else: req.py_helix_is_inactive_rank = True req.prompt_len = token_num req.py_prompt_len = req.prompt_len req.seqlen_this_rank_cp = req.prompt_len req.total_input_len_cp = token_num * self.mapping.cp_size - 1 + req.py_decoding_iter = 1 req.py_draft_tokens = [1] * max_num_draft_tokens if prepare_resource: for _ in range(max_num_draft_tokens):