mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[https://nvbugs/5853720][fix] Disable cutedsl argmax kernel to fix perf regression (#11403)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
This commit is contained in:
parent
be88fe33be
commit
eac56b793e
@ -14,7 +14,6 @@ from tensorrt_llm.logger import logger
|
||||
from ..._utils import get_sm_version
|
||||
from ..attention_backend.trtllm import (AttentionBackend, TrtllmAttention,
|
||||
TrtllmAttentionMetadata)
|
||||
from ..cute_dsl_kernels.argmax import argmax as cute_argmax
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..pyexecutor.resource_manager import (BaseResourceManager,
|
||||
ResourceManagerType)
|
||||
@ -551,8 +550,7 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
Returns:
|
||||
draft_tokens: [num_tokens] - Sampled draft token ids (int32)
|
||||
"""
|
||||
# cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index
|
||||
draft_tokens = cute_argmax(logits)[:, 1].long()
|
||||
draft_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
# Apply d2t (offsets between draft and target model dictionaries)
|
||||
if d2t is not None:
|
||||
@ -703,7 +701,6 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
seed=self.seed,
|
||||
offset=self.offset)
|
||||
else:
|
||||
# cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index
|
||||
sampled_tokens = cute_argmax(logits)[:, 1].long()
|
||||
sampled_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
return sampled_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user