mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] Optimize moe permute by pre-allocate buffer, 9~14% kernel performance improvement (#43014)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
MoEPermuteScratch,
|
||||
moe_permute,
|
||||
moe_unpermute,
|
||||
)
|
||||
@@ -54,6 +55,15 @@ def benchmark_permute(
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False
|
||||
)
|
||||
scratch = MoEPermuteScratch(
|
||||
max_num_tokens=num_tokens,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_experts,
|
||||
device=qhidden_states.device,
|
||||
hidden_size=hidden_size,
|
||||
hidden_dtype=qhidden_states.dtype,
|
||||
)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
@@ -65,6 +75,7 @@ def benchmark_permute(
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
scratch=scratch,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
@@ -123,6 +134,15 @@ def benchmark_unpermute(
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False
|
||||
)
|
||||
scratch = MoEPermuteScratch(
|
||||
max_num_tokens=num_tokens,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_experts,
|
||||
device=qhidden_states.device,
|
||||
hidden_size=hidden_size,
|
||||
hidden_dtype=qhidden_states.dtype,
|
||||
)
|
||||
|
||||
def prepare():
|
||||
(
|
||||
@@ -137,6 +157,7 @@ def benchmark_unpermute(
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
scratch=scratch,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
|
||||
@@ -62,6 +62,9 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
|
||||
bool moe_permute_unpermute_supported();
|
||||
|
||||
int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows,
|
||||
int64_t num_experts);
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
const torch::Tensor& dst2src_map,
|
||||
torch::Tensor& output_tensor);
|
||||
|
||||
@@ -8,6 +8,108 @@
|
||||
// moe_permute kernels require at least CUDA 12.0
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||
|
||||
namespace {
|
||||
|
||||
torch::Tensor maybe_allocate_tensor(
|
||||
const std::optional<torch::Tensor>& maybe_tensor,
|
||||
at::IntArrayRef expected_sizes, torch::ScalarType dtype, c10::Device device,
|
||||
char const* name) {
|
||||
auto expected_numel = c10::multiply_integers(expected_sizes);
|
||||
if (maybe_tensor.has_value()) {
|
||||
auto tensor = maybe_tensor.value();
|
||||
TORCH_CHECK(tensor.device() == device, name, " must be on the same device");
|
||||
TORCH_CHECK(tensor.scalar_type() == dtype, name, " has incorrect dtype");
|
||||
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
|
||||
TORCH_CHECK(tensor.numel() >= expected_numel, name,
|
||||
" is too small for the requested shape");
|
||||
auto flat_tensor = tensor.view({tensor.numel()});
|
||||
return flat_tensor.narrow(0, 0, expected_numel).view(expected_sizes);
|
||||
}
|
||||
return torch::empty(expected_sizes, torch::dtype(dtype).device(device));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows,
|
||||
int64_t n_expert) {
|
||||
return static_cast<int64_t>(
|
||||
CubKeyValueSorter::getWorkspaceSize(num_expanded_rows, n_expert));
|
||||
}
|
||||
|
||||
void moe_permute_impl(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& token_expert_indices, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
torch::Tensor& permuted_input, // [permuted_size, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
torch::Tensor& permuted_idx, // [permute_size]
|
||||
const std::optional<torch::Tensor>& maybe_sort_workspace,
|
||||
const std::optional<torch::Tensor>& maybe_permuted_experts_id,
|
||||
const std::optional<torch::Tensor>& maybe_sorted_row_idx,
|
||||
const std::optional<torch::Tensor>& maybe_topk_ids_for_sort) {
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indices must be int32");
|
||||
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
|
||||
"inv_permuted_idx must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1");
|
||||
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as inv_permuted_idx");
|
||||
auto device = input.device();
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto expanded_rows = n_token * topk;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto sorter_size = moe_permute_sort_workspace_size(expanded_rows, n_expert);
|
||||
auto sort_workspace =
|
||||
maybe_allocate_tensor(maybe_sort_workspace, {sorter_size}, torch::kInt8,
|
||||
device, "sort_workspace");
|
||||
auto permuted_experts_id =
|
||||
maybe_allocate_tensor(maybe_permuted_experts_id, topk_ids.sizes(),
|
||||
at::ScalarType::Int, device, "permuted_experts_id");
|
||||
auto sorted_row_idx =
|
||||
maybe_allocate_tensor(maybe_sorted_row_idx, inv_permuted_idx.sizes(),
|
||||
at::ScalarType::Int, device, "sorted_row_idx");
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
torch::Tensor topk_ids_for_sort = topk_ids;
|
||||
|
||||
if (expert_map.has_value()) {
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
topk_ids_for_sort =
|
||||
maybe_allocate_tensor(maybe_topk_ids_for_sort, topk_ids.sizes(),
|
||||
at::ScalarType::Int, device, "topk_ids_for_sort");
|
||||
topk_ids_for_sort.copy_(topk_ids);
|
||||
preprocessTopkIdLauncher(get_ptr<int>(topk_ids_for_sort), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
|
||||
sortAndScanExpert(
|
||||
get_ptr<const int>(topk_ids_for_sort), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
|
||||
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
|
||||
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
|
||||
});
|
||||
}
|
||||
|
||||
void moe_permute(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
@@ -18,65 +120,26 @@ void moe_permute(
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
torch::Tensor& permuted_idx) { // [permute_size]
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indices must be int32");
|
||||
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
|
||||
"inv_permuted_idx must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1")
|
||||
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as inv_permuted_idx");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const long sorter_size =
|
||||
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
|
||||
auto sort_workspace = torch::empty(
|
||||
{sorter_size},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
torch::Tensor topk_ids_for_sort = topk_ids;
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
moe_permute_impl(input, topk_ids, token_expert_indices, expert_map, n_expert,
|
||||
n_local_expert, topk, permuted_input,
|
||||
expert_first_token_offset, inv_permuted_idx, permuted_idx,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
// pre-process kernel for expert-parallelism:
|
||||
// no local expert id plus "n_expert" offset for priority to local expert
|
||||
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
|
||||
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
|
||||
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
|
||||
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
|
||||
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
|
||||
// operation is to make local expert high priority in following sort topk_ids
|
||||
// and scan local expert_first_token_offset for each ep rank for next group
|
||||
// gemm.
|
||||
if (expert_map.has_value()) {
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
topk_ids_for_sort = topk_ids.clone();
|
||||
preprocessTopkIdLauncher(get_ptr<int>(topk_ids_for_sort), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
// expert sort topk expert id and scan expert id get expert_first_token_offset
|
||||
sortAndScanExpert(
|
||||
get_ptr<const int>(topk_ids_for_sort), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
|
||||
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
|
||||
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
|
||||
});
|
||||
void moe_permute_with_scratch(
|
||||
const torch::Tensor& input, const torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map, int64_t n_expert,
|
||||
int64_t n_local_expert, int64_t topk, torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset, torch::Tensor& inv_permuted_idx,
|
||||
torch::Tensor& permuted_idx, torch::Tensor& sort_workspace,
|
||||
torch::Tensor& permuted_experts_id, torch::Tensor& sorted_row_idx,
|
||||
torch::Tensor& topk_ids_for_sort) {
|
||||
moe_permute_impl(input, topk_ids, token_expert_indices, expert_map, n_expert,
|
||||
n_local_expert, topk, permuted_input,
|
||||
expert_first_token_offset, inv_permuted_idx, permuted_idx,
|
||||
sort_workspace, permuted_experts_id, sorted_row_idx,
|
||||
topk_ids_for_sort);
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
@@ -169,6 +232,12 @@ void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
|
||||
#else
|
||||
|
||||
int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows,
|
||||
int64_t n_expert) {
|
||||
TORCH_CHECK(
|
||||
false, "moe_permute_sort_workspace_size is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
@@ -179,6 +248,19 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_permute_with_scratch(
|
||||
const torch::Tensor& input, const torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map, int64_t n_expert,
|
||||
int64_t n_local_expert, int64_t topk, torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset, torch::Tensor& inv_permuted_idx,
|
||||
torch::Tensor& permuted_idx, torch::Tensor& sort_workspace,
|
||||
torch::Tensor& permuted_experts_id, torch::Tensor& sorted_row_idx,
|
||||
torch::Tensor& topk_ids_for_sort) {
|
||||
TORCH_CHECK(false,
|
||||
"moe_permute_with_scratch is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
@@ -199,5 +281,6 @@ bool moe_permute_unpermute_supported() {
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_permute_with_scratch", &moe_permute_with_scratch);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
@@ -100,13 +100,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
|
||||
"permuted_idx)->()");
|
||||
|
||||
m.def(
|
||||
"moe_permute_with_scratch(Tensor input, Tensor topk_ids,"
|
||||
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
|
||||
"permuted_idx, Tensor! sort_workspace, Tensor! permuted_experts_id, "
|
||||
"Tensor! sorted_row_idx, Tensor! topk_ids_for_sort)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
|
||||
"int topk, Tensor! hidden_states)->()");
|
||||
|
||||
m.def("moe_permute_unpermute_supported() -> bool");
|
||||
m.def(
|
||||
"moe_permute_sort_workspace_size(int num_expanded_rows, int n_expert) -> "
|
||||
"int");
|
||||
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
||||
m.impl("moe_permute_sort_workspace_size", &moe_permute_sort_workspace_size);
|
||||
|
||||
// Row shuffle for MoE
|
||||
m.def(
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.expert_map_manager import (
|
||||
determine_expert_map,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
MoEPermuteScratch,
|
||||
moe_permute,
|
||||
moe_permute_unpermute_supported,
|
||||
moe_unpermute,
|
||||
@@ -209,3 +210,79 @@ def test_moe_permute_unpermute(
|
||||
)
|
||||
# check unpermuted hidden
|
||||
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_moe_permute_reuses_scratch_buffers(dtype: torch.dtype):
|
||||
if not moe_permute_unpermute_supported():
|
||||
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||
|
||||
n_token = 64
|
||||
n_hidden = 2048
|
||||
n_expert = 16
|
||||
topk = 4
|
||||
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
|
||||
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
|
||||
_, topk_ids, _ = fused_topk(hidden_states, gating_output, topk, False)
|
||||
|
||||
scratch = MoEPermuteScratch(
|
||||
max_num_tokens=n_token,
|
||||
topk=topk,
|
||||
num_experts=n_expert,
|
||||
num_local_experts=n_expert,
|
||||
device=hidden_states.device,
|
||||
hidden_size=n_hidden,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
first = moe_permute(
|
||||
hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
scratch=scratch,
|
||||
)
|
||||
second = moe_permute(
|
||||
hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
scratch=scratch,
|
||||
)
|
||||
|
||||
(
|
||||
permuted_hidden_states_1,
|
||||
_,
|
||||
expert_first_token_offset_1,
|
||||
inv_permuted_idx_1,
|
||||
permuted_idx_1,
|
||||
) = first
|
||||
(
|
||||
permuted_hidden_states_2,
|
||||
_,
|
||||
expert_first_token_offset_2,
|
||||
inv_permuted_idx_2,
|
||||
permuted_idx_2,
|
||||
) = second
|
||||
|
||||
torch.testing.assert_close(permuted_hidden_states_1, permuted_hidden_states_2)
|
||||
torch.testing.assert_close(expert_first_token_offset_1, expert_first_token_offset_2)
|
||||
torch.testing.assert_close(inv_permuted_idx_1, inv_permuted_idx_2)
|
||||
torch.testing.assert_close(permuted_idx_1, permuted_idx_2)
|
||||
|
||||
assert (
|
||||
permuted_hidden_states_1.untyped_storage().data_ptr()
|
||||
== permuted_hidden_states_2.untyped_storage().data_ptr()
|
||||
)
|
||||
assert (
|
||||
expert_first_token_offset_1.untyped_storage().data_ptr()
|
||||
== expert_first_token_offset_2.untyped_storage().data_ptr()
|
||||
)
|
||||
assert (
|
||||
inv_permuted_idx_1.untyped_storage().data_ptr()
|
||||
== scratch.inv_permuted_idx.untyped_storage().data_ptr()
|
||||
)
|
||||
assert (
|
||||
permuted_idx_1.untyped_storage().data_ptr()
|
||||
== scratch.permuted_idx.untyped_storage().data_ptr()
|
||||
)
|
||||
|
||||
@@ -17,7 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
MoEPermuteScratch,
|
||||
moe_permute,
|
||||
moe_permute_unpermute_supported,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@@ -73,6 +75,7 @@ def run_cutlass_moe_fp8(
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool,
|
||||
topk_weights: torch.Tensor | None,
|
||||
permute_scratch: MoEPermuteScratch | None,
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
@@ -198,6 +201,7 @@ def run_cutlass_moe_fp8(
|
||||
local_E,
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
scratch=permute_scratch,
|
||||
)
|
||||
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
|
||||
swap_ab = a1q.size(0) <= 64
|
||||
@@ -291,6 +295,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = ab_strides1_c_strides2
|
||||
self._permute_scratch: MoEPermuteScratch | None = None
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
@@ -324,6 +329,17 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def _get_permute_scratch(self) -> MoEPermuteScratch | None:
|
||||
if self._permute_scratch is None and moe_permute_unpermute_supported():
|
||||
self._permute_scratch = MoEPermuteScratch(
|
||||
max_num_tokens=self.moe_config.max_num_tokens,
|
||||
topk=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
device=torch.device(self.moe_config.device),
|
||||
)
|
||||
return self._permute_scratch
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@@ -379,6 +395,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
self.per_out_ch_quant,
|
||||
use_batched_format,
|
||||
topk_weights,
|
||||
self._get_permute_scratch(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1121,6 +1138,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
use_batched_format: bool,
|
||||
topk_weights: torch.Tensor | None,
|
||||
group_size: int,
|
||||
permute_scratch: MoEPermuteScratch | None,
|
||||
):
|
||||
a1q = hidden_states
|
||||
M = a1q.size(0)
|
||||
@@ -1176,6 +1194,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
local_E,
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
scratch=permute_scratch,
|
||||
)
|
||||
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
|
||||
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
@@ -1266,6 +1285,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
self.s_strides2[:, 0] = k
|
||||
|
||||
self.group_size = group_size
|
||||
self._permute_scratch: MoEPermuteScratch | None = None
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
@@ -1330,6 +1350,17 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def _get_permute_scratch(self) -> MoEPermuteScratch | None:
|
||||
if self._permute_scratch is None and moe_permute_unpermute_supported():
|
||||
self._permute_scratch = MoEPermuteScratch(
|
||||
max_num_tokens=self.moe_config.max_num_tokens,
|
||||
topk=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
device=torch.device(self.moe_config.device),
|
||||
)
|
||||
return self._permute_scratch
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
@@ -1409,4 +1440,5 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
use_batched_format,
|
||||
topk_weights,
|
||||
self.group_size,
|
||||
self._get_permute_scratch(),
|
||||
)
|
||||
|
||||
@@ -26,7 +26,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.moe_fused_mul_sum import moe_fused_mul_sum
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
MoEPermuteScratch,
|
||||
moe_permute,
|
||||
moe_permute_unpermute_supported,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@@ -85,6 +87,7 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
self._permute_scratch: MoEPermuteScratch | None = None
|
||||
|
||||
def init_humming_moe(self):
|
||||
self.compute_config = {
|
||||
@@ -110,6 +113,19 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
|
||||
self.w13_tuning_config_str = json.dumps(self.w13_tuning_config)
|
||||
self.w2_tuning_config_str = json.dumps(self.w2_tuning_config)
|
||||
|
||||
def _get_permute_scratch(self) -> MoEPermuteScratch | None:
|
||||
if self._permute_scratch is None and moe_permute_unpermute_supported():
|
||||
self._permute_scratch = MoEPermuteScratch(
|
||||
max_num_tokens=self.moe_config.max_num_tokens,
|
||||
topk=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
device=torch.device(self.moe_config.device),
|
||||
hidden_size=self.moe_config.hidden_dim,
|
||||
hidden_dtype=self.moe_config.in_dtype,
|
||||
)
|
||||
return self._permute_scratch
|
||||
|
||||
def get_global_valid_shape_m(self, topk_ids: torch.Tensor):
|
||||
num_tokens = topk_ids.size(0)
|
||||
ctx = get_forward_context()
|
||||
@@ -598,6 +614,7 @@ class HummingGroupedExperts(HummingExpertsBase):
|
||||
n_expert=self.global_num_experts,
|
||||
n_local_expert=self.num_experts,
|
||||
expert_map=self.layer.expert_map,
|
||||
scratch=self._get_permute_scratch(),
|
||||
)
|
||||
|
||||
inputs, input_scale = HummingMethod.may_quant_input(
|
||||
|
||||
@@ -1,9 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoEPermuteScratch:
|
||||
# Reused metadata buffers for repeated grouped-MoE permutes.
|
||||
max_num_tokens: int
|
||||
topk: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
device: torch.device
|
||||
hidden_size: int | None = None
|
||||
hidden_dtype: torch.dtype | None = None
|
||||
token_expert_indices: torch.Tensor = field(init=False)
|
||||
expert_first_token_offset: torch.Tensor = field(init=False)
|
||||
permuted_idx: torch.Tensor = field(init=False)
|
||||
inv_permuted_idx: torch.Tensor = field(init=False)
|
||||
permuted_hidden_states: torch.Tensor | None = field(init=False, default=None)
|
||||
sort_workspace: torch.Tensor = field(init=False)
|
||||
permuted_experts_id: torch.Tensor = field(init=False)
|
||||
sorted_row_idx: torch.Tensor = field(init=False)
|
||||
topk_ids_int32: torch.Tensor = field(init=False)
|
||||
topk_ids_for_sort: torch.Tensor = field(init=False)
|
||||
max_expanded_rows: int = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.max_num_tokens > 0
|
||||
assert self.topk > 0
|
||||
assert self.num_experts > 0
|
||||
assert self.num_local_experts > 0
|
||||
if self.hidden_size is None:
|
||||
assert self.hidden_dtype is None
|
||||
else:
|
||||
assert self.hidden_dtype is not None
|
||||
|
||||
self.max_expanded_rows = self.max_num_tokens * self.topk
|
||||
self.token_expert_indices = torch.arange(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.expert_first_token_offset = torch.empty(
|
||||
self.num_local_experts + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.permuted_idx = torch.empty(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.inv_permuted_idx = torch.empty(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
if self.hidden_size is not None:
|
||||
hidden_numel = self.max_expanded_rows * self.hidden_size
|
||||
self.permuted_hidden_states = torch.empty(
|
||||
hidden_numel, dtype=self.hidden_dtype, device=self.device
|
||||
)
|
||||
self.permuted_experts_id = torch.empty(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.sorted_row_idx = torch.empty(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.topk_ids_int32 = torch.empty(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.topk_ids_for_sort = torch.empty(
|
||||
self.max_expanded_rows, dtype=torch.int32, device=self.device
|
||||
)
|
||||
sorter_size = torch.ops._moe_C.moe_permute_sort_workspace_size(
|
||||
self.max_expanded_rows, self.num_experts
|
||||
)
|
||||
self.sort_workspace = torch.empty(
|
||||
sorter_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
|
||||
def validate(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor) -> None:
|
||||
n_token, n_hidden = hidden_states.shape
|
||||
assert hidden_states.device == self.device
|
||||
assert topk_ids.device == self.device
|
||||
assert n_token <= self.max_num_tokens
|
||||
assert topk_ids.size(1) == self.topk
|
||||
assert topk_ids.size(0) == n_token
|
||||
if self.hidden_size is not None:
|
||||
assert n_hidden == self.hidden_size
|
||||
assert hidden_states.dtype == self.hidden_dtype
|
||||
assert self.permuted_hidden_states is not None
|
||||
|
||||
def token_expert_indices_view(self, n_token: int) -> torch.Tensor:
|
||||
return self.token_expert_indices[: n_token * self.topk].view(n_token, self.topk)
|
||||
|
||||
def prepare_topk_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
if topk_ids.dtype == torch.int32:
|
||||
return topk_ids
|
||||
numel = topk_ids.numel()
|
||||
topk_ids_int32 = self.topk_ids_int32[:numel].view_as(topk_ids)
|
||||
topk_ids_int32.copy_(topk_ids)
|
||||
return topk_ids_int32
|
||||
|
||||
|
||||
def moe_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -12,6 +107,7 @@ def moe_permute(
|
||||
n_local_expert: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
permuted_hidden_states: torch.Tensor | None = None,
|
||||
scratch: MoEPermuteScratch | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function expands and permutes activation to gather uncontinuous tokens
|
||||
@@ -45,46 +141,92 @@ def moe_permute(
|
||||
if n_local_expert == -1:
|
||||
n_local_expert = n_expert
|
||||
if permuted_hidden_states is None:
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
if scratch is None:
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
scratch.validate(hidden_states, topk_ids)
|
||||
hidden_numel = permuted_row_size * n_hidden
|
||||
scratch_hidden_states = scratch.permuted_hidden_states
|
||||
assert scratch_hidden_states is not None
|
||||
permuted_hidden_states = scratch_hidden_states[:hidden_numel].view(
|
||||
permuted_row_size, n_hidden
|
||||
)
|
||||
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
|
||||
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
|
||||
f" but got {permuted_hidden_states.size()}"
|
||||
)
|
||||
|
||||
token_expert_indices = torch.arange(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
if scratch is None:
|
||||
token_expert_indices = torch.arange(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
|
||||
expert_first_token_offset = torch.empty(
|
||||
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
|
||||
)
|
||||
permuted_idx = torch.full(
|
||||
(permuted_row_size,),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
inv_permuted_idx = torch.empty(
|
||||
(n_token, topk), dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
torch.ops._moe_C.moe_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
expert_map,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
topk,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
permuted_idx,
|
||||
)
|
||||
expert_first_token_offset = torch.empty(
|
||||
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
|
||||
)
|
||||
permuted_idx = torch.full(
|
||||
(permuted_row_size,),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
inv_permuted_idx = torch.empty(
|
||||
(n_token, topk), dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
topk_ids_int32 = topk_ids.to(torch.int32)
|
||||
torch.ops._moe_C.moe_permute(
|
||||
hidden_states,
|
||||
topk_ids_int32,
|
||||
token_expert_indices,
|
||||
expert_map,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
topk,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
permuted_idx,
|
||||
)
|
||||
else:
|
||||
scratch.validate(hidden_states, topk_ids)
|
||||
assert n_expert == scratch.num_experts
|
||||
assert n_local_expert == scratch.num_local_experts
|
||||
token_expert_indices = scratch.token_expert_indices_view(n_token)
|
||||
expert_first_token_offset = scratch.expert_first_token_offset
|
||||
permuted_idx = scratch.permuted_idx[:permuted_row_size]
|
||||
permuted_idx.fill_(permuted_row_size)
|
||||
inv_permuted_idx = scratch.inv_permuted_idx[:permuted_row_size].view(
|
||||
n_token, topk
|
||||
)
|
||||
permuted_experts_id = scratch.permuted_experts_id[:permuted_row_size].view(
|
||||
n_token, topk
|
||||
)
|
||||
sorted_row_idx = scratch.sorted_row_idx[:permuted_row_size].view(n_token, topk)
|
||||
topk_ids_for_sort = scratch.topk_ids_for_sort[:permuted_row_size].view(
|
||||
n_token, topk
|
||||
)
|
||||
topk_ids_int32 = scratch.prepare_topk_ids(topk_ids)
|
||||
torch.ops._moe_C.moe_permute_with_scratch(
|
||||
hidden_states,
|
||||
topk_ids_int32,
|
||||
token_expert_indices,
|
||||
expert_map,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
topk,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
permuted_idx,
|
||||
scratch.sort_workspace,
|
||||
permuted_experts_id,
|
||||
sorted_row_idx,
|
||||
topk_ids_for_sort,
|
||||
)
|
||||
|
||||
if a1q_scale is not None and a1q_scale.dim() > 1:
|
||||
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk]
|
||||
|
||||
Reference in New Issue
Block a user