mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: some refactor on WideEP (#5727)
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
parent
64fd64fcf2
commit
dd3c736c7e
@ -11,8 +11,7 @@ from tensorrt_llm.mapping import Mapping
|
||||
from ...distributed import allgather, reducescatter
|
||||
from ...expert_statistic import ExpertStatistic
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
|
||||
reswizzle_sf, swizzle_sf, unswizzle_sf)
|
||||
from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf
|
||||
from .deep_ep_utils import buffer_pool, deep_ep_installed
|
||||
from .interface import MoE
|
||||
from .moe_load_balancer import get_moe_load_balancer
|
||||
@ -83,6 +82,12 @@ class WideEPMoE(MoE):
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
)
|
||||
|
||||
assert self.use_dp, "Attention DP should be used with WideEP."
|
||||
assert self.parallel_size > 1, "WideEP should only be enabled with parallel_size > 1"
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
assert self.apply_router_weight_on_input is False, "WideEP doesn't support apply_router_weight_on_input."
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
moe_load_balancer = get_moe_load_balancer()
|
||||
@ -128,9 +133,6 @@ class WideEPMoE(MoE):
|
||||
self.expert_size_per_partition = num_experts // self.ep_size
|
||||
self.num_slots = num_experts
|
||||
|
||||
if self.smart_router:
|
||||
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
|
||||
|
||||
self.slot_start = self.ep_rank * self.expert_size_per_partition
|
||||
self.slot_end = self.slot_start + self.expert_size_per_partition
|
||||
self.initial_local_expert_ids = self.initial_global_assignments[
|
||||
@ -140,8 +142,7 @@ class WideEPMoE(MoE):
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
if self.use_dp:
|
||||
max_num_tokens *= model_config.mapping.world_size
|
||||
max_num_tokens *= model_config.mapping.world_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
|
||||
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
|
||||
if self.moe_max_num_tokens < max_num_tokens:
|
||||
@ -163,7 +164,6 @@ class WideEPMoE(MoE):
|
||||
16384 * self.num_slots // routing_method.get_experts_per_token(),
|
||||
)
|
||||
self.has_been_profiled = False
|
||||
self.has_been_profiled_min_latency = False
|
||||
|
||||
self.alltoall_method_type = self.select_alltoall_method_type(
|
||||
model_config.mapping, routing_method.experts_per_token, dtype,
|
||||
@ -173,14 +173,15 @@ class WideEPMoE(MoE):
|
||||
key="alltoall_method_type")
|
||||
self.use_postquant_alltoall = False
|
||||
if self.enable_alltoall:
|
||||
assert self.use_dp and self.parallel_size > 1,\
|
||||
"alltoall should only enabled with attention dp and parallel_size > 1"
|
||||
qm = self.quant_config.quant_mode
|
||||
self.use_postquant_alltoall = (os.environ.get(
|
||||
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
|
||||
== "1") and qm.has_nvfp4()
|
||||
self.enable_alltoall_without_allgather = os.environ.get(
|
||||
"TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER", "0") == "1"
|
||||
# TODO: support alltoall without allgather for top_k % 4 != 0
|
||||
self.enable_alltoall_without_allgather = (
|
||||
os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER",
|
||||
"1") == "1"
|
||||
) and self.alltoall_method_type == AlltoallMethodType.MNNVL and routing_method.experts_per_token % 4 == 0
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
MnnvlMemory.initialize()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
@ -204,9 +205,6 @@ class WideEPMoE(MoE):
|
||||
f"Not available alltoall method type: {self.alltoall_method_type!r}"
|
||||
)
|
||||
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
|
||||
self._weights_created = False
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
@ -218,12 +216,6 @@ class WideEPMoE(MoE):
|
||||
def _check_configs(self):
|
||||
assert self._weights_created
|
||||
|
||||
if self.enable_alltoall:
|
||||
assert self.use_dp and self.parallel_size > 1,\
|
||||
"alltoall should only enabled with attention dp and parallel_size > 1"
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
|
||||
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
if not (self.quant_config.quant_mode.has_nvfp4()
|
||||
@ -323,24 +315,20 @@ class WideEPMoE(MoE):
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
):
|
||||
outputs = inputs
|
||||
if self.parallel_size > 1 and not self.enable_alltoall:
|
||||
if self.use_dp:
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
outputs = reducescatter(
|
||||
inputs,
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
elif self.reduce_results:
|
||||
outputs = self.all_reduce(inputs)
|
||||
if not self.enable_alltoall:
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
outputs = reducescatter(
|
||||
inputs,
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
return outputs
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
@ -373,17 +361,12 @@ class WideEPMoE(MoE):
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
|
||||
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
|
||||
x = x * token_final_scales.to(x.dtype)
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_first_call:
|
||||
self.layer_load_balancer.maybe_cudagraph_done_wait()
|
||||
|
||||
use_allgather = not self.enable_alltoall
|
||||
|
||||
loadbalancer_local_statistic_info = None
|
||||
gathered_loadbalancer_local_statistic_info = None
|
||||
token_selected_experts_for_statistic = None
|
||||
@ -464,29 +447,37 @@ class WideEPMoE(MoE):
|
||||
)
|
||||
|
||||
x_sf = None
|
||||
x_is_sf_swizzled = x.is_sf_swizzled if isinstance(
|
||||
x, Fp4QuantizedTensor) else False
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1]
|
||||
sf_swizzle = True
|
||||
if self.has_any_quant:
|
||||
if self.has_fp8_qdq:
|
||||
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, self.fc31_input_dequant)
|
||||
elif self.has_nvfp4:
|
||||
if not disable_fp4_allgather() or self.use_postquant_alltoall:
|
||||
if use_allgather or self.use_postquant_alltoall:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
if use_allgather:
|
||||
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
|
||||
x, x_sf = x.fp4_tensor, x.scaling_factor
|
||||
x_row = x.shape[0]
|
||||
# note: we use uint8 to store 2 fp4 values
|
||||
x_col = x.shape[1] * 2
|
||||
else:
|
||||
sf_swizzle = not self.use_postquant_alltoall
|
||||
# for both postquant alltoall and allgather, we need non swizzle layout
|
||||
needed_sf_swizzle = False
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1]
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size,
|
||||
False, sf_swizzle)
|
||||
x,
|
||||
self.fc31_input_scale,
|
||||
self.scaling_vector_size,
|
||||
sfUseUE8M0=False,
|
||||
swizzedLayout=needed_sf_swizzle)
|
||||
if self.use_postquant_alltoall:
|
||||
x_sf = x_sf.view((x_row, -1))
|
||||
x_is_sf_swizzled = needed_sf_swizzle
|
||||
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
use_deepseek_fp8_block_scale = True
|
||||
@ -498,8 +489,8 @@ class WideEPMoE(MoE):
|
||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||
)
|
||||
|
||||
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
|
||||
) and not self.enable_alltoall:
|
||||
if use_allgather:
|
||||
# using allgather case.
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
|
||||
@ -513,10 +504,11 @@ class WideEPMoE(MoE):
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
x_row = x.shape[0]
|
||||
# Fp4 gemm has extra scaling factor
|
||||
if x_sf is not None:
|
||||
x_sf = reswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
assert not x_is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
|
||||
x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
@ -532,42 +524,22 @@ class WideEPMoE(MoE):
|
||||
is_first_stage=is_first_call,
|
||||
is_last_stage=is_last_call)
|
||||
|
||||
if self.smart_router and not cutlass_min_latency_mode:
|
||||
ep_size = self.cluster_size
|
||||
ep_rank = self.cluster_rank
|
||||
expert_start = ep_rank * self.num_experts // ep_size
|
||||
expert_end = min(self.num_experts,
|
||||
(ep_rank + 1) * self.num_experts // ep_size)
|
||||
w3_w1_weight = self.w3_w1_weight.narrow(0, expert_start,
|
||||
expert_end - expert_start)
|
||||
w2_weight = self.w2_weight.narrow(0, expert_start,
|
||||
expert_end - expert_start)
|
||||
cluster_size = self.ep_size
|
||||
cluster_rank = self.ep_rank
|
||||
quant_scales = self.get_quant_scales(expert_start, expert_end)
|
||||
else:
|
||||
ep_size = self.ep_size
|
||||
ep_rank = self.ep_rank
|
||||
w3_w1_weight = self.w3_w1_weight
|
||||
w2_weight = self.w2_weight
|
||||
cluster_size = self.cluster_size
|
||||
cluster_rank = self.cluster_rank
|
||||
quant_scales = self.quant_scales
|
||||
ep_size = self.ep_size
|
||||
ep_rank = self.ep_rank
|
||||
w3_w1_weight = self.w3_w1_weight
|
||||
w2_weight = self.w2_weight
|
||||
cluster_size = self.cluster_size
|
||||
cluster_rank = self.cluster_rank
|
||||
quant_scales = self.quant_scales
|
||||
|
||||
if self.use_postquant_alltoall:
|
||||
if x_sf is not None and self.has_nvfp4:
|
||||
assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall"
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
x, x_sf = self.alltoall_postquant_dispatch(
|
||||
x,
|
||||
x_sf,
|
||||
x_row,
|
||||
x_col,
|
||||
alltoall_info,
|
||||
is_sf_swizzle=sf_swizzle)
|
||||
x, x_sf, alltoall_info)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
if x_sf is not None:
|
||||
if self.has_nvfp4 and sf_swizzle:
|
||||
x_sf = unswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
# Adapter between `x_sf` and DeepEP
|
||||
# TODO: remove the adapter by adding dtype support to DeepEP
|
||||
x_sf_dtype = x_sf.dtype
|
||||
@ -631,7 +603,7 @@ class WideEPMoE(MoE):
|
||||
enable_alltoall=self.enable_alltoall,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
use_w4a8_group_scaling=use_w4a8_group_scaling,
|
||||
min_latency_mode=cutlass_min_latency_mode,
|
||||
min_latency_mode=False,
|
||||
tune_max_num_tokens=self.tune_max_num_tokens,
|
||||
)
|
||||
|
||||
@ -639,14 +611,9 @@ class WideEPMoE(MoE):
|
||||
) and is_last_call:
|
||||
self.layer_load_balancer.set_cpu_stage()
|
||||
|
||||
if cutlass_min_latency_mode:
|
||||
assert not self.reduce_results
|
||||
assert not self.enable_alltoall
|
||||
else:
|
||||
# Custom op requires all inputs are in the same type.
|
||||
# Only in cutlass_min_latency_mode, the output is a list of tensors.
|
||||
# Otherwise, the output should be unpacked as a single tensor.
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
# Only in cutlass_min_latency_mode, the output is a list of tensors.
|
||||
# Otherwise, the output should be unpacked as a single tensor.
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
|
||||
if self.enable_alltoall:
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
@ -702,23 +669,13 @@ class WideEPMoE(MoE):
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_dp:
|
||||
assert all_rank_num_tokens is not None
|
||||
assert use_dp_padding is not None
|
||||
num_rows = sum(all_rank_num_tokens)
|
||||
else:
|
||||
num_rows = x.shape[0]
|
||||
assert all_rank_num_tokens is not None
|
||||
assert use_dp_padding is not None
|
||||
num_rows = sum(all_rank_num_tokens)
|
||||
|
||||
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
|
||||
num_chunks = (num_rows + self.moe_max_num_tokens -
|
||||
1) // self.moe_max_num_tokens
|
||||
# TODO: remove cutlass_min_latency_mode since it is not used anymore
|
||||
cutlass_min_latency_mode = not do_finalize
|
||||
|
||||
if cutlass_min_latency_mode:
|
||||
assert num_chunks == 1 and (
|
||||
not self.reduce_results
|
||||
), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False"
|
||||
|
||||
if use_dp_padding:
|
||||
all_rank_num_tokens_padded = [all_rank_max_num_tokens
|
||||
@ -731,7 +688,6 @@ class WideEPMoE(MoE):
|
||||
outputs = self.forward_chunk(
|
||||
x,
|
||||
router_logits,
|
||||
cutlass_min_latency_mode,
|
||||
output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens_padded,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
@ -750,29 +706,24 @@ class WideEPMoE(MoE):
|
||||
split_num_chunks - val_mod)
|
||||
return split_chunk_size_list
|
||||
|
||||
if self.use_dp:
|
||||
all_rank_chunk_size_list = [
|
||||
split_chunk(val, num_chunks)
|
||||
for val in all_rank_num_tokens_padded
|
||||
]
|
||||
all_rank_chunk_size_list = [
|
||||
split_chunk(val, num_chunks)
|
||||
for val in all_rank_num_tokens_padded
|
||||
]
|
||||
all_rank_num_tokens_list = [[
|
||||
val[idx_chunk] for val in all_rank_chunk_size_list
|
||||
] for idx_chunk in range(num_chunks)]
|
||||
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
|
||||
num_chunks)
|
||||
chunk_size_list = all_rank_chunk_size_list[self.rank]
|
||||
if self.enable_alltoall:
|
||||
all_rank_num_tokens_list = [[
|
||||
val[idx_chunk] for val in all_rank_chunk_size_list
|
||||
] for idx_chunk in range(num_chunks)]
|
||||
all_rank_max_num_tokens_list = split_chunk(
|
||||
all_rank_max_num_tokens, num_chunks)
|
||||
chunk_size_list = all_rank_chunk_size_list[self.rank]
|
||||
if self.enable_alltoall:
|
||||
all_rank_num_tokens_list = [[
|
||||
1 if val == 0 else val for val in val_list
|
||||
] for val_list in all_rank_num_tokens_list]
|
||||
all_rank_max_num_tokens_list = [
|
||||
1 if val == 0 else val
|
||||
for val in all_rank_max_num_tokens_list
|
||||
]
|
||||
else:
|
||||
all_rank_num_tokens_list = [None] * num_chunks
|
||||
all_rank_max_num_tokens_list = [None] * num_chunks
|
||||
chunk_size_list = split_chunk(x.shape[0], num_chunks)
|
||||
1 if val == 0 else val for val in val_list
|
||||
] for val_list in all_rank_num_tokens_list]
|
||||
all_rank_max_num_tokens_list = [
|
||||
1 if val == 0 else val
|
||||
for val in all_rank_max_num_tokens_list
|
||||
]
|
||||
|
||||
x_list = x.split(chunk_size_list)
|
||||
router_logits_list = router_logits.split(chunk_size_list)
|
||||
@ -795,10 +746,9 @@ class WideEPMoE(MoE):
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
idx_chunk],
|
||||
all_rank_max_num_tokens=
|
||||
all_rank_max_num_tokens_list[idx_chunk]
|
||||
if self.use_dp else None,
|
||||
all_rank_max_num_tokens_list[idx_chunk],
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
if idx_chunk > 0:
|
||||
@ -812,9 +762,9 @@ class WideEPMoE(MoE):
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
idx_chunk],
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
idx_chunk],
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
@ -827,10 +777,9 @@ class WideEPMoE(MoE):
|
||||
outputs = self.forward_chunk(
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk]
|
||||
if self.use_dp else None,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
idx_chunk],
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
|
||||
outputs_list.append(outputs)
|
||||
@ -850,9 +799,8 @@ class WideEPMoE(MoE):
|
||||
self.event_dict[EventType.MoeChunkingOverlap].record()
|
||||
self.event_dict[EventType.MoeChunkingOverlap].wait()
|
||||
outputs = torch.cat(outputs_list)
|
||||
if self.use_dp:
|
||||
rank = self.mapping.tp_rank
|
||||
outputs = outputs[:all_rank_num_tokens[rank]]
|
||||
rank = self.mapping.tp_rank
|
||||
outputs = outputs[:all_rank_num_tokens[rank]]
|
||||
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
|
||||
return outputs
|
||||
|
||||
@ -862,8 +810,8 @@ class WideEPMoE(MoE):
|
||||
token_final_scales: torch.Tensor,
|
||||
local_statistic_tensor: Optional[torch.Tensor]):
|
||||
top_k = self.routing_method.experts_per_token
|
||||
# TODO: support alltoall without allgather for top_k % 4 != 0
|
||||
if self.enable_alltoall_without_allgather and top_k % 4 == 0:
|
||||
|
||||
if self.enable_alltoall_without_allgather:
|
||||
alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_slots, token_final_scales,
|
||||
local_statistic_tensor, self.alltoall_prepare_workspace,
|
||||
@ -913,22 +861,13 @@ class WideEPMoE(MoE):
|
||||
|
||||
return x, token_selected_slots, token_final_scales, gathered_local_statistic_tensor, alltoall_info
|
||||
|
||||
def alltoall_postquant_dispatch(self,
|
||||
x: torch.Tensor,
|
||||
x_sf: torch.Tensor,
|
||||
x_row: int,
|
||||
x_col: int,
|
||||
alltoall_info: MoEAlltoallInfo,
|
||||
is_sf_swizzle: bool = True):
|
||||
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo):
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
|
||||
self.alltoall_workspace, self.ep_rank,
|
||||
self.ep_size)
|
||||
|
||||
if x_sf is not None:
|
||||
if self.has_nvfp4 and is_sf_swizzle:
|
||||
x_sf = unswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank, self.ep_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user