[fix] improve fp4_block_scale_moe_runner type check (#5681)

Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
JieXin Liang 2025-07-08 13:32:14 +08:00 committed by GitHub
parent 95978e3044
commit 664bf95892
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,9 +39,14 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
{
auto const sm = tensorrt_llm::common::getSMVersion();
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP4 block scale MOE");
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float
|| routing_logits.scalar_type() == at::ScalarType::BFloat16,
"routing_logits must be float or bfloat16.");
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3)
{
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float");
}
else
{
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16");
}
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape.");
if (routing_bias.has_value())