mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[https://nvbugs/5854860][fix] Fix cutedsl argmax on sm120 (#11181)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
parent
ad2d1df4a9
commit
0bd4630cd1
@ -219,7 +219,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
@cute.jit
|
||||
def warp_argmax_redux(current_max: Float32, current_argmax: Int32):
|
||||
"""Redux-based warp argmax - only works on sm_100+ (Blackwell)."""
|
||||
"""Redux-based warp argmax - only works on sm_100f (Blackwell)."""
|
||||
warp_max = ptx_redux_sync_max_f32(current_max)
|
||||
candidate_idx = ptx_select_argmax_candidate(current_max, warp_max, current_argmax)
|
||||
winning_idx = ptx_redux_sync_min_u32(candidate_idx)
|
||||
@ -324,7 +324,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
class ArgmaxKernel(ReductionBase):
|
||||
def __init__(self, dtype: Type[cutlass.Numeric], N: int, use_redux: bool = False):
|
||||
super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
|
||||
# use_redux=True for Blackwell (sm_100+), False for Hopper (sm_90)
|
||||
# use_redux=True for Blackwell (sm_100f), False for Hopper (sm_90)
|
||||
self.use_redux = use_redux
|
||||
|
||||
def _calculate_threads_per_row(self):
|
||||
@ -582,6 +582,11 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
return True
|
||||
if N % _VOCAB_SIZE_ALIGNMENT != 0:
|
||||
return True
|
||||
# Fall back on sm_120 - CUTLASS DSL JIT not well-supported for this setup
|
||||
from ..._utils import get_sm_version
|
||||
|
||||
if get_sm_version() >= 120:
|
||||
return True
|
||||
return False
|
||||
|
||||
def argmax(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -618,11 +623,12 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
out_tensor = convert_from_dlpack(out)
|
||||
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Detect compute capability: use redux instructions only on Blackwell (sm_100+)
|
||||
# redux.sync.max.f32 is only supported on sm_100+
|
||||
# Detect compute capability: use redux instructions only on Blackwell (sm_100f)
|
||||
# redux.sync.max.f32 is only supported on sm_100f
|
||||
from ..._utils import get_sm_version
|
||||
|
||||
use_redux = get_sm_version() >= 100 # sm_100+ (Blackwell)
|
||||
sm_version = get_sm_version()
|
||||
use_redux = sm_version >= 100 and sm_version < 120
|
||||
|
||||
compile_key = (dtype, N, use_redux)
|
||||
if compile_key not in _argmax_compile_cache:
|
||||
|
||||
@ -262,7 +262,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_p
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5800646)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5800679)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5853997)
|
||||
examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5802248)
|
||||
unittest/_torch/modeling/test_modeling_llama.py::TestLlama::test_llama_verification_with_kv_cache_relocation SKIP (https://nvbugs/5804923)
|
||||
@ -345,18 +344,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] S
|
||||
unittest/_torch/modeling/test_modeling_nemotron_h.py::test_nemotron_h_cuda_graph_overlap_scheduler SKIP (https://nvbugs/5843316)
|
||||
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5846024)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugs/5847284)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user