mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-25 21:22:57 +08:00
fix more sm version check
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
parent
b7cc06cd6a
commit
fa8b52ed33
@ -34,7 +34,11 @@ TllmGenFmhaRunner::TllmGenFmhaRunner(Data_type dtypeQ, Data_type dtypeKv, Data_t
|
||||
, mDtypeKv(dtypeKv)
|
||||
, mDtypeOut(dtypeOut)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mSM == kSM_100, "Unsupported architecture");
|
||||
TLLM_CHECK_WITH_INFO(mSM == kSM_100 || mSM == kSM_103, "Unsupported architecture");
|
||||
if (mSM == kSM_103)
|
||||
{
|
||||
mSM = kSM_100; // Currently we use same kernel for SM100 and SM103
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16, "Unsupported Q data type");
|
||||
TLLM_CHECK_WITH_INFO(mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 || mDtypeKv == DATA_TYPE_BF16,
|
||||
|
||||
@ -39,7 +39,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
|
||||
int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex)
|
||||
{
|
||||
auto const sm = tensorrt_llm::common::getSMVersion();
|
||||
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
|
||||
TORCH_CHECK(sm == 100 || sm == 103, "Only SM100f is supported by FP8 block scale MOE");
|
||||
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
|
||||
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
||||
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
|
||||
|
||||
@ -37,7 +37,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit
|
||||
{
|
||||
|
||||
auto const sm = tensorrt_llm::common::getSMVersion();
|
||||
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
|
||||
TORCH_CHECK(sm == 100 || sm == 103, "Only SM100f is supported by FP8 block scale MOE");
|
||||
if (use_routing_scales_on_input)
|
||||
{
|
||||
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16.");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user