diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 9ace3c3e98..ea623bc442 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -14,7 +14,6 @@ from tensorrt_llm.logger import logger from ..._utils import get_sm_version from ..attention_backend.trtllm import (AttentionBackend, TrtllmAttention, TrtllmAttentionMetadata) -from ..cute_dsl_kernels.argmax import argmax as cute_argmax from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..pyexecutor.resource_manager import (BaseResourceManager, ResourceManagerType) @@ -551,8 +550,7 @@ class SpecWorkerBase(nn.Module, ABC): Returns: draft_tokens: [num_tokens] - Sampled draft token ids (int32) """ - # cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index - draft_tokens = cute_argmax(logits)[:, 1].long() + draft_tokens = torch.argmax(logits, dim=-1) # Apply d2t (offsets between draft and target model dictionaries) if d2t is not None: @@ -703,7 +701,6 @@ class SpecWorkerBase(nn.Module, ABC): seed=self.seed, offset=self.offset) else: - # cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index - sampled_tokens = cute_argmax(logits)[:, 1].long() + sampled_tokens = torch.argmax(logits, dim=-1) return sampled_tokens