[None][chore] Remove unused get_quant_scales methods (#7687)

Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
Aurelien Chartier 2025-09-16 12:56:11 -07:00 committed by GitHub
parent 9befd1a72f
commit 471723bce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 140 deletions

View File

@ -227,10 +227,6 @@ class TritonUnquantizedFusedMoEMethod(FusedMoEMethodBase):
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = tuple()
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
return tuple()
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,

View File

@ -329,16 +329,6 @@ class FusedMoEMethodBase(ABC):
def setup_quant_scales(self, module: torch.nn.Module):
raise NotImplementedError
@abstractmethod
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
"""
Get the quant scales for the given slot range.
Due to the special handling of slot_start and slot_end, we require the subclasses
to implement this method or explicitly raise NotImplementedError.
"""
# TODO: remove this method, it's no longer needed
def apply(self, module: torch.nn.Module, input: torch.Tensor, *args,
**kwargs) -> torch.Tensor:
"""
@ -412,10 +402,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = tuple()
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
return tuple()
def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale,
dst_fc31_input_scale: torch.Tensor):
@ -534,15 +520,6 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
fc1_input_dequant=module.fc31_input_dequant,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
return FusedMoEQuantScalesFP8(
fc1_dequant=module.fc31_dequant[slot_start:slot_end],
fc2_quant=module.fc2_quant,
fc2_dequant=module.fc2_dequant[slot_start:slot_end],
fc1_input_dequant=module.fc31_input_dequant,
)
def load_expert_w3_w1_weight_scale_fp8_qdq(
self, w1_weight_scale, w3_weight_scale,
dst_w3_w1_weight_scale: torch.Tensor):
@ -647,16 +624,6 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
proj_weight_scales=module.w2_weight_scaling_factor,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesDeepSeekFP8BlockScales(
fc_weight_scales=module.w3_w1_weight_scaling_factor.narrow(
0, slot_start, slot_end - slot_start),
proj_weight_scales=module.w2_weight_scaling_factor.narrow(
0, slot_start, slot_end - slot_start),
)
def load_expert_all_weight_scale_fp8_block_scale(
self, module: torch.nn.Module, weights: Dict,
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
@ -827,16 +794,6 @@ class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase):
fc2_weight_scale=module.fc2_weight_scale,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesINT8WoqPerChannel(
fc31_weight_scale=module.fc31_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc2_weight_scale=module.fc2_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
)
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
@ -1012,26 +969,6 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
alpha_2=module.fc2_alpha,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesW4A8(
scale_1_interleaved=module.fc31_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
scale_2_interleaved=module.fc2_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
pre_quant_scale_1=module.fc31_act_scale.narrow(
0, slot_start, slot_end - slot_start),
pre_quant_scale_2=module.fc2_act_scale.narrow(
0, slot_start, slot_end - slot_start),
zero_1=torch.Tensor(),
zero_2=torch.Tensor(),
alpha_1=module.fc31_alpha.narrow(0, slot_start,
slot_end - slot_start),
alpha_2=module.fc2_alpha.narrow(0, slot_start,
slot_end - slot_start),
)
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
@ -1384,15 +1321,6 @@ class WFP4A16FusedMoEMethod(FusedMoEMethodBase):
scale_1_interleaved=module.fc31_weight_scale,
scale_2_interleaved=module.fc2_weight_scale)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesW4A16MXFP4(
scale_1_interleaved=module.fc31_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
scale_2_interleaved=module.fc2_weight_scale.narrow(
0, slot_start, slot_end - slot_start))
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
@ -1788,22 +1716,6 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
fc2_global=module.fc2_alpha,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesNVFP4(
fc1_act_global=module.fc31_input_scale,
fc1_weight_block=module.w3_w1_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc1_global=module.fc31_alpha.narrow(0, slot_start,
slot_end - slot_start),
fc2_act_global=module.fc2_input_scale,
fc2_weight_block=module.w2_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc2_global=module.fc2_alpha.narrow(0, slot_start,
slot_end - slot_start),
)
class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
@ -1907,13 +1819,6 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = tuple()
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
"""
The TRTLLM-Gen backend of FusedMoE does not use FusedMoEQuantScalesNVFP4.
"""
raise NotImplementedError
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
@ -2252,11 +2157,6 @@ class MXFP4WeightFusedMoEMethod(FusedMoEMethodBase):
def setup_quant_scales(self, module: torch.nn.Module):
pass
@abstractmethod
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
pass
class MXFP4WeightCutlassFusedMoEMethod(MXFP4WeightFusedMoEMethod):
weight_dtype = FUSED_MOE_MXFP4_WEIGHT_DTYPE
@ -2455,20 +2355,6 @@ class W4A8MXFP4MXFP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod):
fc2_dequant_scale=module.fake_input_scale,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesW4A8MXFP4MXFP8(
fc31_weight_block_scale=module.w3_w1_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc31_dequant_scale=module.fake_input_scale.narrow(
0, slot_start, slot_end - slot_start),
fc2_weight_block_scale=module.w2_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc2_dequant_scale=module.fake_input_scale.narrow(
0, slot_start, slot_end - slot_start),
)
class W4A8MXFP4FP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod):
@ -2554,21 +2440,6 @@ class W4A8MXFP4FP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod):
fc2_dequant_scale=module.fc2_input_dequant,
)
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
assert module.smart_router
return FusedMoEQuantScalesW4A8MXFP4FP8(
fc31_weight_block_scale=module.w3_w1_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc31_dequant_scale=module.fc31_input_dequant.narrow(
0, slot_start, slot_end - slot_start),
fc2_input_scale=module.fc2_input_scale,
fc2_weight_block_scale=module.w2_weight_scale.narrow(
0, slot_start, slot_end - slot_start),
fc2_dequant_scale=module.fc2_input_dequant.narrow(
0, slot_start, slot_end - slot_start),
)
class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod):
weight_dtype = torch.uint8
@ -2595,13 +2466,6 @@ class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod):
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = tuple()
def get_quant_scales(self, module: torch.nn.Module, slot_start,
slot_end) -> tuple[torch.Tensor, ...]:
"""
The TRTLLM-Gen backend of FusedMoE does not use FusedMoEQuantScales.
"""
raise NotImplementedError
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,