diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp index 9ff85d9d7c..7121d0fcf5 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp @@ -34,7 +34,11 @@ TllmGenFmhaRunner::TllmGenFmhaRunner(Data_type dtypeQ, Data_type dtypeKv, Data_t , mDtypeKv(dtypeKv) , mDtypeOut(dtypeOut) { - TLLM_CHECK_WITH_INFO(mSM == kSM_100, "Unsupported architecture"); + TLLM_CHECK_WITH_INFO(mSM == kSM_100 || mSM == kSM_103, "Unsupported architecture"); + if (mSM == kSM_103) + { + mSM = kSM_100; // Currently we use same kernel for SM100 and SM103 + } TLLM_CHECK_WITH_INFO( mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16, "Unsupported Q data type"); TLLM_CHECK_WITH_INFO(mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 || mDtypeKv == DATA_TYPE_BF16, diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index f48a40620f..07d7c4924e 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -39,7 +39,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex) { auto const sm = tensorrt_llm::common::getSMVersion(); - TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE"); + TORCH_CHECK(sm == 100 || sm == 103, "Only SM100f is supported by FP8 block scale MOE"); TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float."); TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0], diff --git a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp index b76701f788..eabf9a6f97 100644 --- a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp @@ -37,7 +37,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit { auto const sm = tensorrt_llm::common::getSMVersion(); - TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE"); + TORCH_CHECK(sm == 100 || sm == 103, "Only SM100f is supported by FP8 block scale MOE"); if (use_routing_scales_on_input) { TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16.");