[Perf] Optimize moe permute by pre-allocate buffer, 9~14% kernel performance improvement (#43014)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-05-28 09:18:26 -04:00
committed by GitHub
parent 02606b0b09
commit 64e1218673
8 changed files with 480 additions and 92 deletions
@@ -10,6 +10,7 @@ from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
MoEPermuteScratch,
moe_permute,
moe_unpermute,
)
@@ -54,6 +55,15 @@ def benchmark_permute(
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False
)
scratch = MoEPermuteScratch(
max_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
num_local_experts=num_experts,
device=qhidden_states.device,
hidden_size=hidden_size,
hidden_dtype=qhidden_states.dtype,
)
def prepare(i: int):
input_gating.copy_(gating_output[i])
@@ -65,6 +75,7 @@ def benchmark_permute(
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
scratch=scratch,
)
# JIT compilation & warmup
@@ -123,6 +134,15 @@ def benchmark_unpermute(
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False
)
scratch = MoEPermuteScratch(
max_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
num_local_experts=num_experts,
device=qhidden_states.device,
hidden_size=hidden_size,
hidden_dtype=qhidden_states.dtype,
)
def prepare():
(
@@ -137,6 +157,7 @@ def benchmark_unpermute(
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
scratch=scratch,
)
# convert to fp16/bf16 as gemm output
return (
+3
View File
@@ -62,6 +62,9 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
bool moe_permute_unpermute_supported();
int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows,
int64_t num_experts);
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);
+141 -58
View File
@@ -8,6 +8,108 @@
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
namespace {
torch::Tensor maybe_allocate_tensor(
const std::optional<torch::Tensor>& maybe_tensor,
at::IntArrayRef expected_sizes, torch::ScalarType dtype, c10::Device device,
char const* name) {
auto expected_numel = c10::multiply_integers(expected_sizes);
if (maybe_tensor.has_value()) {
auto tensor = maybe_tensor.value();
TORCH_CHECK(tensor.device() == device, name, " must be on the same device");
TORCH_CHECK(tensor.scalar_type() == dtype, name, " has incorrect dtype");
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
TORCH_CHECK(tensor.numel() >= expected_numel, name,
" is too small for the requested shape");
auto flat_tensor = tensor.view({tensor.numel()});
return flat_tensor.narrow(0, 0, expected_numel).view(expected_sizes);
}
return torch::empty(expected_sizes, torch::dtype(dtype).device(device));
}
} // namespace
int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows,
int64_t n_expert) {
return static_cast<int64_t>(
CubKeyValueSorter::getWorkspaceSize(num_expanded_rows, n_expert));
}
void moe_permute_impl(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
torch::Tensor& permuted_input, // [permuted_size, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size]
const std::optional<torch::Tensor>& maybe_sort_workspace,
const std::optional<torch::Tensor>& maybe_permuted_experts_id,
const std::optional<torch::Tensor>& maybe_sorted_row_idx,
const std::optional<torch::Tensor>& maybe_topk_ids_for_sort) {
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
"token_expert_indices must be int32");
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
"inv_permuted_idx must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1");
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
"token_expert_indices shape must be same as inv_permuted_idx");
auto device = input.device();
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto expanded_rows = n_token * topk;
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto sorter_size = moe_permute_sort_workspace_size(expanded_rows, n_expert);
auto sort_workspace =
maybe_allocate_tensor(maybe_sort_workspace, {sorter_size}, torch::kInt8,
device, "sort_workspace");
auto permuted_experts_id =
maybe_allocate_tensor(maybe_permuted_experts_id, topk_ids.sizes(),
at::ScalarType::Int, device, "permuted_experts_id");
auto sorted_row_idx =
maybe_allocate_tensor(maybe_sorted_row_idx, inv_permuted_idx.sizes(),
at::ScalarType::Int, device, "sorted_row_idx");
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
torch::Tensor topk_ids_for_sort = topk_ids;
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
topk_ids_for_sort =
maybe_allocate_tensor(maybe_topk_ids_for_sort, topk_ids.sizes(),
at::ScalarType::Int, device, "topk_ids_for_sort");
topk_ids_for_sort.copy_(topk_ids);
preprocessTopkIdLauncher(get_ptr<int>(topk_ids_for_sort), n_token * topk,
expert_map_ptr, n_expert, stream);
}
sortAndScanExpert(
get_ptr<const int>(topk_ids_for_sort), get_ptr<int>(token_expert_indices),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
});
}
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_ids, // [n_token, topk]
@@ -18,65 +120,26 @@ void moe_permute(
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx) { // [permute_size]
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
"token_expert_indices must be int32");
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
"inv_permuted_idx must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
"token_expert_indices shape must be same as inv_permuted_idx");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
torch::Tensor topk_ids_for_sort = topk_ids;
auto permuted_experts_id = torch::empty_like(topk_ids);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
moe_permute_impl(input, topk_ids, token_expert_indices, expert_map, n_expert,
n_local_expert, topk, permuted_input,
expert_first_token_offset, inv_permuted_idx, permuted_idx,
std::nullopt, std::nullopt, std::nullopt, std::nullopt);
}
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
topk_ids_for_sort = topk_ids.clone();
preprocessTopkIdLauncher(get_ptr<int>(topk_ids_for_sort), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(
get_ptr<const int>(topk_ids_for_sort), get_ptr<int>(token_expert_indices),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
});
void moe_permute_with_scratch(
const torch::Tensor& input, const torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map, int64_t n_expert,
int64_t n_local_expert, int64_t topk, torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset, torch::Tensor& inv_permuted_idx,
torch::Tensor& permuted_idx, torch::Tensor& sort_workspace,
torch::Tensor& permuted_experts_id, torch::Tensor& sorted_row_idx,
torch::Tensor& topk_ids_for_sort) {
moe_permute_impl(input, topk_ids, token_expert_indices, expert_map, n_expert,
n_local_expert, topk, permuted_input,
expert_first_token_offset, inv_permuted_idx, permuted_idx,
sort_workspace, permuted_experts_id, sorted_row_idx,
topk_ids_for_sort);
}
void moe_unpermute(
@@ -169,6 +232,12 @@ void shuffle_rows(const torch::Tensor& input_tensor,
#else
int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows,
int64_t n_expert) {
TORCH_CHECK(
false, "moe_permute_sort_workspace_size is not supported on CUDA < 12.0");
}
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map,
@@ -179,6 +248,19 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
}
void moe_permute_with_scratch(
const torch::Tensor& input, const torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map, int64_t n_expert,
int64_t n_local_expert, int64_t topk, torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset, torch::Tensor& inv_permuted_idx,
torch::Tensor& permuted_idx, torch::Tensor& sort_workspace,
torch::Tensor& permuted_experts_id, torch::Tensor& sorted_row_idx,
torch::Tensor& topk_ids_for_sort) {
TORCH_CHECK(false,
"moe_permute_with_scratch is not supported on CUDA < 12.0");
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states,
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
@@ -199,5 +281,6 @@ bool moe_permute_unpermute_supported() {
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_permute_with_scratch", &moe_permute_with_scratch);
m.impl("moe_unpermute", &moe_unpermute);
}
+13
View File
@@ -100,13 +100,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx)->()");
m.def(
"moe_permute_with_scratch(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx, Tensor! sort_workspace, Tensor! permuted_experts_id, "
"Tensor! sorted_row_idx, Tensor! topk_ids_for_sort)->()");
m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
"int topk, Tensor! hidden_states)->()");
m.def("moe_permute_unpermute_supported() -> bool");
m.def(
"moe_permute_sort_workspace_size(int num_expanded_rows, int n_expert) -> "
"int");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
m.impl("moe_permute_sort_workspace_size", &moe_permute_sort_workspace_size);
// Row shuffle for MoE
m.def(
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.expert_map_manager import (
determine_expert_map,
)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
MoEPermuteScratch,
moe_permute,
moe_permute_unpermute_supported,
moe_unpermute,
@@ -209,3 +210,79 @@ def test_moe_permute_unpermute(
)
# check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_moe_permute_reuses_scratch_buffers(dtype: torch.dtype):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
n_token = 64
n_hidden = 2048
n_expert = 16
topk = 4
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
_, topk_ids, _ = fused_topk(hidden_states, gating_output, topk, False)
scratch = MoEPermuteScratch(
max_num_tokens=n_token,
topk=topk,
num_experts=n_expert,
num_local_experts=n_expert,
device=hidden_states.device,
hidden_size=n_hidden,
hidden_dtype=hidden_states.dtype,
)
first = moe_permute(
hidden_states=hidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=n_expert,
scratch=scratch,
)
second = moe_permute(
hidden_states=hidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=n_expert,
scratch=scratch,
)
(
permuted_hidden_states_1,
_,
expert_first_token_offset_1,
inv_permuted_idx_1,
permuted_idx_1,
) = first
(
permuted_hidden_states_2,
_,
expert_first_token_offset_2,
inv_permuted_idx_2,
permuted_idx_2,
) = second
torch.testing.assert_close(permuted_hidden_states_1, permuted_hidden_states_2)
torch.testing.assert_close(expert_first_token_offset_1, expert_first_token_offset_2)
torch.testing.assert_close(inv_permuted_idx_1, inv_permuted_idx_2)
torch.testing.assert_close(permuted_idx_1, permuted_idx_2)
assert (
permuted_hidden_states_1.untyped_storage().data_ptr()
== permuted_hidden_states_2.untyped_storage().data_ptr()
)
assert (
expert_first_token_offset_1.untyped_storage().data_ptr()
== expert_first_token_offset_2.untyped_storage().data_ptr()
)
assert (
inv_permuted_idx_1.untyped_storage().data_ptr()
== scratch.inv_permuted_idx.untyped_storage().data_ptr()
)
assert (
permuted_idx_1.untyped_storage().data_ptr()
== scratch.permuted_idx.untyped_storage().data_ptr()
)
@@ -17,7 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
MoEPermuteScratch,
moe_permute,
moe_permute_unpermute_supported,
moe_unpermute,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -73,6 +75,7 @@ def run_cutlass_moe_fp8(
per_out_ch: bool,
use_batched_format: bool,
topk_weights: torch.Tensor | None,
permute_scratch: MoEPermuteScratch | None,
):
a1q = hidden_states
@@ -198,6 +201,7 @@ def run_cutlass_moe_fp8(
local_E,
expert_map,
permuted_hidden_states=a1q_perm,
scratch=permute_scratch,
)
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
swap_ab = a1q.size(0) <= 64
@@ -291,6 +295,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = ab_strides1_c_strides2
self._permute_scratch: MoEPermuteScratch | None = None
@staticmethod
def _supports_current_device() -> bool:
@@ -324,6 +329,17 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def _get_permute_scratch(self) -> MoEPermuteScratch | None:
if self._permute_scratch is None and moe_permute_unpermute_supported():
self._permute_scratch = MoEPermuteScratch(
max_num_tokens=self.moe_config.max_num_tokens,
topk=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts,
device=torch.device(self.moe_config.device),
)
return self._permute_scratch
def apply(
self,
output: torch.Tensor,
@@ -379,6 +395,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
self.per_out_ch_quant,
use_batched_format,
topk_weights,
self._get_permute_scratch(),
)
@@ -1121,6 +1138,7 @@ def run_cutlass_moe_w4a8_fp8(
use_batched_format: bool,
topk_weights: torch.Tensor | None,
group_size: int,
permute_scratch: MoEPermuteScratch | None,
):
a1q = hidden_states
M = a1q.size(0)
@@ -1176,6 +1194,7 @@ def run_cutlass_moe_w4a8_fp8(
local_E,
expert_map,
permuted_hidden_states=a1q_perm,
scratch=permute_scratch,
)
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
@@ -1266,6 +1285,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
self.s_strides2[:, 0] = k
self.group_size = group_size
self._permute_scratch: MoEPermuteScratch | None = None
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
@@ -1330,6 +1350,17 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def _get_permute_scratch(self) -> MoEPermuteScratch | None:
if self._permute_scratch is None and moe_permute_unpermute_supported():
self._permute_scratch = MoEPermuteScratch(
max_num_tokens=self.moe_config.max_num_tokens,
topk=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts,
device=torch.device(self.moe_config.device),
)
return self._permute_scratch
def workspace_shapes(
self,
M: int,
@@ -1409,4 +1440,5 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
use_batched_format,
topk_weights,
self.group_size,
self._get_permute_scratch(),
)
@@ -26,7 +26,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
)
from vllm.model_executor.layers.fused_moe.moe_fused_mul_sum import moe_fused_mul_sum
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
MoEPermuteScratch,
moe_permute,
moe_permute_unpermute_supported,
moe_unpermute,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -85,6 +87,7 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
self._permute_scratch: MoEPermuteScratch | None = None
def init_humming_moe(self):
self.compute_config = {
@@ -110,6 +113,19 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
self.w13_tuning_config_str = json.dumps(self.w13_tuning_config)
self.w2_tuning_config_str = json.dumps(self.w2_tuning_config)
def _get_permute_scratch(self) -> MoEPermuteScratch | None:
if self._permute_scratch is None and moe_permute_unpermute_supported():
self._permute_scratch = MoEPermuteScratch(
max_num_tokens=self.moe_config.max_num_tokens,
topk=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts,
device=torch.device(self.moe_config.device),
hidden_size=self.moe_config.hidden_dim,
hidden_dtype=self.moe_config.in_dtype,
)
return self._permute_scratch
def get_global_valid_shape_m(self, topk_ids: torch.Tensor):
num_tokens = topk_ids.size(0)
ctx = get_forward_context()
@@ -598,6 +614,7 @@ class HummingGroupedExperts(HummingExpertsBase):
n_expert=self.global_num_experts,
n_local_expert=self.num_experts,
expert_map=self.layer.expert_map,
scratch=self._get_permute_scratch(),
)
inputs, input_scale = HummingMethod.may_quant_input(
@@ -1,9 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import torch
@dataclass
class MoEPermuteScratch:
# Reused metadata buffers for repeated grouped-MoE permutes.
max_num_tokens: int
topk: int
num_experts: int
num_local_experts: int
device: torch.device
hidden_size: int | None = None
hidden_dtype: torch.dtype | None = None
token_expert_indices: torch.Tensor = field(init=False)
expert_first_token_offset: torch.Tensor = field(init=False)
permuted_idx: torch.Tensor = field(init=False)
inv_permuted_idx: torch.Tensor = field(init=False)
permuted_hidden_states: torch.Tensor | None = field(init=False, default=None)
sort_workspace: torch.Tensor = field(init=False)
permuted_experts_id: torch.Tensor = field(init=False)
sorted_row_idx: torch.Tensor = field(init=False)
topk_ids_int32: torch.Tensor = field(init=False)
topk_ids_for_sort: torch.Tensor = field(init=False)
max_expanded_rows: int = field(init=False)
def __post_init__(self) -> None:
assert self.max_num_tokens > 0
assert self.topk > 0
assert self.num_experts > 0
assert self.num_local_experts > 0
if self.hidden_size is None:
assert self.hidden_dtype is None
else:
assert self.hidden_dtype is not None
self.max_expanded_rows = self.max_num_tokens * self.topk
self.token_expert_indices = torch.arange(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
self.expert_first_token_offset = torch.empty(
self.num_local_experts + 1, dtype=torch.int64, device=self.device
)
self.permuted_idx = torch.empty(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
self.inv_permuted_idx = torch.empty(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
if self.hidden_size is not None:
hidden_numel = self.max_expanded_rows * self.hidden_size
self.permuted_hidden_states = torch.empty(
hidden_numel, dtype=self.hidden_dtype, device=self.device
)
self.permuted_experts_id = torch.empty(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
self.sorted_row_idx = torch.empty(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
self.topk_ids_int32 = torch.empty(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
self.topk_ids_for_sort = torch.empty(
self.max_expanded_rows, dtype=torch.int32, device=self.device
)
sorter_size = torch.ops._moe_C.moe_permute_sort_workspace_size(
self.max_expanded_rows, self.num_experts
)
self.sort_workspace = torch.empty(
sorter_size, dtype=torch.int8, device=self.device
)
def validate(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor) -> None:
n_token, n_hidden = hidden_states.shape
assert hidden_states.device == self.device
assert topk_ids.device == self.device
assert n_token <= self.max_num_tokens
assert topk_ids.size(1) == self.topk
assert topk_ids.size(0) == n_token
if self.hidden_size is not None:
assert n_hidden == self.hidden_size
assert hidden_states.dtype == self.hidden_dtype
assert self.permuted_hidden_states is not None
def token_expert_indices_view(self, n_token: int) -> torch.Tensor:
return self.token_expert_indices[: n_token * self.topk].view(n_token, self.topk)
def prepare_topk_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
if topk_ids.dtype == torch.int32:
return topk_ids
numel = topk_ids.numel()
topk_ids_int32 = self.topk_ids_int32[:numel].view_as(topk_ids)
topk_ids_int32.copy_(topk_ids)
return topk_ids_int32
def moe_permute(
hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
@@ -12,6 +107,7 @@ def moe_permute(
n_local_expert: int = -1,
expert_map: torch.Tensor | None = None,
permuted_hidden_states: torch.Tensor | None = None,
scratch: MoEPermuteScratch | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function expands and permutes activation to gather uncontinuous tokens
@@ -45,46 +141,92 @@ def moe_permute(
if n_local_expert == -1:
n_local_expert = n_expert
if permuted_hidden_states is None:
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if scratch is None:
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
else:
scratch.validate(hidden_states, topk_ids)
hidden_numel = permuted_row_size * n_hidden
scratch_hidden_states = scratch.permuted_hidden_states
assert scratch_hidden_states is not None
permuted_hidden_states = scratch_hidden_states[:hidden_numel].view(
permuted_row_size, n_hidden
)
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
f" but got {permuted_hidden_states.size()}"
)
token_expert_indices = torch.arange(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk))
if scratch is None:
token_expert_indices = torch.arange(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk))
expert_first_token_offset = torch.empty(
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
inv_permuted_idx = torch.empty(
(n_token, topk), dtype=torch.int32, device=hidden_states.device
)
topk_ids = topk_ids.to(torch.int32)
torch.ops._moe_C.moe_permute(
hidden_states,
topk_ids,
token_expert_indices,
expert_map,
n_expert,
n_local_expert,
topk,
permuted_hidden_states,
expert_first_token_offset,
inv_permuted_idx,
permuted_idx,
)
expert_first_token_offset = torch.empty(
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
inv_permuted_idx = torch.empty(
(n_token, topk), dtype=torch.int32, device=hidden_states.device
)
topk_ids_int32 = topk_ids.to(torch.int32)
torch.ops._moe_C.moe_permute(
hidden_states,
topk_ids_int32,
token_expert_indices,
expert_map,
n_expert,
n_local_expert,
topk,
permuted_hidden_states,
expert_first_token_offset,
inv_permuted_idx,
permuted_idx,
)
else:
scratch.validate(hidden_states, topk_ids)
assert n_expert == scratch.num_experts
assert n_local_expert == scratch.num_local_experts
token_expert_indices = scratch.token_expert_indices_view(n_token)
expert_first_token_offset = scratch.expert_first_token_offset
permuted_idx = scratch.permuted_idx[:permuted_row_size]
permuted_idx.fill_(permuted_row_size)
inv_permuted_idx = scratch.inv_permuted_idx[:permuted_row_size].view(
n_token, topk
)
permuted_experts_id = scratch.permuted_experts_id[:permuted_row_size].view(
n_token, topk
)
sorted_row_idx = scratch.sorted_row_idx[:permuted_row_size].view(n_token, topk)
topk_ids_for_sort = scratch.topk_ids_for_sort[:permuted_row_size].view(
n_token, topk
)
topk_ids_int32 = scratch.prepare_topk_ids(topk_ids)
torch.ops._moe_C.moe_permute_with_scratch(
hidden_states,
topk_ids_int32,
token_expert_indices,
expert_map,
n_expert,
n_local_expert,
topk,
permuted_hidden_states,
expert_first_token_offset,
inv_permuted_idx,
permuted_idx,
scratch.sort_workspace,
permuted_experts_id,
sorted_row_idx,
topk_ids_for_sort,
)
if a1q_scale is not None and a1q_scale.dim() > 1:
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk]