mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][chore] Remove unused get_quant_scales methods (#7687)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
parent
9befd1a72f
commit
471723bce1
@ -227,10 +227,6 @@ class TritonUnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = tuple()
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
return tuple()
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
|
||||
@ -329,16 +329,6 @@ class FusedMoEMethodBase(ABC):
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Get the quant scales for the given slot range.
|
||||
Due to the special handling of slot_start and slot_end, we require the subclasses
|
||||
to implement this method or explicitly raise NotImplementedError.
|
||||
"""
|
||||
# TODO: remove this method, it's no longer needed
|
||||
|
||||
def apply(self, module: torch.nn.Module, input: torch.Tensor, *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
@ -412,10 +402,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = tuple()
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
return tuple()
|
||||
|
||||
|
||||
def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale,
|
||||
dst_fc31_input_scale: torch.Tensor):
|
||||
@ -534,15 +520,6 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
|
||||
fc1_input_dequant=module.fc31_input_dequant,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
return FusedMoEQuantScalesFP8(
|
||||
fc1_dequant=module.fc31_dequant[slot_start:slot_end],
|
||||
fc2_quant=module.fc2_quant,
|
||||
fc2_dequant=module.fc2_dequant[slot_start:slot_end],
|
||||
fc1_input_dequant=module.fc31_input_dequant,
|
||||
)
|
||||
|
||||
def load_expert_w3_w1_weight_scale_fp8_qdq(
|
||||
self, w1_weight_scale, w3_weight_scale,
|
||||
dst_w3_w1_weight_scale: torch.Tensor):
|
||||
@ -647,16 +624,6 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
proj_weight_scales=module.w2_weight_scaling_factor,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesDeepSeekFP8BlockScales(
|
||||
fc_weight_scales=module.w3_w1_weight_scaling_factor.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
proj_weight_scales=module.w2_weight_scaling_factor.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
)
|
||||
|
||||
def load_expert_all_weight_scale_fp8_block_scale(
|
||||
self, module: torch.nn.Module, weights: Dict,
|
||||
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
|
||||
@ -827,16 +794,6 @@ class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase):
|
||||
fc2_weight_scale=module.fc2_weight_scale,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesINT8WoqPerChannel(
|
||||
fc31_weight_scale=module.fc31_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_weight_scale=module.fc2_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
)
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
@ -1012,26 +969,6 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
alpha_2=module.fc2_alpha,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesW4A8(
|
||||
scale_1_interleaved=module.fc31_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
scale_2_interleaved=module.fc2_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
pre_quant_scale_1=module.fc31_act_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
pre_quant_scale_2=module.fc2_act_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
zero_1=torch.Tensor(),
|
||||
zero_2=torch.Tensor(),
|
||||
alpha_1=module.fc31_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
alpha_2=module.fc2_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
)
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
@ -1384,15 +1321,6 @@ class WFP4A16FusedMoEMethod(FusedMoEMethodBase):
|
||||
scale_1_interleaved=module.fc31_weight_scale,
|
||||
scale_2_interleaved=module.fc2_weight_scale)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesW4A16MXFP4(
|
||||
scale_1_interleaved=module.fc31_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
scale_2_interleaved=module.fc2_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start))
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
@ -1788,22 +1716,6 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
|
||||
fc2_global=module.fc2_alpha,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesNVFP4(
|
||||
fc1_act_global=module.fc31_input_scale,
|
||||
fc1_weight_block=module.w3_w1_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc1_global=module.fc31_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
fc2_act_global=module.fc2_input_scale,
|
||||
fc2_weight_block=module.w2_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_global=module.fc2_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
)
|
||||
|
||||
|
||||
class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
|
||||
@ -1907,13 +1819,6 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = tuple()
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
The TRTLLM-Gen backend of FusedMoE does not use FusedMoEQuantScalesNVFP4.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
@ -2252,11 +2157,6 @@ class MXFP4WeightFusedMoEMethod(FusedMoEMethodBase):
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
pass
|
||||
|
||||
|
||||
class MXFP4WeightCutlassFusedMoEMethod(MXFP4WeightFusedMoEMethod):
|
||||
weight_dtype = FUSED_MOE_MXFP4_WEIGHT_DTYPE
|
||||
@ -2455,20 +2355,6 @@ class W4A8MXFP4MXFP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod):
|
||||
fc2_dequant_scale=module.fake_input_scale,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesW4A8MXFP4MXFP8(
|
||||
fc31_weight_block_scale=module.w3_w1_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc31_dequant_scale=module.fake_input_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_weight_block_scale=module.w2_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_dequant_scale=module.fake_input_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
)
|
||||
|
||||
|
||||
class W4A8MXFP4FP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod):
|
||||
|
||||
@ -2554,21 +2440,6 @@ class W4A8MXFP4FP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod):
|
||||
fc2_dequant_scale=module.fc2_input_dequant,
|
||||
)
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
assert module.smart_router
|
||||
return FusedMoEQuantScalesW4A8MXFP4FP8(
|
||||
fc31_weight_block_scale=module.w3_w1_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc31_dequant_scale=module.fc31_input_dequant.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_input_scale=module.fc2_input_scale,
|
||||
fc2_weight_block_scale=module.w2_weight_scale.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_dequant_scale=module.fc2_input_dequant.narrow(
|
||||
0, slot_start, slot_end - slot_start),
|
||||
)
|
||||
|
||||
|
||||
class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod):
|
||||
weight_dtype = torch.uint8
|
||||
@ -2595,13 +2466,6 @@ class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod):
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = tuple()
|
||||
|
||||
def get_quant_scales(self, module: torch.nn.Module, slot_start,
|
||||
slot_end) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
The TRTLLM-Gen backend of FusedMoE does not use FusedMoEQuantScales.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user