diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a5696d6684..5b98c4eb02 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -40,7 +40,7 @@ from tqdm import tqdm from transformers import PretrainedConfig from tensorrt_llm._ipc_utils import can_access_peer -from tensorrt_llm._utils import get_sm_family, get_sm_version +from tensorrt_llm._utils import get_sm_version, is_sm_100f from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping @@ -1486,8 +1486,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, p.data.copy_(module_weights[n][:]) if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( - ) and get_sm_family() == 100 and hasattr( - module, "weight_scale"): + ) and is_sm_100f() and hasattr(module, "weight_scale"): weight, weight_scale = resmooth_to_fp8_e8m0( module.weight, module.weight_scale) transfromed_scale = transform_sf_into_required_layout( diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index d73054b540..cbdb401850 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -5,7 +5,7 @@ from typing import Optional, Union, cast import torch from torch import nn -from tensorrt_llm._utils import get_sm_family +from tensorrt_llm._utils import get_sm_version, is_sm_100f from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -568,7 +568,7 @@ def fp8_block_scaling_bmm_out( out: torch.Tensor, mat2_dequant: Optional[torch.Tensor] = None, ) -> torch.Tensor: - sm_version = get_sm_family() + sm_version = get_sm_version() if sm_version == 90 or sm_version == 89: mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( mat1) @@ -579,7 +579,7 @@ def fp8_block_scaling_bmm_out( output) out.copy_(output) - elif sm_version == 100: + elif is_sm_100f(sm_version): torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out) else: raise NotImplementedError(f"SM{sm_version} is not supported") @@ -894,7 +894,7 @@ class MLA(nn.Module): ), requires_grad=False, ) - if get_sm_family() == 100: + if is_sm_100f(): assert self.dtype == torch.bfloat16 self.k_b_proj_trans_dequant = nn.Parameter( torch.empty( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 5237095b28..9b43f1b22b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union import torch import torch.nn.functional as F -from tensorrt_llm._utils import get_sm_family +from tensorrt_llm._utils import is_sm_100f from ...model_config import ModelConfig from ...utils import AuxStreamType, Fp4QuantizedTensor @@ -34,7 +34,7 @@ def cute_dsl_fp8_group_blockwise_gemm_ref( b_tmp = b.permute(1, 2, 0) # Note: we have different output scale shape for fp8_quantize_1x128, so we need to handle it differently for sm100 and other archs. - if get_sm_family() == 100: + if is_sm_100f() == 100: input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1), (1, m, m * w_k)) else: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 55f37ca76e..512c71d231 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn import tensorrt_llm.logger as trtllm_logger -from tensorrt_llm._utils import get_sm_family, get_sm_version +from tensorrt_llm._utils import get_sm_version, is_sm_100f from tensorrt_llm.quantization.functional import \ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.utils.fp4_utils import ( @@ -742,7 +742,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): - if get_sm_family() == 100: + if is_sm_100f() == 100: expert_ids = set(module.initial_local_expert_ids) if self.need_load_shared_weights(module): expert_ids.update( @@ -759,7 +759,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( weight, scale) super().load_weights(module, weights, weight_loading_mode) - if get_sm_family() == 100: + if is_sm_100f() == 100: transfromed_w3_w1_scale = transform_sf_into_required_layout( module.quant_scales[0], mn=module.w3_w1_weight.shape[1], diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index c91e4532ab..c77d208b09 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -21,7 +21,7 @@ from tensorrt_llm.quantization.functional import \ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.mode import QuantAlgo -from ..._utils import get_sm_version +from ..._utils import is_sm_100f from ...models.modeling_utils import QuantConfig from ..utils import Fp4QuantizedTensor @@ -613,7 +613,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase): input = input.to(torch.bfloat16) * module.input_scale assert input.dtype == torch.bfloat16 - if get_sm_version() == 100: + if is_sm_100f(): if module.use_cute_dsl_blockscaling_mm: # TODO (@lmin): replace with cute_dsl gemm act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index b0f349502c..6120443e7b 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -692,11 +692,10 @@ def get_sm_version(): @lru_cache(maxsize=1) -def get_sm_family(): - sm_version = get_sm_version() - if sm_version == 100 or sm_version == 103: - return 100 - return sm_version +def is_sm_100f(sm_version=None): + if sm_version is None: + sm_version = get_sm_version() + return sm_version == 100 or sm_version == 103 def is_trace_enabled(env_var: str): diff --git a/tests/unittest/trt/attention/test_gpt_attention.py b/tests/unittest/trt/attention/test_gpt_attention.py index d176046816..f6fc08c433 100644 --- a/tests/unittest/trt/attention/test_gpt_attention.py +++ b/tests/unittest/trt/attention/test_gpt_attention.py @@ -439,7 +439,8 @@ class TestFunctional(unittest.TestCase): skip_blackwell_for_fmha_tests(context_fmha_type, head_size) # Skip custom mask tests for Blackwell - if (getSMVersion() == 100 or getSMVersion == 103) and custom_mask_input: + if (getSMVersion() == 100 + or getSMVersion() == 103) and custom_mask_input: pytest.skip("Custom masked is not supported by TRTLLM-GEN for now.") if num_kv_heads == 0: