mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
refine sm version check
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
parent
2e61526d12
commit
0b73a57c33
@ -40,7 +40,7 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._ipc_utils import can_access_peer
|
||||
from tensorrt_llm._utils import get_sm_family, get_sm_version
|
||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -1486,8 +1486,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
|
||||
) and get_sm_family() == 100 and hasattr(
|
||||
module, "weight_scale"):
|
||||
) and is_sm_100f() and hasattr(module, "weight_scale"):
|
||||
weight, weight_scale = resmooth_to_fp8_e8m0(
|
||||
module.weight, module.weight_scale)
|
||||
transfromed_scale = transform_sf_into_required_layout(
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Optional, Union, cast
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._utils import get_sm_family
|
||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -568,7 +568,7 @@ def fp8_block_scaling_bmm_out(
|
||||
out: torch.Tensor,
|
||||
mat2_dequant: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
sm_version = get_sm_family()
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 90 or sm_version == 89:
|
||||
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
|
||||
mat1)
|
||||
@ -579,7 +579,7 @@ def fp8_block_scaling_bmm_out(
|
||||
output)
|
||||
out.copy_(output)
|
||||
|
||||
elif sm_version == 100:
|
||||
elif is_sm_100f(sm_version):
|
||||
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
|
||||
else:
|
||||
raise NotImplementedError(f"SM{sm_version} is not supported")
|
||||
@ -894,7 +894,7 @@ class MLA(nn.Module):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
if get_sm_family() == 100:
|
||||
if is_sm_100f():
|
||||
assert self.dtype == torch.bfloat16
|
||||
self.k_b_proj_trans_dequant = nn.Parameter(
|
||||
torch.empty(
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._utils import get_sm_family
|
||||
from tensorrt_llm._utils import is_sm_100f
|
||||
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
@ -34,7 +34,7 @@ def cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
b_tmp = b.permute(1, 2, 0)
|
||||
|
||||
# Note: we have different output scale shape for fp8_quantize_1x128, so we need to handle it differently for sm100 and other archs.
|
||||
if get_sm_family() == 100:
|
||||
if is_sm_100f() == 100:
|
||||
input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1),
|
||||
(1, m, m * w_k))
|
||||
else:
|
||||
|
||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import tensorrt_llm.logger as trtllm_logger
|
||||
from tensorrt_llm._utils import get_sm_family, get_sm_version
|
||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
||||
from tensorrt_llm.quantization.functional import \
|
||||
preprocess_weights_for_mixed_gemm
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||
@ -742,7 +742,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
if get_sm_family() == 100:
|
||||
if is_sm_100f() == 100:
|
||||
expert_ids = set(module.initial_local_expert_ids)
|
||||
if self.need_load_shared_weights(module):
|
||||
expert_ids.update(
|
||||
@ -759,7 +759,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
weight, scale)
|
||||
super().load_weights(module, weights, weight_loading_mode)
|
||||
|
||||
if get_sm_family() == 100:
|
||||
if is_sm_100f() == 100:
|
||||
transfromed_w3_w1_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[0],
|
||||
mn=module.w3_w1_weight.shape[1],
|
||||
|
||||
@ -21,7 +21,7 @@ from tensorrt_llm.quantization.functional import \
|
||||
preprocess_weights_for_mixed_gemm
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..._utils import is_sm_100f
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
@ -613,7 +613,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
input = input.to(torch.bfloat16) * module.input_scale
|
||||
assert input.dtype == torch.bfloat16
|
||||
|
||||
if get_sm_version() == 100:
|
||||
if is_sm_100f():
|
||||
if module.use_cute_dsl_blockscaling_mm:
|
||||
# TODO (@lmin): replace with cute_dsl gemm
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||
|
||||
@ -692,11 +692,10 @@ def get_sm_version():
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_sm_family():
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 100 or sm_version == 103:
|
||||
return 100
|
||||
return sm_version
|
||||
def is_sm_100f(sm_version=None):
|
||||
if sm_version is None:
|
||||
sm_version = get_sm_version()
|
||||
return sm_version == 100 or sm_version == 103
|
||||
|
||||
|
||||
def is_trace_enabled(env_var: str):
|
||||
|
||||
@ -439,7 +439,8 @@ class TestFunctional(unittest.TestCase):
|
||||
skip_blackwell_for_fmha_tests(context_fmha_type, head_size)
|
||||
|
||||
# Skip custom mask tests for Blackwell
|
||||
if (getSMVersion() == 100 or getSMVersion == 103) and custom_mask_input:
|
||||
if (getSMVersion() == 100
|
||||
or getSMVersion() == 103) and custom_mask_input:
|
||||
pytest.skip("Custom masked is not supported by TRTLLM-GEN for now.")
|
||||
|
||||
if num_kv_heads == 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user