mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Optionally split MoE inputs into chunks to reduce GPU memory usage (#3104)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Co-authored-by: raccoonliukai <raccoonliu@tencent.com>
This commit is contained in:
parent
727d78e785
commit
992d513bc6
@ -23,6 +23,8 @@ class ModelConfig(Generic[TConfig]):
|
||||
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
|
||||
skip_create_weights: bool = False
|
||||
is_generation: bool = True
|
||||
max_num_tokens: int = 8192
|
||||
moe_max_num_tokens: Optional[int] = None
|
||||
|
||||
attn_backend: str = 'TRTLLM'
|
||||
|
||||
|
||||
@ -15,8 +15,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
|
||||
DeepseekAllReduce, ParallelConfig, allgather,
|
||||
reducescatter)
|
||||
DeepseekAllReduce, ParallelConfig, allgather)
|
||||
from ..model_config import ModelConfig
|
||||
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
|
||||
from ..modules.attention import MLA
|
||||
@ -29,7 +28,8 @@ from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from ..pipeline_interface import PipelineInterface
|
||||
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
|
||||
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from ..utils import Fp4QuantizedTensor, disable_fp4_allgather
|
||||
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
|
||||
disable_fp4_allgather)
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
EagerFusionConfig, register_auto_model)
|
||||
|
||||
@ -258,9 +258,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
shared_expert_intermediate_size: int,
|
||||
aux_stream: torch.cuda.Stream,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
model_config: ModelConfig = ModelConfig()):
|
||||
from ..distributed import AllReduce
|
||||
|
||||
@ -284,8 +283,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
dtype=dtype,
|
||||
reduce_results=
|
||||
False, # In both low latency and attention dp scenarios, FusedMoE needs not to do allreduce inside op.
|
||||
tune_max_num_tokens=tune_max_num_tokens,
|
||||
model_config=model_config)
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap])
|
||||
|
||||
self.shared_output_scale = None
|
||||
if self.use_dp:
|
||||
@ -314,18 +313,11 @@ class Deepseekv3MoE(nn.Module):
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank)
|
||||
self.all_reduce = AllReduce(self.parallel_config)
|
||||
self.aux_stream = aux_stream
|
||||
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
def reduce_scatter(self, input_tensor, all_rank_num_tokens):
|
||||
world_size = self.parallel_config.tensor_parallel_size
|
||||
rank = self.parallel_config.tensor_parallel_rank
|
||||
if world_size == 1:
|
||||
return input_tensor
|
||||
dst_tensor = input_tensor
|
||||
outputs = reducescatter(dst_tensor, self.parallel_config, scatter_dim=0)
|
||||
depad_tensors = outputs[:all_rank_num_tokens[rank]]
|
||||
return depad_tensors
|
||||
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
for key in [EventType.Main, EventType.MoeShared]
|
||||
}
|
||||
|
||||
def compute_routed_output(self, hidden_states, all_rank_num_tokens):
|
||||
if self.use_dp and self.parallel_config.tensor_parallel_size > 1:
|
||||
@ -338,10 +330,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
self.parallel_config,
|
||||
gather_dim=0)
|
||||
router_logits = self.gate(hidden_states)
|
||||
routed_output = self.experts(hidden_states, router_logits)
|
||||
if self.use_dp:
|
||||
routed_output = self.reduce_scatter(routed_output,
|
||||
all_rank_num_tokens)
|
||||
routed_output = self.experts(hidden_states, router_logits,
|
||||
all_rank_num_tokens)
|
||||
return routed_output
|
||||
|
||||
def forward(
|
||||
@ -354,17 +344,17 @@ class Deepseekv3MoE(nn.Module):
|
||||
# This design is mainly for low latency use case. Need to improve for max throughput use case.
|
||||
do_multi_stream = is_graph_capturing()
|
||||
if do_multi_stream:
|
||||
self.moe_event[0].record()
|
||||
self.event_dict[EventType.Main].record()
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
if self.shared_output_scale is not None:
|
||||
shared_output *= self.shared_output_scale
|
||||
if do_multi_stream:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.moe_event[0].wait()
|
||||
self.event_dict[EventType.Main].wait()
|
||||
routed_output = self.compute_routed_output(
|
||||
hidden_states, all_rank_num_tokens)
|
||||
self.moe_event[1].record()
|
||||
self.moe_event[1].wait()
|
||||
self.event_dict[EventType.MoeShared].record()
|
||||
self.event_dict[EventType.MoeShared].wait()
|
||||
else:
|
||||
routed_output = self.compute_routed_output(hidden_states,
|
||||
all_rank_num_tokens)
|
||||
@ -381,7 +371,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int, aux_stream: torch.cuda.Stream):
|
||||
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
|
||||
torch.cuda.Stream]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -390,9 +381,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self.num_shared_experts = config.n_shared_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.self_attn = DeepseekV3Attention(model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream=aux_stream)
|
||||
self.self_attn = DeepseekV3Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.Attention])
|
||||
self.fusion_config = EagerFusionConfig()
|
||||
self.enable_attention_dp = model_config.mapping.enable_attention_dp
|
||||
self.mlp_tp_size = model_config.mapping.tp_size
|
||||
@ -420,9 +412,8 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
shared_expert_intermediate_size=self.moe_intermediate_size *
|
||||
self.num_shared_experts,
|
||||
dtype=config.torch_dtype,
|
||||
tune_max_num_tokens=config.max_position_embeddings // 4,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream)
|
||||
aux_stream_dict=aux_stream_dict)
|
||||
else:
|
||||
if self.enable_attention_dp:
|
||||
self.mlp_tp_size = 1
|
||||
@ -606,8 +597,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int, aux_stream: torch.cuda.Stream):
|
||||
super().__init__(model_config, layer_idx, aux_stream)
|
||||
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
|
||||
torch.cuda.Stream]):
|
||||
super().__init__(model_config, layer_idx, aux_stream_dict)
|
||||
config = model_config.pretrained_config
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.moe_intermediate_size = config.moe_intermediate_size
|
||||
@ -735,14 +727,21 @@ class DeepseekV3Model(DecoderModel):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.aux_stream_dict = {
|
||||
key: torch.cuda.Stream()
|
||||
for key in [
|
||||
AuxStreamType.Attention, AuxStreamType.MoeShared,
|
||||
AuxStreamType.MoeChunkingOverlap
|
||||
]
|
||||
}
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekV3DecoderLayer(model_config, layer_idx, self.aux_stream)
|
||||
DeepseekV3DecoderLayer(model_config, layer_idx,
|
||||
self.aux_stream_dict)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
@ -842,7 +841,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
|
||||
if ckpt_nextn == 1:
|
||||
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
|
||||
self.model.aux_stream)
|
||||
self.model.aux_stream_dict)
|
||||
self.model.layers.append(mtp_layer)
|
||||
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
|
||||
else:
|
||||
@ -851,7 +850,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
mtp_layers = nn.ModuleList([
|
||||
DeepseekV3MTP(model_config,
|
||||
layer_idx + self.num_hidden_layers,
|
||||
self.model.aux_stream)
|
||||
self.model.aux_stream_dict)
|
||||
for layer_idx in range(model_nextn)
|
||||
])
|
||||
self.model.layers.extend(mtp_layers)
|
||||
|
||||
@ -8,7 +8,7 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..distributed import ParallelConfig, allgather
|
||||
from ..distributed import ParallelConfig
|
||||
from ..model_config import ModelConfig
|
||||
from ..models.modeling_utils import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
@ -23,13 +23,18 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
aux_stream: torch.cuda.Stream,
|
||||
):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.enable_attention_dp = model_config.mapping.enable_attention_dp
|
||||
|
||||
# moe gate (linear layer) only runs in half/full precision for now
|
||||
self.gate = Linear(self.hidden_dim,
|
||||
@ -45,19 +50,26 @@ class MixtralMoE(nn.Module):
|
||||
routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k),
|
||||
hidden_size=self.hidden_dim,
|
||||
intermediate_size=self.ffn_dim,
|
||||
aux_stream=aux_stream,
|
||||
dtype=config.torch_dtype,
|
||||
reduce_results=reduce_results,
|
||||
tune_max_num_tokens=config.max_position_embeddings // 4,
|
||||
model_config=model_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert hidden_states.shape[-1] == self.hidden_dim
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
if self.enable_attention_dp and len(all_rank_num_tokens) > 1:
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
return final_hidden_states.view(orig_shape), router_logits
|
||||
final_hidden_states = self.experts(hidden_states, router_logits,
|
||||
all_rank_num_tokens)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class MixtralRotaryEmbedding(RotaryEmbedding):
|
||||
@ -105,14 +117,14 @@ class MixtralAttention(Attention):
|
||||
class MixtralDecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int):
|
||||
layer_idx: int, aux_stream: torch.cuda.Stream):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = MixtralAttention(model_config, layer_idx=layer_idx)
|
||||
|
||||
self.block_sparse_moe = MixtralMoE(model_config)
|
||||
self.block_sparse_moe = MixtralMoE(model_config, aux_stream)
|
||||
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
@ -121,7 +133,6 @@ class MixtralDecoderLayer(DecoderLayer):
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.enable_attention_dp = model_config.mapping.enable_attention_dp
|
||||
# TODO: add pipeline parallel config
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
@ -129,26 +140,6 @@ class MixtralDecoderLayer(DecoderLayer):
|
||||
gpus_per_node=model_config.mapping.gpus_per_node)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def all_gather(self, input_tensor, attn_metadata):
|
||||
rank = self.parallel_config.tensor_parallel_rank
|
||||
world_size = self.parallel_config.tensor_parallel_size
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
if world_size == 1:
|
||||
return input_tensor, 0, max_num_token
|
||||
|
||||
pad_tensor = torch.nn.functional.pad(
|
||||
input_tensor, (0, 0, 0, max_num_token - input_tensor.shape[0]))
|
||||
outputs = allgather(pad_tensor, self.parallel_config, gather_dim=0)
|
||||
depad_tensors = torch.concat([
|
||||
outputs[i * max_num_token:i * max_num_token +
|
||||
all_rank_num_tokens[i]] for i in range(world_size)
|
||||
])
|
||||
|
||||
cur_rank_start = 0 if rank == 0 else sum(all_rank_num_tokens[:rank])
|
||||
cur_rank_end = cur_rank_start + all_rank_num_tokens[rank]
|
||||
return depad_tensors, cur_rank_start, cur_rank_end
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
@ -175,12 +166,7 @@ class MixtralDecoderLayer(DecoderLayer):
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
if self.enable_attention_dp:
|
||||
hidden_states, cur_rank_start, cur_rank_end = self.all_gather(
|
||||
hidden_states, attn_metadata)
|
||||
hidden_states, _router_logits = self.block_sparse_moe(hidden_states)
|
||||
if self.enable_attention_dp:
|
||||
hidden_states = hidden_states[cur_rank_start:cur_rank_end]
|
||||
hidden_states = self.block_sparse_moe(hidden_states, attn_metadata)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -191,6 +177,7 @@ class MixtralModel(DecoderModel):
|
||||
config = model_config.pretrained_config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -198,7 +185,7 @@ class MixtralModel(DecoderModel):
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(model_config, layer_idx)
|
||||
MixtralDecoderLayer(model_config, layer_idx, self.aux_stream)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
|
||||
@ -5,9 +5,10 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...quantization.utils.fp4_utils import float4_sf_dtype
|
||||
from ..distributed import allgather
|
||||
from ..distributed import allgather, reducescatter
|
||||
from ..model_config import ModelConfig
|
||||
from ..utils import disable_fp4_allgather, is_torch_compiling, reswizzle_sf
|
||||
from ..utils import (EventType, disable_fp4_allgather, is_torch_compiling,
|
||||
reswizzle_sf)
|
||||
from .linear import ParallelConfig, TensorParallelMode, load_weight_shard
|
||||
|
||||
# The declarations aligns with moe_kernels.h
|
||||
@ -198,9 +199,9 @@ class FusedMoE(nn.Module):
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (torch.cuda.Stream): Auxiliary CUDA stream to overlap chunks.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
tune_max_num_tokens (int): Maximum number of tokens for performance tuning.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
"""
|
||||
|
||||
@ -213,8 +214,8 @@ class FusedMoE(nn.Module):
|
||||
intermediate_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: torch.cuda.Stream = torch.cuda.Stream(),
|
||||
):
|
||||
from ..distributed import AllReduce
|
||||
|
||||
@ -224,6 +225,12 @@ class FusedMoE(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.aux_stream = aux_stream
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
for key in [EventType.Main, EventType.MoeChunkingOverlap]
|
||||
}
|
||||
|
||||
self.dtype = dtype
|
||||
self.reduce_results = reduce_results
|
||||
# could be modified later
|
||||
@ -256,7 +263,18 @@ class FusedMoE(nn.Module):
|
||||
self.expert_start + self.expert_size_per_partition,
|
||||
self.num_experts)
|
||||
|
||||
self.tune_max_num_tokens = tune_max_num_tokens
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens
|
||||
if self.moe_max_num_tokens is None:
|
||||
self.moe_max_num_tokens = model_config.max_num_tokens
|
||||
if self.use_dp:
|
||||
self.moe_max_num_tokens *= model_config.mapping.world_size
|
||||
# The profiler converges on the same best tactic when the number of tokens is large enough.
|
||||
# To avoid long profiling time, the max number of tokens used in the profiling is capped to
|
||||
# around 16k tokens per expert, which is well into the compute bound domain.
|
||||
self.tune_max_num_tokens = min(
|
||||
self.moe_max_num_tokens,
|
||||
16384 * num_experts / routing_method.get_experts_per_token(),
|
||||
)
|
||||
self.has_been_profiled = False
|
||||
|
||||
self._weights_created = False
|
||||
@ -484,7 +502,18 @@ class FusedMoE(nn.Module):
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
def reducescatter_or_allreduce(self, inputs):
|
||||
outputs = inputs
|
||||
if self.parallel_size > 1:
|
||||
if self.use_dp:
|
||||
outputs = reducescatter(inputs,
|
||||
self.parallel_config,
|
||||
scatter_dim=0)
|
||||
elif self.reduce_results:
|
||||
outputs = self.all_reduce(inputs)
|
||||
return outputs
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@ -570,11 +599,65 @@ class FusedMoE(nn.Module):
|
||||
use_fp8_block_scaling=use_fp8_block_scaling,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.parallel_size > 1:
|
||||
final_hidden_states = self.all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_chunk_size = self.moe_max_num_tokens
|
||||
if self.use_dp:
|
||||
assert all_rank_num_tokens is not None
|
||||
if not disable_fp4_allgather():
|
||||
max_chunk_size //= len(all_rank_num_tokens)
|
||||
num_chunks = (x.shape[0] + max_chunk_size - 1) // max_chunk_size
|
||||
if num_chunks == 1:
|
||||
outputs = self.forward_chunk(x, router_logits)
|
||||
outputs = self.reducescatter_or_allreduce(outputs)
|
||||
else:
|
||||
val_div = x.shape[0] // num_chunks
|
||||
val_mod = x.shape[0] % num_chunks
|
||||
chunk_size_list = [val_div + 1
|
||||
] * val_mod + [val_div] * (num_chunks - val_mod)
|
||||
x_list = x.split(chunk_size_list)
|
||||
router_logits_list = router_logits.split(chunk_size_list)
|
||||
outputs_list = []
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
|
||||
for idx_chunk, (x, router_logits) in enumerate(
|
||||
zip(x_list, router_logits_list)):
|
||||
if idx_chunk % 2 == 0:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs = self.forward_chunk(x, router_logits)
|
||||
if idx_chunk > 0:
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1])
|
||||
else:
|
||||
outputs = self.forward_chunk(x, router_logits)
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1])
|
||||
outputs_list.append(outputs)
|
||||
if num_chunks % 2 == 0:
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1])
|
||||
else:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1])
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.MoeChunkingOverlap].record()
|
||||
self.event_dict[EventType.MoeChunkingOverlap].wait()
|
||||
outputs = torch.cat(outputs_list)
|
||||
if self.use_dp:
|
||||
rank = self.parallel_config.tensor_parallel_rank
|
||||
outputs = outputs[:all_rank_num_tokens[rank]]
|
||||
return outputs
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
assert self._weights_created
|
||||
assert len(weights) == 1
|
||||
|
||||
@ -46,6 +46,10 @@ class PyTorchConfig:
|
||||
# This is usually a net win for performance.
|
||||
cuda_graph_padding_enabled: bool = False
|
||||
enable_overlap_scheduler: bool = False
|
||||
max_num_tokens: int = 8192
|
||||
# If set, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time.
|
||||
# If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used.
|
||||
moe_max_num_tokens: Optional[int] = None
|
||||
|
||||
attn_backend: str = 'TRTLLM'
|
||||
# If true, will iterate over sampling_params of each request and use the
|
||||
|
||||
@ -257,6 +257,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mapping=self.mapping,
|
||||
attn_backend=attn_backend,
|
||||
load_format=pytorch_backend_config.load_format,
|
||||
max_num_tokens=max_num_tokens,
|
||||
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
|
||||
)
|
||||
if self.pytorch_backend_config.enable_layerwise_nvtx_marker:
|
||||
layerwise_nvtx_marker = LayerwiseNvtxMarker()
|
||||
@ -732,11 +734,13 @@ class PyTorchModelEngine(ModelEngine):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _load_model(self, checkpoint_dir: str, load_format: LoadFormat,
|
||||
**kwargs):
|
||||
max_num_tokens: int, moe_max_num_tokens: int, **kwargs):
|
||||
config = ModelConfig.from_pretrained(checkpoint_dir,
|
||||
trust_remote_code=True,
|
||||
**kwargs)
|
||||
config.spec_config = self.spec_config
|
||||
config.max_num_tokens = max_num_tokens
|
||||
config.moe_max_num_tokens = moe_max_num_tokens
|
||||
|
||||
validate_and_set_kv_cache_quant(
|
||||
config, self.pytorch_backend_config.kv_cache_dtype)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@ -10,6 +11,17 @@ from .pipeline_interface import PipelineInterface
|
||||
|
||||
is_torch_compiling_flag = False
|
||||
|
||||
aux_stream_name_list = ['Attention', 'MoeShared', 'MoeChunkingOverlap']
|
||||
AuxStreamType = Enum(
|
||||
'AuxStreamType',
|
||||
aux_stream_name_list,
|
||||
)
|
||||
EventType = Enum(
|
||||
'EventType',
|
||||
['Main', *aux_stream_name_list],
|
||||
start=0,
|
||||
)
|
||||
|
||||
|
||||
def set_torch_compiling(enable: bool):
|
||||
global is_torch_compiling_flag
|
||||
|
||||
Loading…
Reference in New Issue
Block a user