fix more sm version check

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-08-22 13:27:07 +08:00
parent b7cc06cd6a
commit fa8b52ed33
3 changed files with 7 additions and 3 deletions

View File

@ -34,7 +34,11 @@ TllmGenFmhaRunner::TllmGenFmhaRunner(Data_type dtypeQ, Data_type dtypeKv, Data_t
, mDtypeKv(dtypeKv)
, mDtypeOut(dtypeOut)
{
TLLM_CHECK_WITH_INFO(mSM == kSM_100, "Unsupported architecture");
TLLM_CHECK_WITH_INFO(mSM == kSM_100 || mSM == kSM_103, "Unsupported architecture");
if (mSM == kSM_103)
{
mSM = kSM_100; // Currently we use same kernel for SM100 and SM103
}
TLLM_CHECK_WITH_INFO(
mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16, "Unsupported Q data type");
TLLM_CHECK_WITH_INFO(mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 || mDtypeKv == DATA_TYPE_BF16,

View File

@ -39,7 +39,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex)
{
auto const sm = tensorrt_llm::common::getSMVersion();
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
TORCH_CHECK(sm == 100 || sm == 103, "Only SM100f is supported by FP8 block scale MOE");
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],

View File

@ -37,7 +37,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit
{
auto const sm = tensorrt_llm::common::getSMVersion();
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
TORCH_CHECK(sm == 100 || sm == 103, "Only SM100f is supported by FP8 block scale MOE");
if (use_routing_scales_on_input)
{
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16.");