mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5772396][fix]: WAR: Disable TinyGEMM PDL due to accuracy issue
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
parent
6df2c8a074
commit
cb6e93ed19
@ -178,7 +178,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
||||
@ -213,7 +213,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config,
|
||||
@ -388,7 +388,7 @@ void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const*
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
|
||||
@ -236,7 +236,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
if (!weight_warp)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
|
||||
for (int ki = 0; ki < K_LOOPS_DMA; ki++)
|
||||
@ -441,6 +440,9 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
profile[blockIdx.x].complete = gclock64();
|
||||
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
}
|
||||
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
|
||||
@ -228,7 +228,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP] SKIP
|
||||
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL] SKIP (https://nvbugs/5664904)
|
||||
test_e2e.py::test_ptp_quickstart_advanced[Nemotron-Super-49B-v1-FP8-nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1-FP8] SKIP (https://nvbugs/5670469)
|
||||
test_e2e.py::test_ptp_quickstart_advanced[Nemotron-Super-49B-v1-NVFP4-nvfp4-quantized/Llama-3_3-Nemotron-Super-49B-v1_nvfp4_hf] SKIP (https://nvbugs/5670469)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] SKIP (https://nvbugs/5673610)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] SKIP (https://nvbugs/5756804)
|
||||
examples/test_llama.py::test_llm_llama_1gpu_fp4[llama-3.1-70b-instruct-enable_norm_quant_fusion-enable_fused_quant-fp4_plugin-bfloat16] SKIP (https://nvbugs/5451216)
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8 SKIP (https://nvbugs/5673527)
|
||||
@ -312,14 +311,11 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5768068)
|
||||
test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/saved_models_Qwen3-235B-A22B_fp8_hf-Qwen3/qwen3-235B-eagle3] SKIP (https://nvbugs/5685010)
|
||||
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5769855)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] SKIP (https://nvbugs/5772396)
|
||||
full:sm100/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] SKIP (https://nvbugs/5772396)
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm] SKIP (https://nvbugs/5772360)
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model] SKIP (https://nvbugs/5772993)
|
||||
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5772995)
|
||||
test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/Qwen3-30B-A3B-Qwen3/Qwen3-30B-eagle3] SKIP (https://nvbugs/5685010)
|
||||
full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2] SKIP (https://nvbugs/5773047)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-dp4-trtllm-fp8] SKIP (https://nvbugs/5773201)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding[GQA_Block-torch_dist_all_reduce-True-False-2] SKIP (https://nvbugs/5766982)
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=True] SKIP (https://nvbugs/5773185)
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=False] SKIP (https://nvbugs/5773185)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user