diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py b/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py index 12719acd00..6c3a635e5a 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py @@ -219,7 +219,7 @@ if IS_CUTLASS_DSL_AVAILABLE: @cute.jit def warp_argmax_redux(current_max: Float32, current_argmax: Int32): - """Redux-based warp argmax - only works on sm_100+ (Blackwell).""" + """Redux-based warp argmax - only works on sm_100f (Blackwell).""" warp_max = ptx_redux_sync_max_f32(current_max) candidate_idx = ptx_select_argmax_candidate(current_max, warp_max, current_argmax) winning_idx = ptx_redux_sync_min_u32(candidate_idx) @@ -324,7 +324,7 @@ if IS_CUTLASS_DSL_AVAILABLE: class ArgmaxKernel(ReductionBase): def __init__(self, dtype: Type[cutlass.Numeric], N: int, use_redux: bool = False): super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32) - # use_redux=True for Blackwell (sm_100+), False for Hopper (sm_90) + # use_redux=True for Blackwell (sm_100f), False for Hopper (sm_90) self.use_redux = use_redux def _calculate_threads_per_row(self): @@ -582,6 +582,11 @@ if IS_CUTLASS_DSL_AVAILABLE: return True if N % _VOCAB_SIZE_ALIGNMENT != 0: return True + # Fall back on sm_120 - CUTLASS DSL JIT not well-supported for this setup + from ..._utils import get_sm_version + + if get_sm_version() >= 120: + return True return False def argmax(x: torch.Tensor) -> torch.Tensor: @@ -618,11 +623,12 @@ if IS_CUTLASS_DSL_AVAILABLE: out_tensor = convert_from_dlpack(out) current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Detect compute capability: use redux instructions only on Blackwell (sm_100+) - # redux.sync.max.f32 is only supported on sm_100+ + # Detect compute capability: use redux instructions only on Blackwell (sm_100f) + # redux.sync.max.f32 is only supported on sm_100f from ..._utils import get_sm_version - use_redux = get_sm_version() >= 100 # sm_100+ (Blackwell) + sm_version = get_sm_version() + use_redux = sm_version >= 100 and sm_version < 120 compile_key = (dtype, N, use_redux) if compile_key not in _argmax_compile_cache: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index dc164d0063..6401c2ce0b 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -262,7 +262,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_p accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5800646) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5800679) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5853997) examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5802248) unittest/_torch/modeling/test_modeling_llama.py::TestLlama::test_llama_verification_with_kv_cache_relocation SKIP (https://nvbugs/5804923) @@ -345,18 +344,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] S unittest/_torch/modeling/test_modeling_nemotron_h.py::test_nemotron_h_cuda_graph_overlap_scheduler SKIP (https://nvbugs/5843316) examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5846024) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugs/5847284) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183)