mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Refactor Deepseek tp_size calculation (#3695)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
parent
d51ae53940
commit
17eba98445
@ -126,6 +126,12 @@ class ModelConfig(Generic[TConfig]):
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
quant_config.exclude_modules = ["*eh_proj"]
|
||||
|
||||
block_size = hf_quant_config.get("weight_block_size", [])
|
||||
assert tuple(block_size) == (
|
||||
128,
|
||||
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
|
||||
quant_config.group_size = block_size[0]
|
||||
|
||||
return cls(pretrained_config=pretrained_config,
|
||||
quant_config=quant_config,
|
||||
quant_config_dict=layer_quant_config,
|
||||
|
||||
@ -272,22 +272,15 @@ class Deepseekv3MoE(nn.Module):
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap])
|
||||
|
||||
self.shared_output_scale = None
|
||||
# The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128.
|
||||
assert shared_expert_intermediate_size % 128 == 0
|
||||
if self.use_dp:
|
||||
# If using attention DP, the shared experts also use DP instead of TP.
|
||||
shared_tp_size = 1
|
||||
else:
|
||||
# Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16.
|
||||
# The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes.
|
||||
shared_tp_size = math.gcd(
|
||||
shared_expert_intermediate_size // 128,
|
||||
model_config.mapping.tp_size,
|
||||
)
|
||||
# If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce.
|
||||
if shared_tp_size != model_config.mapping.tp_size:
|
||||
self.shared_output_scale = shared_tp_size / model_config.mapping.tp_size
|
||||
self.mapping = model_config.mapping
|
||||
|
||||
# FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
|
||||
block_size = 1
|
||||
if model_config.quant_config and model_config.quant_config.group_size is not None:
|
||||
block_size = model_config.quant_config.group_size
|
||||
|
||||
shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size(
|
||||
shared_expert_intermediate_size, block_size)
|
||||
|
||||
self.shared_experts = GatedMLP(
|
||||
hidden_size=hidden_size,
|
||||
@ -298,7 +291,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
overridden_tp_size=shared_tp_size,
|
||||
reduce_output=False)
|
||||
|
||||
self.mapping = model_config.mapping
|
||||
self.all_reduce = AllReduce(self.mapping)
|
||||
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
|
||||
self.event_dict = {
|
||||
@ -306,6 +298,42 @@ class Deepseekv3MoE(nn.Module):
|
||||
for key in [EventType.Main, EventType.MoeShared]
|
||||
}
|
||||
|
||||
def _compute_shared_expert_tp_size(self, intermediate_size: int,
|
||||
block_size: int) -> int:
|
||||
"""
|
||||
In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size.
|
||||
For example, when the intermediate_size is 2048 and block scaling size is 128,
|
||||
TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16.
|
||||
|
||||
Args:
|
||||
intermediate_size (int): MLP intermediate size.
|
||||
block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe,
|
||||
it's 128. For NVFP4, it's 16.
|
||||
|
||||
Returns:
|
||||
int: The computed tp_size.
|
||||
"""
|
||||
|
||||
assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size."
|
||||
|
||||
shared_output_scale = None
|
||||
# The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128.
|
||||
if self.use_dp:
|
||||
# If using attention DP, the shared experts also use DP instead of TP.
|
||||
shared_tp_size = 1
|
||||
else:
|
||||
# Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16.
|
||||
# The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes.
|
||||
shared_tp_size = math.gcd(
|
||||
intermediate_size // block_size,
|
||||
self.mapping.tp_size,
|
||||
)
|
||||
# If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce.
|
||||
if shared_tp_size != self.mapping.tp_size:
|
||||
shared_output_scale = shared_tp_size / self.mapping.tp_size
|
||||
|
||||
return shared_tp_size, shared_output_scale
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, min_latency_mode):
|
||||
# max-throughput
|
||||
@ -405,6 +433,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
"0") == "0"
|
||||
self.enable_fusion = enable_fusion and not self.enable_attention_dp
|
||||
|
||||
# FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
|
||||
self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4()
|
||||
has_tp = mapping.has_tp()
|
||||
has_pp = mapping.has_pp()
|
||||
@ -427,22 +456,11 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
model_config=model_config,
|
||||
aux_stream_dict=aux_stream_dict)
|
||||
else:
|
||||
# The block scale size is 128, which requires intermediate_size to be divisible by 128.
|
||||
assert config.intermediate_size % 128 == 0
|
||||
if self.enable_attention_dp:
|
||||
# If using attention DP, the MLP also uses DP instead of TP.
|
||||
self.mlp_tp_size = 1
|
||||
else:
|
||||
# Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16.
|
||||
# To avoid the costly inter-node all-reduce, we further restrict TP size to be divisible by gpus_per_node.
|
||||
# The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes.
|
||||
self.mlp_tp_size = math.gcd(
|
||||
math.gcd(
|
||||
config.intermediate_size // 128,
|
||||
mapping.tp_size,
|
||||
),
|
||||
mapping.gpus_per_node, # Avoid costly inter-node TP
|
||||
)
|
||||
block_size = 1
|
||||
if model_config.quant_config and model_config.quant_config.group_size is not None:
|
||||
block_size = model_config.quant_config.group_size
|
||||
self.mlp_tp_size = self._compute_mlp_tp_size(
|
||||
config.intermediate_size, block_size)
|
||||
|
||||
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_tp and self.is_nvfp4
|
||||
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not has_pp
|
||||
@ -479,6 +497,37 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
if not self.deepseek_allreduce_disabled:
|
||||
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
|
||||
|
||||
def _compute_mlp_tp_size(self, intermediate_size: int,
|
||||
block_size: int) -> int:
|
||||
"""
|
||||
For DeepSeek‑R1, MLP TP size is limited by intermediate_size // block_size
|
||||
and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce.
|
||||
|
||||
Args:
|
||||
intermediate_size (int): MLP intermediate size.
|
||||
block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe,
|
||||
it's 128. For NVFP4, it's 16.
|
||||
|
||||
Returns:
|
||||
int: The computed tp_size.
|
||||
"""
|
||||
|
||||
assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size."
|
||||
|
||||
if self.enable_attention_dp:
|
||||
# If using attention DP, the MLP also uses DP instead of TP.
|
||||
mlp_tp_size = 1
|
||||
else:
|
||||
# The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes.
|
||||
mlp_tp_size = math.gcd(
|
||||
math.gcd(
|
||||
intermediate_size // block_size,
|
||||
self.mapping.tp_size,
|
||||
),
|
||||
self.mapping.gpus_per_node, # Avoid costly inter-node TP
|
||||
)
|
||||
return mlp_tp_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user