mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[fix] improve fp4_block_scale_moe_runner type check (#5681)
Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com> Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
parent
95978e3044
commit
664bf95892
@ -39,9 +39,14 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
|
||||
{
|
||||
auto const sm = tensorrt_llm::common::getSMVersion();
|
||||
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP4 block scale MOE");
|
||||
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float
|
||||
|| routing_logits.scalar_type() == at::ScalarType::BFloat16,
|
||||
"routing_logits must be float or bfloat16.");
|
||||
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3)
|
||||
{
|
||||
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float");
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16");
|
||||
}
|
||||
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
||||
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape.");
|
||||
if (routing_bias.has_value())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user