mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-01 00:31:24 +08:00
[Deepseek] Refactor Deepseek Decoder layer (#4016)
Refactor Deepseek Decoder layer Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com> Co-authored-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
This commit is contained in:
parent
bb766eca0a
commit
26a2679217
@ -455,7 +455,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
return True
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, min_latency_mode):
|
||||
all_rank_num_tokens, cutlass_min_latency_mode):
|
||||
# max-throughput
|
||||
if self.use_dp and self.mapping.tp_size > 1 and not self.enable_alltoall:
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
@ -473,7 +473,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
routed_output = self.experts(hidden_states_fp4 or hidden_states,
|
||||
router_logits,
|
||||
min_latency_mode,
|
||||
cutlass_min_latency_mode,
|
||||
output_dtype=hidden_states.dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens)
|
||||
|
||||
@ -485,9 +485,9 @@ class Deepseekv3MoE(nn.Module):
|
||||
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
|
||||
all_rank_num_tokens: Optional[list[int]] = None,
|
||||
final_all_reduce_params: Optional[AllReduceParams] = None,
|
||||
min_latency_mode: Optional[bool] = False,
|
||||
cutlass_min_latency_mode: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
if min_latency_mode:
|
||||
if cutlass_min_latency_mode:
|
||||
assert not self.use_dp
|
||||
|
||||
def _compute_shared_output():
|
||||
@ -498,10 +498,9 @@ class Deepseekv3MoE(nn.Module):
|
||||
return shared_output
|
||||
|
||||
def _compute_routed_output():
|
||||
routed_output = self.compute_routed_output(hidden_states,
|
||||
hidden_states_fp4,
|
||||
all_rank_num_tokens,
|
||||
min_latency_mode)
|
||||
routed_output = self.compute_routed_output(
|
||||
hidden_states, hidden_states_fp4, all_rank_num_tokens,
|
||||
cutlass_min_latency_mode)
|
||||
return routed_output
|
||||
|
||||
shared_output, routed_output = maybe_execute_in_parallel(
|
||||
@ -509,7 +508,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
self.event_dict[EventType.Main],
|
||||
self.event_dict[EventType.MoeShared], self.aux_stream)
|
||||
|
||||
if min_latency_mode:
|
||||
if cutlass_min_latency_mode:
|
||||
return [shared_output, *routed_output]
|
||||
else:
|
||||
assert shared_output.size() == routed_output.size(
|
||||
@ -531,6 +530,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
config = model_config.pretrained_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.moe_intermediate_size = config.moe_intermediate_size
|
||||
self.num_experts = config.n_routed_experts
|
||||
@ -544,16 +544,17 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.Attention])
|
||||
self.fusion_config = EagerFusionConfig()
|
||||
self.enable_attention_dp = mapping.enable_attention_dp
|
||||
|
||||
self.mlp_tp_size = mapping.tp_size
|
||||
|
||||
pp_layer_offset = mapping.pp_layers(config.num_hidden_layers)[0]
|
||||
global_layer_idx = pp_layer_offset + layer_idx
|
||||
|
||||
enable_fusion = os.environ.get("TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED",
|
||||
"0") == "0"
|
||||
self.enable_fusion = enable_fusion and not self.enable_attention_dp
|
||||
self.fusion_config = EagerFusionConfig()
|
||||
self.enable_fusion = os.environ.get(
|
||||
"TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0"
|
||||
self.enable_fusion &= not self.enable_attention_dp
|
||||
|
||||
# FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
|
||||
self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4()
|
||||
@ -584,8 +585,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self.mlp_tp_size = self._compute_mlp_tp_size(
|
||||
config.intermediate_size, block_size)
|
||||
|
||||
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_tp and self.is_nvfp4
|
||||
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not has_pp
|
||||
has_mlp_tp = self.mlp_tp_size > 1
|
||||
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4
|
||||
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp and not has_pp
|
||||
|
||||
self.mlp = GatedMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
@ -643,8 +645,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
)
|
||||
return mlp_tp_size
|
||||
|
||||
def _enable_latency_mode(self, num_tokens: int):
|
||||
return num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS'
|
||||
def _enable_min_latency_mode(self, num_tokens: int):
|
||||
return (num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION
|
||||
and self.is_nvfp4
|
||||
and self.model_config.moe_backend == 'CUTLASS')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -668,24 +672,78 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
min_latency_mode = self._enable_latency_mode(hidden_states.size(0))
|
||||
if isinstance(self.mlp, Deepseekv3MoE):
|
||||
return self.forward_MoE(
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
else:
|
||||
assert isinstance(self.mlp, GatedMLP)
|
||||
return self.forward_mlp(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
hidden_states_fp4 = None
|
||||
if self.fusion_config.PRE_MOE_FUSION:
|
||||
if min_latency_mode:
|
||||
hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=AllReduceFusionOp.
|
||||
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
|
||||
residual=residual,
|
||||
norm_weight=self.post_attention_layernorm.weight,
|
||||
scale=self.mlp.experts.fc31_input_scale,
|
||||
eps=self.post_attention_layernorm.variance_epsilon,
|
||||
))
|
||||
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
|
||||
hidden_states_sf)
|
||||
else:
|
||||
def forward_MoE(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def _run_MoE(hidden_states, hidden_states_fp4):
|
||||
return self.mlp(
|
||||
hidden_states,
|
||||
hidden_states_fp4,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)),
|
||||
cutlass_min_latency_mode=cutlass_min_latency_mode,
|
||||
)
|
||||
|
||||
cutlass_min_latency_mode = self._enable_min_latency_mode(
|
||||
hidden_states.shape[0])
|
||||
|
||||
if cutlass_min_latency_mode:
|
||||
assert self.fusion_config.PRE_MOE_FUSION and self.fusion_config.POST_MOE_FUSION
|
||||
assert self.model_config.moe_backend == 'CUTLASS'
|
||||
|
||||
hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=AllReduceFusionOp.
|
||||
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
|
||||
residual=residual,
|
||||
norm_weight=self.post_attention_layernorm.weight,
|
||||
scale=self.mlp.experts.fc31_input_scale,
|
||||
eps=self.post_attention_layernorm.variance_epsilon,
|
||||
))
|
||||
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
|
||||
hidden_states_sf)
|
||||
|
||||
hidden_states = _run_MoE(hidden_states, hidden_states_fp4)
|
||||
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
experts_to_token_score = hidden_states[3]
|
||||
|
||||
# MoE_finalize is fused into allreduce
|
||||
hidden_states, residual = self.moe_allreduce(
|
||||
residual,
|
||||
self.next_layer_layernorm.weight,
|
||||
device_num_experts=num_activated_experts_per_node,
|
||||
scale_input=experts_to_token_score,
|
||||
active_experts_token_input=hidden_states_activated_experts,
|
||||
token_input=shared_output,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
if self.fusion_config.PRE_MOE_FUSION:
|
||||
# moe_backend can be either CUTLASS or TRTLLM here
|
||||
# TODO: unify the two min-latency MoE backends by enabling quant fusion
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
@ -694,7 +752,36 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
norm_weight=self.post_attention_layernorm.weight,
|
||||
eps=self.post_attention_layernorm.variance_epsilon,
|
||||
))
|
||||
elif self.fusion_config.PRE_MLP_FUSION:
|
||||
else:
|
||||
# No fusion
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = _run_MoE(hidden_states, hidden_states_fp4=None)
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
else:
|
||||
if self.next_layer_layernorm is not None:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_mlp(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.fusion_config.PRE_MLP_FUSION:
|
||||
act_fp4, act_sf, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
@ -710,54 +797,13 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if self.fusion_config.PRE_MOE_FUSION and min_latency_mode:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
hidden_states_fp4,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
|
||||
min_latency_mode=min_latency_mode,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
|
||||
min_latency_mode=min_latency_mode,
|
||||
)
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)),
|
||||
)
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
if min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
experts_to_token_score = hidden_states[3]
|
||||
|
||||
hidden_states, residual = self.moe_allreduce(
|
||||
residual,
|
||||
self.next_layer_layernorm.weight,
|
||||
device_num_experts=num_activated_experts_per_node,
|
||||
scale_input=experts_to_token_score,
|
||||
active_experts_token_input=hidden_states_activated_experts,
|
||||
token_input=shared_output,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
elif self.fusion_config.POST_MLP_FUSION:
|
||||
if self.fusion_config.POST_MLP_FUSION:
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
@ -851,13 +897,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
else:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
# Fully Connected
|
||||
|
||||
# MoE
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
|
||||
or self.enable_attention_dp)),
|
||||
final_all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)),
|
||||
)
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
|
||||
@ -295,10 +295,10 @@ class Llama4MoE(nn.Module):
|
||||
self.aux_stream = aux_stream
|
||||
|
||||
def compute_routed_output(self, hidden_states, all_rank_num_tokens,
|
||||
min_latency_mode):
|
||||
cutlass_min_latency_mode):
|
||||
router_logits = self.router(hidden_states)
|
||||
routed_output = self.experts(hidden_states, router_logits,
|
||||
min_latency_mode)
|
||||
cutlass_min_latency_mode)
|
||||
return routed_output
|
||||
|
||||
def forward(
|
||||
@ -306,16 +306,16 @@ class Llama4MoE(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
all_rank_num_tokens=None,
|
||||
final_all_reduce_params: Optional[AllReduceParams] = None,
|
||||
min_latency_mode: Optional[bool] = False,
|
||||
cutlass_min_latency_mode: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
|
||||
# This design is mainly for low latency use case. Need to improve for max throughput use case.
|
||||
fn0 = lambda: self.shared_expert(hidden_states)
|
||||
fn1 = lambda: self.compute_routed_output(
|
||||
hidden_states, all_rank_num_tokens, min_latency_mode)
|
||||
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
|
||||
shared_output, routed_output = maybe_execute_in_parallel(
|
||||
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
|
||||
if min_latency_mode:
|
||||
if cutlass_min_latency_mode:
|
||||
return [shared_output, *routed_output]
|
||||
|
||||
assert shared_output.size() == routed_output.size(
|
||||
@ -414,12 +414,12 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
# TODO: Remove it after we fix crash on Hopper
|
||||
# major, minor = torch.cuda.get_device_capability()
|
||||
# is_blackwell = (major * 10 + minor) >= 100
|
||||
# min_latency_mode = hidden_states.size(
|
||||
# cutlass_min_latency_mode = hidden_states.size(
|
||||
# 0
|
||||
# ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted
|
||||
|
||||
# Temporarily disable min-latency mode for Llama4
|
||||
min_latency_mode = False
|
||||
cutlass_min_latency_mode = False
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
@ -456,7 +456,7 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
|
||||
POST_MLP_FUSION or self.mapping.tp_size == 1)),
|
||||
min_latency_mode=min_latency_mode,
|
||||
cutlass_min_latency_mode=cutlass_min_latency_mode,
|
||||
)
|
||||
if spec_metadata is not None:
|
||||
# We save the hidden states in the spec metadata here. In _prepare_draft_tokens,
|
||||
@ -467,7 +467,7 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual)
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION:
|
||||
if min_latency_mode:
|
||||
if cutlass_min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
|
||||
@ -468,8 +468,11 @@ class MLA(nn.Module):
|
||||
self.mha.update_quant_config(self.quant_config)
|
||||
self.mqa.update_quant_config(self.quant_config)
|
||||
|
||||
has_fp8_block_scales = self.quant_config and self.quant_config.quant_mode.has_fp8_block_scales(
|
||||
)
|
||||
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
|
||||
# which can be modified after __init__
|
||||
has_fp8_block_scales = (
|
||||
self.kv_b_proj.quant_config
|
||||
and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())
|
||||
|
||||
mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype
|
||||
self.k_b_proj_trans = nn.Parameter(
|
||||
@ -693,6 +696,7 @@ class MLA(nn.Module):
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
if self.k_b_proj_trans.dtype == torch.bfloat16:
|
||||
# [num_heads, num_tokens, self.qk_nope_head_dim]
|
||||
q_nope_t = q_nope.transpose(0, 1)
|
||||
|
||||
@ -265,7 +265,7 @@ class FusedMoE(nn.Module):
|
||||
equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter
|
||||
|
||||
trtllm_gen backend (moe_backend="TRTLLM"):
|
||||
min-latency mode (min_latency_mode flag of forward has no effect when trtllm_gen is used):
|
||||
min-latency mode (cutlass_min_latency_mode flag of forward has no effect when trtllm_gen is used):
|
||||
dynamic quant + FusedMoe Op
|
||||
equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute
|
||||
|
||||
@ -689,7 +689,7 @@ class FusedMoE(nn.Module):
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
min_latency_mode: bool = False,
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens=None,
|
||||
) -> torch.Tensor:
|
||||
@ -766,7 +766,7 @@ class FusedMoE(nn.Module):
|
||||
x_sf = reswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
|
||||
if self.smart_router and not min_latency_mode:
|
||||
if self.smart_router and not cutlass_min_latency_mode:
|
||||
ep_size = self.cluster_size
|
||||
ep_rank = self.cluster_rank
|
||||
expert_start = ep_rank * self.num_experts // ep_size
|
||||
@ -808,15 +808,15 @@ class FusedMoE(nn.Module):
|
||||
cluster_size=cluster_size,
|
||||
cluster_rank=cluster_rank,
|
||||
use_fp8_block_scaling=use_fp8_block_scaling,
|
||||
min_latency_mode=min_latency_mode,
|
||||
min_latency_mode=cutlass_min_latency_mode,
|
||||
)
|
||||
|
||||
if min_latency_mode:
|
||||
if cutlass_min_latency_mode:
|
||||
assert not self.reduce_results
|
||||
return final_hidden_states
|
||||
else:
|
||||
# Custom op requires all inputs are in the same type.
|
||||
# Only in min_latency_mode, the output is a list of tensors.
|
||||
# Only in cutlass_min_latency_mode, the output is a list of tensors.
|
||||
# Otherwise, the output should be unpacked as a single tensor.
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
|
||||
@ -830,16 +830,17 @@ class FusedMoE(nn.Module):
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
min_latency_mode: bool = False,
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
min_latency_mode has no effect when trtllm_gen backend is enabled.
|
||||
cutlass_min_latency_mode has no effect when trtllm_gen backend is enabled.
|
||||
"""
|
||||
if self.is_cutlass():
|
||||
return self.forward_cutlass(x, router_logits, min_latency_mode,
|
||||
output_dtype, all_rank_num_tokens)
|
||||
return self.forward_cutlass(x, router_logits,
|
||||
cutlass_min_latency_mode, output_dtype,
|
||||
all_rank_num_tokens)
|
||||
elif self.is_trtllm():
|
||||
return self.forward_trtllmgen(x, router_logits)
|
||||
else:
|
||||
@ -851,7 +852,7 @@ class FusedMoE(nn.Module):
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
min_latency_mode: bool = False,
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -866,16 +867,16 @@ class FusedMoE(nn.Module):
|
||||
num_rows = x.shape[0]
|
||||
num_chunks = (num_rows + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
if min_latency_mode:
|
||||
if cutlass_min_latency_mode:
|
||||
assert num_chunks == 1 and (
|
||||
not self.reduce_results
|
||||
), "min_latency_mode must be used with a single chunk and reduce_results must be False"
|
||||
), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False"
|
||||
|
||||
if num_chunks == 1:
|
||||
outputs = self.forward_chunk(
|
||||
x,
|
||||
router_logits,
|
||||
min_latency_mode,
|
||||
cutlass_min_latency_mode,
|
||||
output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens)
|
||||
outputs = self.reducescatter_or_allreduce(outputs)
|
||||
|
||||
@ -104,13 +104,12 @@ class GatedMLP(nn.Module):
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
all_rank_num_tokens=None,
|
||||
final_all_reduce_params: Optional[AllReduceParams] = None,
|
||||
min_latency_mode: Optional[bool] = False,
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if lora_params is not None:
|
||||
return self.forward_lora(x, all_rank_num_tokens,
|
||||
final_all_reduce_params, min_latency_mode,
|
||||
lora_params)
|
||||
final_all_reduce_params, lora_params)
|
||||
|
||||
if self.activation == F.silu:
|
||||
h1 = self.gate_up_proj(x)
|
||||
@ -146,7 +145,6 @@ class GatedMLP(nn.Module):
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
all_rank_num_tokens=None,
|
||||
final_all_reduce_params: Optional[AllReduceParams] = None,
|
||||
min_latency_mode: Optional[bool] = False,
|
||||
lora_params: Optional[dict] = None,
|
||||
) -> torch.Tensor:
|
||||
assert lora_params is not None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user