mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix sm check
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
parent
b8d1ee6975
commit
6133354d2d
@ -34,7 +34,7 @@ def cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
b_tmp = b.permute(1, 2, 0)
|
||||
|
||||
# Note: we have different output scale shape for fp8_quantize_1x128, so we need to handle it differently for sm100 and other archs.
|
||||
if is_sm_100f() == 100:
|
||||
if is_sm_100f():
|
||||
input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1),
|
||||
(1, m, m * w_k))
|
||||
else:
|
||||
|
||||
@ -742,7 +742,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
if is_sm_100f() == 100:
|
||||
if is_sm_100f():
|
||||
expert_ids = set(module.initial_local_expert_ids)
|
||||
if self.need_load_shared_weights(module):
|
||||
expert_ids.update(
|
||||
@ -759,7 +759,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
weight, scale)
|
||||
super().load_weights(module, weights, weight_loading_mode)
|
||||
|
||||
if is_sm_100f() == 100:
|
||||
if is_sm_100f():
|
||||
transfromed_w3_w1_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[0],
|
||||
mn=module.w3_w1_weight.shape[1],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user