[https://nvbugs/5853720][fix] Disable cutedsl argmax kernel to fix perf regression (#11403)

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
This commit is contained in:
chenfeiz0326 2026-02-10 18:10:38 +08:00 committed by GitHub
parent be88fe33be
commit eac56b793e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,6 @@ from tensorrt_llm.logger import logger
from ..._utils import get_sm_version
from ..attention_backend.trtllm import (AttentionBackend, TrtllmAttention,
TrtllmAttentionMetadata)
from ..cute_dsl_kernels.argmax import argmax as cute_argmax
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..pyexecutor.resource_manager import (BaseResourceManager,
ResourceManagerType)
@ -551,8 +550,7 @@ class SpecWorkerBase(nn.Module, ABC):
Returns:
draft_tokens: [num_tokens] - Sampled draft token ids (int32)
"""
# cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index
draft_tokens = cute_argmax(logits)[:, 1].long()
draft_tokens = torch.argmax(logits, dim=-1)
# Apply d2t (offsets between draft and target model dictionaries)
if d2t is not None:
@ -703,7 +701,6 @@ class SpecWorkerBase(nn.Module, ABC):
seed=self.seed,
offset=self.offset)
else:
# cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index
sampled_tokens = cute_argmax(logits)[:, 1].long()
sampled_tokens = torch.argmax(logits, dim=-1)
return sampled_tokens