refine sm version check

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-09-10 14:51:06 +08:00
parent 2e61526d12
commit 0b73a57c33
7 changed files with 19 additions and 20 deletions

View File

@ -40,7 +40,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_family, get_sm_version
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
@ -1486,8 +1486,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
p.data.copy_(module_weights[n][:])
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and get_sm_family() == 100 and hasattr(
module, "weight_scale"):
) and is_sm_100f() and hasattr(module, "weight_scale"):
weight, weight_scale = resmooth_to_fp8_e8m0(
module.weight, module.weight_scale)
transfromed_scale = transform_sf_into_required_layout(

View File

@ -5,7 +5,7 @@ from typing import Optional, Union, cast
import torch
from torch import nn
from tensorrt_llm._utils import get_sm_family
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
@ -568,7 +568,7 @@ def fp8_block_scaling_bmm_out(
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sm_version = get_sm_family()
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
@ -579,7 +579,7 @@ def fp8_block_scaling_bmm_out(
output)
out.copy_(output)
elif sm_version == 100:
elif is_sm_100f(sm_version):
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
else:
raise NotImplementedError(f"SM{sm_version} is not supported")
@ -894,7 +894,7 @@ class MLA(nn.Module):
),
requires_grad=False,
)
if get_sm_family() == 100:
if is_sm_100f():
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from tensorrt_llm._utils import get_sm_family
from tensorrt_llm._utils import is_sm_100f
from ...model_config import ModelConfig
from ...utils import AuxStreamType, Fp4QuantizedTensor
@ -34,7 +34,7 @@ def cute_dsl_fp8_group_blockwise_gemm_ref(
b_tmp = b.permute(1, 2, 0)
# Note: we have different output scale shape for fp8_quantize_1x128, so we need to handle it differently for sm100 and other archs.
if get_sm_family() == 100:
if is_sm_100f() == 100:
input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1),
(1, m, m * w_k))
else:

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import nn
import tensorrt_llm.logger as trtllm_logger
from tensorrt_llm._utils import get_sm_family, get_sm_version
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.utils.fp4_utils import (
@ -742,7 +742,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
if get_sm_family() == 100:
if is_sm_100f() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
@ -759,7 +759,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)
if get_sm_family() == 100:
if is_sm_100f() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],

View File

@ -21,7 +21,7 @@ from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.mode import QuantAlgo
from ..._utils import get_sm_version
from ..._utils import is_sm_100f
from ...models.modeling_utils import QuantConfig
from ..utils import Fp4QuantizedTensor
@ -613,7 +613,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
input = input.to(torch.bfloat16) * module.input_scale
assert input.dtype == torch.bfloat16
if get_sm_version() == 100:
if is_sm_100f():
if module.use_cute_dsl_blockscaling_mm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(

View File

@ -692,11 +692,10 @@ def get_sm_version():
@lru_cache(maxsize=1)
def get_sm_family():
sm_version = get_sm_version()
if sm_version == 100 or sm_version == 103:
return 100
return sm_version
def is_sm_100f(sm_version=None):
if sm_version is None:
sm_version = get_sm_version()
return sm_version == 100 or sm_version == 103
def is_trace_enabled(env_var: str):

View File

@ -439,7 +439,8 @@ class TestFunctional(unittest.TestCase):
skip_blackwell_for_fmha_tests(context_fmha_type, head_size)
# Skip custom mask tests for Blackwell
if (getSMVersion() == 100 or getSMVersion == 103) and custom_mask_input:
if (getSMVersion() == 100
or getSMVersion() == 103) and custom_mask_input:
pytest.skip("Custom masked is not supported by TRTLLM-GEN for now.")
if num_kv_heads == 0: