mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5565565] [fix] fp8 wideep support sm103 (#8228)
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
parent
4bac6b337e
commit
d5b79268e7
@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm._utils import is_sm_100f
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -351,7 +351,7 @@ class WideEPMoE(MoE):
|
||||
if self.quant_config.layer_quant_mode.has_fp8_qdq():
|
||||
return FP8QDQFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
|
||||
if get_sm_version() == 100:
|
||||
if is_sm_100f():
|
||||
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
|
||||
else:
|
||||
return DeepSeekFP8BlockScalesFusedMoEMethod()
|
||||
|
||||
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm._utils import is_sm_100f
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..interface import MoE
|
||||
@ -224,7 +224,7 @@ class MoEOpSelector:
|
||||
|
||||
# Check if we should use DeepGemm op
|
||||
# Blackwell has SM version 100
|
||||
is_blackwell = get_sm_version() == 100
|
||||
is_blackwell = is_sm_100f()
|
||||
has_block_fp8 = module.has_deepseek_fp8_block_scales
|
||||
|
||||
if is_blackwell and has_block_fp8:
|
||||
|
||||
@ -16,7 +16,7 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from defs.conftest import get_sm_version
|
||||
from defs.conftest import get_sm_version, is_sm_100f
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
@ -2168,7 +2168,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
|
||||
attention_dp, cuda_graph, overlap_scheduler,
|
||||
max_batch_size, moe_backend):
|
||||
if get_sm_version() == 100 or get_sm_version() == 103:
|
||||
if is_sm_100f():
|
||||
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
@ -2217,7 +2217,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
mtp_nextn, fp8kv, attention_dp,
|
||||
cuda_graph, overlap_scheduler,
|
||||
max_batch_size):
|
||||
if get_sm_version() == 100:
|
||||
if is_sm_100f():
|
||||
moe_config = MoeConfig(backend="DEEPGEMM", max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
else:
|
||||
|
||||
@ -1871,6 +1871,12 @@ def get_sm_version():
|
||||
return prop.major * 10 + prop.minor
|
||||
|
||||
|
||||
def is_sm_100f(sm_version=None):
|
||||
if sm_version is None:
|
||||
sm_version = get_sm_version()
|
||||
return sm_version == 100 or sm_version == 103
|
||||
|
||||
|
||||
def get_gpu_device_list():
|
||||
"get device list"
|
||||
with tempfile.TemporaryDirectory() as temp_dirname:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user