fix sm check

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-09-11 09:13:23 +08:00
parent b8d1ee6975
commit 6133354d2d
2 changed files with 3 additions and 3 deletions

View File

@ -34,7 +34,7 @@ def cute_dsl_fp8_group_blockwise_gemm_ref(
b_tmp = b.permute(1, 2, 0)
# Note: we have different output scale shape for fp8_quantize_1x128, so we need to handle it differently for sm100 and other archs.
if is_sm_100f() == 100:
if is_sm_100f():
input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1),
(1, m, m * w_k))
else:

View File

@ -742,7 +742,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
if is_sm_100f() == 100:
if is_sm_100f():
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
@ -759,7 +759,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)
if is_sm_100f() == 100:
if is_sm_100f():
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],