Refactor Deepseek tp_size calculation (#3695)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1 2025-04-19 23:55:19 -07:00 committed by GitHub
parent d51ae53940
commit 17eba98445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 33 deletions

View File

@ -126,6 +126,12 @@ class ModelConfig(Generic[TConfig]):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
block_size = hf_quant_config.get("weight_block_size", [])
assert tuple(block_size) == (
128,
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.group_size = block_size[0]
return cls(pretrained_config=pretrained_config,
quant_config=quant_config,
quant_config_dict=layer_quant_config,

View File

@ -272,22 +272,15 @@ class Deepseekv3MoE(nn.Module):
model_config=model_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap])
self.shared_output_scale = None
# The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128.
assert shared_expert_intermediate_size % 128 == 0
if self.use_dp:
# If using attention DP, the shared experts also use DP instead of TP.
shared_tp_size = 1
else:
# Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16.
# The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes.
shared_tp_size = math.gcd(
shared_expert_intermediate_size // 128,
model_config.mapping.tp_size,
)
# If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce.
if shared_tp_size != model_config.mapping.tp_size:
self.shared_output_scale = shared_tp_size / model_config.mapping.tp_size
self.mapping = model_config.mapping
# FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
block_size = 1
if model_config.quant_config and model_config.quant_config.group_size is not None:
block_size = model_config.quant_config.group_size
shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size(
shared_expert_intermediate_size, block_size)
self.shared_experts = GatedMLP(
hidden_size=hidden_size,
@ -298,7 +291,6 @@ class Deepseekv3MoE(nn.Module):
overridden_tp_size=shared_tp_size,
reduce_output=False)
self.mapping = model_config.mapping
self.all_reduce = AllReduce(self.mapping)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
@ -306,6 +298,42 @@ class Deepseekv3MoE(nn.Module):
for key in [EventType.Main, EventType.MoeShared]
}
def _compute_shared_expert_tp_size(self, intermediate_size: int,
block_size: int) -> int:
"""
In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size.
For example, when the intermediate_size is 2048 and block scaling size is 128,
TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16.
Args:
intermediate_size (int): MLP intermediate size.
block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe,
it's 128. For NVFP4, it's 16.
Returns:
int: The computed tp_size.
"""
assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size."
shared_output_scale = None
# The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128.
if self.use_dp:
# If using attention DP, the shared experts also use DP instead of TP.
shared_tp_size = 1
else:
# Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16.
# The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes.
shared_tp_size = math.gcd(
intermediate_size // block_size,
self.mapping.tp_size,
)
# If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce.
if shared_tp_size != self.mapping.tp_size:
shared_output_scale = shared_tp_size / self.mapping.tp_size
return shared_tp_size, shared_output_scale
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, min_latency_mode):
# max-throughput
@ -405,6 +433,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
"0") == "0"
self.enable_fusion = enable_fusion and not self.enable_attention_dp
# FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4()
has_tp = mapping.has_tp()
has_pp = mapping.has_pp()
@ -427,22 +456,11 @@ class DeepseekV3DecoderLayer(DecoderLayer):
model_config=model_config,
aux_stream_dict=aux_stream_dict)
else:
# The block scale size is 128, which requires intermediate_size to be divisible by 128.
assert config.intermediate_size % 128 == 0
if self.enable_attention_dp:
# If using attention DP, the MLP also uses DP instead of TP.
self.mlp_tp_size = 1
else:
# Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16.
# To avoid the costly inter-node all-reduce, we further restrict TP size to be divisible by gpus_per_node.
# The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes.
self.mlp_tp_size = math.gcd(
math.gcd(
config.intermediate_size // 128,
mapping.tp_size,
),
mapping.gpus_per_node, # Avoid costly inter-node TP
)
block_size = 1
if model_config.quant_config and model_config.quant_config.group_size is not None:
block_size = model_config.quant_config.group_size
self.mlp_tp_size = self._compute_mlp_tp_size(
config.intermediate_size, block_size)
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_tp and self.is_nvfp4
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not has_pp
@ -479,6 +497,37 @@ class DeepseekV3DecoderLayer(DecoderLayer):
if not self.deepseek_allreduce_disabled:
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
def _compute_mlp_tp_size(self, intermediate_size: int,
block_size: int) -> int:
"""
For DeepSeekR1, MLP TP size is limited by intermediate_size // block_size
and must also be multiples of gpus_per_node to avoid expensive internode allreduce.
Args:
intermediate_size (int): MLP intermediate size.
block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe,
it's 128. For NVFP4, it's 16.
Returns:
int: The computed tp_size.
"""
assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size."
if self.enable_attention_dp:
# If using attention DP, the MLP also uses DP instead of TP.
mlp_tp_size = 1
else:
# The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes.
mlp_tp_size = math.gcd(
math.gcd(
intermediate_size // block_size,
self.mapping.tp_size,
),
self.mapping.gpus_per_node, # Avoid costly inter-node TP
)
return mlp_tp_size
def forward(
self,
position_ids: torch.LongTensor,