diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 990be593299..6867274ebe2 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -10,6 +10,7 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + MoEPermuteScratch, moe_permute, moe_unpermute, ) @@ -54,6 +55,15 @@ def benchmark_permute( topk_weights, topk_ids, token_expert_indices = fused_topk( qhidden_states, input_gating, topk, False ) + scratch = MoEPermuteScratch( + max_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + num_local_experts=num_experts, + device=qhidden_states.device, + hidden_size=hidden_size, + hidden_dtype=qhidden_states.dtype, + ) def prepare(i: int): input_gating.copy_(gating_output[i]) @@ -65,6 +75,7 @@ def benchmark_permute( topk_ids=topk_ids, n_expert=num_experts, expert_map=None, + scratch=scratch, ) # JIT compilation & warmup @@ -123,6 +134,15 @@ def benchmark_unpermute( topk_weights, topk_ids, token_expert_indices = fused_topk( qhidden_states, input_gating, topk, False ) + scratch = MoEPermuteScratch( + max_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + num_local_experts=num_experts, + device=qhidden_states.device, + hidden_size=hidden_size, + hidden_dtype=qhidden_states.dtype, + ) def prepare(): ( @@ -137,6 +157,7 @@ def benchmark_unpermute( topk_ids=topk_ids, n_expert=num_experts, expert_map=None, + scratch=scratch, ) # convert to fp16/bf16 as gemm output return ( diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index ac0e8d59f60..ca2776c6edd 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -62,6 +62,9 @@ std::tuple grouped_topk( bool moe_permute_unpermute_supported(); +int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows, + int64_t num_experts); + void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index c7fcb3ecf2a..6fce009ae6d 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -8,6 +8,108 @@ // moe_permute kernels require at least CUDA 12.0 #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) +namespace { + +torch::Tensor maybe_allocate_tensor( + const std::optional& maybe_tensor, + at::IntArrayRef expected_sizes, torch::ScalarType dtype, c10::Device device, + char const* name) { + auto expected_numel = c10::multiply_integers(expected_sizes); + if (maybe_tensor.has_value()) { + auto tensor = maybe_tensor.value(); + TORCH_CHECK(tensor.device() == device, name, " must be on the same device"); + TORCH_CHECK(tensor.scalar_type() == dtype, name, " has incorrect dtype"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.numel() >= expected_numel, name, + " is too small for the requested shape"); + auto flat_tensor = tensor.view({tensor.numel()}); + return flat_tensor.narrow(0, 0, expected_numel).view(expected_sizes); + } + return torch::empty(expected_sizes, torch::dtype(dtype).device(device)); +} + +} // namespace + +int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows, + int64_t n_expert) { + return static_cast( + CubKeyValueSorter::getWorkspaceSize(num_expanded_rows, n_expert)); +} + +void moe_permute_impl( + const torch::Tensor& input, // [n_token, hidden] + const torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& token_expert_indices, // [n_token, topk] + const std::optional& expert_map, // [n_expert] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + torch::Tensor& permuted_input, // [permuted_size, hidden] + torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] + torch::Tensor& inv_permuted_idx, // [n_token, topk] + torch::Tensor& permuted_idx, // [permute_size] + const std::optional& maybe_sort_workspace, + const std::optional& maybe_permuted_experts_id, + const std::optional& maybe_sorted_row_idx, + const std::optional& maybe_topk_ids_for_sort) { + TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, + "expert_first_token_offset must be int64"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, + "token_expert_indices must be int32"); + TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int, + "inv_permuted_idx must be int32"); + TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, + "expert_first_token_offset shape != n_local_expert+1"); + TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(), + "token_expert_indices shape must be same as inv_permuted_idx"); + auto device = input.device(); + auto n_token = input.sizes()[0]; + auto n_hidden = input.sizes()[1]; + auto expanded_rows = n_token * topk; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto sorter_size = moe_permute_sort_workspace_size(expanded_rows, n_expert); + auto sort_workspace = + maybe_allocate_tensor(maybe_sort_workspace, {sorter_size}, torch::kInt8, + device, "sort_workspace"); + auto permuted_experts_id = + maybe_allocate_tensor(maybe_permuted_experts_id, topk_ids.sizes(), + at::ScalarType::Int, device, "permuted_experts_id"); + auto sorted_row_idx = + maybe_allocate_tensor(maybe_sorted_row_idx, inv_permuted_idx.sizes(), + at::ScalarType::Int, device, "sorted_row_idx"); + + CubKeyValueSorter sorter{}; + int64_t* valid_num_ptr = nullptr; + torch::Tensor topk_ids_for_sort = topk_ids; + + if (expert_map.has_value()) { + const int* expert_map_ptr = get_ptr(expert_map.value()); + valid_num_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + topk_ids_for_sort = + maybe_allocate_tensor(maybe_topk_ids_for_sort, topk_ids.sizes(), + at::ScalarType::Int, device, "topk_ids_for_sort"); + topk_ids_for_sort.copy_(topk_ids); + preprocessTopkIdLauncher(get_ptr(topk_ids_for_sort), n_token * topk, + expert_map_ptr, n_expert, stream); + } + + sortAndScanExpert( + get_ptr(topk_ids_for_sort), get_ptr(token_expert_indices), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(expert_first_token_offset), n_token, n_expert, + n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); + + MOE_DISPATCH(input.scalar_type(), [&] { + expandInputRowsKernelLauncher( + get_ptr(input), get_ptr(permuted_input), + get_ptr(sorted_row_idx), get_ptr(inv_permuted_idx), + get_ptr(permuted_idx), get_ptr(expert_first_token_offset), + n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream); + }); +} + void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_ids, // [n_token, topk] @@ -18,65 +120,26 @@ void moe_permute( torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] torch::Tensor& inv_permuted_idx, // [n_token, topk] torch::Tensor& permuted_idx) { // [permute_size] - TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, - "expert_first_token_offset must be int64"); - TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, - "topk_ids must be int32"); - TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, - "token_expert_indices must be int32"); - TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int, - "inv_permuted_idx must be int32"); - TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, - "expert_first_token_offset shape != n_local_expert+1") - TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(), - "token_expert_indices shape must be same as inv_permuted_idx"); - auto n_token = input.sizes()[0]; - auto n_hidden = input.sizes()[1]; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const long sorter_size = - CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); - auto sort_workspace = torch::empty( - {sorter_size}, - torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - torch::Tensor topk_ids_for_sort = topk_ids; - auto permuted_experts_id = torch::empty_like(topk_ids); - auto sorted_row_idx = torch::empty_like(inv_permuted_idx); + moe_permute_impl(input, topk_ids, token_expert_indices, expert_map, n_expert, + n_local_expert, topk, permuted_input, + expert_first_token_offset, inv_permuted_idx, permuted_idx, + std::nullopt, std::nullopt, std::nullopt, std::nullopt); +} - CubKeyValueSorter sorter{}; - int64_t* valid_num_ptr = nullptr; - // pre-process kernel for expert-parallelism: - // no local expert id plus "n_expert" offset for priority to local expert - // map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1] - // For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id - // [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids - // and map global expert id [2, 3] to local_expert id [0, 1] and map global - // expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map - // operation is to make local expert high priority in following sort topk_ids - // and scan local expert_first_token_offset for each ep rank for next group - // gemm. - if (expert_map.has_value()) { - const int* expert_map_ptr = get_ptr(expert_map.value()); - valid_num_ptr = - get_ptr(expert_first_token_offset) + n_local_expert; - topk_ids_for_sort = topk_ids.clone(); - preprocessTopkIdLauncher(get_ptr(topk_ids_for_sort), n_token * topk, - expert_map_ptr, n_expert, stream); - } - // expert sort topk expert id and scan expert id get expert_first_token_offset - sortAndScanExpert( - get_ptr(topk_ids_for_sort), get_ptr(token_expert_indices), - get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), - get_ptr(expert_first_token_offset), n_token, n_expert, - n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); - - // dispatch expandInputRowsKernelLauncher - MOE_DISPATCH(input.scalar_type(), [&] { - expandInputRowsKernelLauncher( - get_ptr(input), get_ptr(permuted_input), - get_ptr(sorted_row_idx), get_ptr(inv_permuted_idx), - get_ptr(permuted_idx), get_ptr(expert_first_token_offset), - n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream); - }); +void moe_permute_with_scratch( + const torch::Tensor& input, const torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indices, + const std::optional& expert_map, int64_t n_expert, + int64_t n_local_expert, int64_t topk, torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, torch::Tensor& inv_permuted_idx, + torch::Tensor& permuted_idx, torch::Tensor& sort_workspace, + torch::Tensor& permuted_experts_id, torch::Tensor& sorted_row_idx, + torch::Tensor& topk_ids_for_sort) { + moe_permute_impl(input, topk_ids, token_expert_indices, expert_map, n_expert, + n_local_expert, topk, permuted_input, + expert_first_token_offset, inv_permuted_idx, permuted_idx, + sort_workspace, permuted_experts_id, sorted_row_idx, + topk_ids_for_sort); } void moe_unpermute( @@ -169,6 +232,12 @@ void shuffle_rows(const torch::Tensor& input_tensor, #else +int64_t moe_permute_sort_workspace_size(int64_t num_expanded_rows, + int64_t n_expert) { + TORCH_CHECK( + false, "moe_permute_sort_workspace_size is not supported on CUDA < 12.0"); +} + void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids, const torch::Tensor& token_expert_indices, const std::optional& expert_map, @@ -179,6 +248,19 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids, TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } +void moe_permute_with_scratch( + const torch::Tensor& input, const torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indices, + const std::optional& expert_map, int64_t n_expert, + int64_t n_local_expert, int64_t topk, torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, torch::Tensor& inv_permuted_idx, + torch::Tensor& permuted_idx, torch::Tensor& sort_workspace, + torch::Tensor& permuted_experts_id, torch::Tensor& sorted_row_idx, + torch::Tensor& topk_ids_for_sort) { + TORCH_CHECK(false, + "moe_permute_with_scratch is not supported on CUDA < 12.0"); +} + void moe_unpermute( const torch::Tensor& permuted_hidden_states, const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, @@ -199,5 +281,6 @@ bool moe_permute_unpermute_supported() { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); + m.impl("moe_permute_with_scratch", &moe_permute_with_scratch); m.impl("moe_unpermute", &moe_unpermute); } \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index b8145435cb1..99230f03b4b 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -100,13 +100,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " "permuted_idx)->()"); + m.def( + "moe_permute_with_scratch(Tensor input, Tensor topk_ids," + "Tensor token_expert_indices, Tensor? expert_map, int n_expert," + "int n_local_expert," + "int topk, Tensor! permuted_input, Tensor! " + "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " + "permuted_idx, Tensor! sort_workspace, Tensor! permuted_experts_id, " + "Tensor! sorted_row_idx, Tensor! topk_ids_for_sort)->()"); + m.def( "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," "Tensor inv_permuted_idx, Tensor? expert_first_token_offset, " "int topk, Tensor! hidden_states)->()"); m.def("moe_permute_unpermute_supported() -> bool"); + m.def( + "moe_permute_sort_workspace_size(int num_expanded_rows, int n_expert) -> " + "int"); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); + m.impl("moe_permute_sort_workspace_size", &moe_permute_sort_workspace_size); // Row shuffle for MoE m.def( diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 5aafb89589f..03db46851ef 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.expert_map_manager import ( determine_expert_map, ) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + MoEPermuteScratch, moe_permute, moe_permute_unpermute_supported, moe_unpermute, @@ -209,3 +210,79 @@ def test_moe_permute_unpermute( ) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_moe_permute_reuses_scratch_buffers(dtype: torch.dtype): + if not moe_permute_unpermute_supported(): + pytest.skip("moe_permute_unpermute is not supported on this platform.") + + n_token = 64 + n_hidden = 2048 + n_expert = 16 + topk = 4 + hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) + gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) + _, topk_ids, _ = fused_topk(hidden_states, gating_output, topk, False) + + scratch = MoEPermuteScratch( + max_num_tokens=n_token, + topk=topk, + num_experts=n_expert, + num_local_experts=n_expert, + device=hidden_states.device, + hidden_size=n_hidden, + hidden_dtype=hidden_states.dtype, + ) + + first = moe_permute( + hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=n_expert, + scratch=scratch, + ) + second = moe_permute( + hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=n_expert, + scratch=scratch, + ) + + ( + permuted_hidden_states_1, + _, + expert_first_token_offset_1, + inv_permuted_idx_1, + permuted_idx_1, + ) = first + ( + permuted_hidden_states_2, + _, + expert_first_token_offset_2, + inv_permuted_idx_2, + permuted_idx_2, + ) = second + + torch.testing.assert_close(permuted_hidden_states_1, permuted_hidden_states_2) + torch.testing.assert_close(expert_first_token_offset_1, expert_first_token_offset_2) + torch.testing.assert_close(inv_permuted_idx_1, inv_permuted_idx_2) + torch.testing.assert_close(permuted_idx_1, permuted_idx_2) + + assert ( + permuted_hidden_states_1.untyped_storage().data_ptr() + == permuted_hidden_states_2.untyped_storage().data_ptr() + ) + assert ( + expert_first_token_offset_1.untyped_storage().data_ptr() + == expert_first_token_offset_2.untyped_storage().data_ptr() + ) + assert ( + inv_permuted_idx_1.untyped_storage().data_ptr() + == scratch.inv_permuted_idx.untyped_storage().data_ptr() + ) + assert ( + permuted_idx_1.untyped_storage().data_ptr() + == scratch.permuted_idx.untyped_storage().data_ptr() + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py index 53a876721a1..28a7d283b4b 100644 --- a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py @@ -17,7 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + MoEPermuteScratch, moe_permute, + moe_permute_unpermute_supported, moe_unpermute, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -73,6 +75,7 @@ def run_cutlass_moe_fp8( per_out_ch: bool, use_batched_format: bool, topk_weights: torch.Tensor | None, + permute_scratch: MoEPermuteScratch | None, ): a1q = hidden_states @@ -198,6 +201,7 @@ def run_cutlass_moe_fp8( local_E, expert_map, permuted_hidden_states=a1q_perm, + scratch=permute_scratch, ) # swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding). swap_ab = a1q.size(0) <= 64 @@ -291,6 +295,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): self.ab_strides2 = ab_strides2 self.c_strides1 = c_strides1 self.c_strides2 = ab_strides1_c_strides2 + self._permute_scratch: MoEPermuteScratch | None = None @staticmethod def _supports_current_device() -> bool: @@ -324,6 +329,17 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() + def _get_permute_scratch(self) -> MoEPermuteScratch | None: + if self._permute_scratch is None and moe_permute_unpermute_supported(): + self._permute_scratch = MoEPermuteScratch( + max_num_tokens=self.moe_config.max_num_tokens, + topk=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts, + device=torch.device(self.moe_config.device), + ) + return self._permute_scratch + def apply( self, output: torch.Tensor, @@ -379,6 +395,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): self.per_out_ch_quant, use_batched_format, topk_weights, + self._get_permute_scratch(), ) @@ -1121,6 +1138,7 @@ def run_cutlass_moe_w4a8_fp8( use_batched_format: bool, topk_weights: torch.Tensor | None, group_size: int, + permute_scratch: MoEPermuteScratch | None, ): a1q = hidden_states M = a1q.size(0) @@ -1176,6 +1194,7 @@ def run_cutlass_moe_w4a8_fp8( local_E, expert_map, permuted_hidden_states=a1q_perm, + scratch=permute_scratch, ) # for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape). ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( @@ -1266,6 +1285,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): self.s_strides2[:, 0] = k self.group_size = group_size + self._permute_scratch: MoEPermuteScratch | None = None @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -1330,6 +1350,17 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: return self.out_dtype if self.out_dtype is not None else act_dtype + def _get_permute_scratch(self) -> MoEPermuteScratch | None: + if self._permute_scratch is None and moe_permute_unpermute_supported(): + self._permute_scratch = MoEPermuteScratch( + max_num_tokens=self.moe_config.max_num_tokens, + topk=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts, + device=torch.device(self.moe_config.device), + ) + return self._permute_scratch + def workspace_shapes( self, M: int, @@ -1409,4 +1440,5 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): use_batched_format, topk_weights, self.group_size, + self._get_permute_scratch(), ) diff --git a/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py b/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py index 203c5782829..8874228a142 100644 --- a/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/fused_humming_moe.py @@ -26,7 +26,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ) from vllm.model_executor.layers.fused_moe.moe_fused_mul_sum import moe_fused_mul_sum from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + MoEPermuteScratch, moe_permute, + moe_permute_unpermute_supported, moe_unpermute, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -85,6 +87,7 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular): max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, ) + self._permute_scratch: MoEPermuteScratch | None = None def init_humming_moe(self): self.compute_config = { @@ -110,6 +113,19 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular): self.w13_tuning_config_str = json.dumps(self.w13_tuning_config) self.w2_tuning_config_str = json.dumps(self.w2_tuning_config) + def _get_permute_scratch(self) -> MoEPermuteScratch | None: + if self._permute_scratch is None and moe_permute_unpermute_supported(): + self._permute_scratch = MoEPermuteScratch( + max_num_tokens=self.moe_config.max_num_tokens, + topk=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts, + device=torch.device(self.moe_config.device), + hidden_size=self.moe_config.hidden_dim, + hidden_dtype=self.moe_config.in_dtype, + ) + return self._permute_scratch + def get_global_valid_shape_m(self, topk_ids: torch.Tensor): num_tokens = topk_ids.size(0) ctx = get_forward_context() @@ -598,6 +614,7 @@ class HummingGroupedExperts(HummingExpertsBase): n_expert=self.global_num_experts, n_local_expert=self.num_experts, expert_map=self.layer.expert_map, + scratch=self._get_permute_scratch(), ) inputs, input_scale = HummingMethod.may_quant_input( diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index de2a392953f..df430689436 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -1,9 +1,104 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field + import torch +@dataclass +class MoEPermuteScratch: + # Reused metadata buffers for repeated grouped-MoE permutes. + max_num_tokens: int + topk: int + num_experts: int + num_local_experts: int + device: torch.device + hidden_size: int | None = None + hidden_dtype: torch.dtype | None = None + token_expert_indices: torch.Tensor = field(init=False) + expert_first_token_offset: torch.Tensor = field(init=False) + permuted_idx: torch.Tensor = field(init=False) + inv_permuted_idx: torch.Tensor = field(init=False) + permuted_hidden_states: torch.Tensor | None = field(init=False, default=None) + sort_workspace: torch.Tensor = field(init=False) + permuted_experts_id: torch.Tensor = field(init=False) + sorted_row_idx: torch.Tensor = field(init=False) + topk_ids_int32: torch.Tensor = field(init=False) + topk_ids_for_sort: torch.Tensor = field(init=False) + max_expanded_rows: int = field(init=False) + + def __post_init__(self) -> None: + assert self.max_num_tokens > 0 + assert self.topk > 0 + assert self.num_experts > 0 + assert self.num_local_experts > 0 + if self.hidden_size is None: + assert self.hidden_dtype is None + else: + assert self.hidden_dtype is not None + + self.max_expanded_rows = self.max_num_tokens * self.topk + self.token_expert_indices = torch.arange( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + self.expert_first_token_offset = torch.empty( + self.num_local_experts + 1, dtype=torch.int64, device=self.device + ) + self.permuted_idx = torch.empty( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + self.inv_permuted_idx = torch.empty( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + if self.hidden_size is not None: + hidden_numel = self.max_expanded_rows * self.hidden_size + self.permuted_hidden_states = torch.empty( + hidden_numel, dtype=self.hidden_dtype, device=self.device + ) + self.permuted_experts_id = torch.empty( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + self.sorted_row_idx = torch.empty( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + self.topk_ids_int32 = torch.empty( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + self.topk_ids_for_sort = torch.empty( + self.max_expanded_rows, dtype=torch.int32, device=self.device + ) + sorter_size = torch.ops._moe_C.moe_permute_sort_workspace_size( + self.max_expanded_rows, self.num_experts + ) + self.sort_workspace = torch.empty( + sorter_size, dtype=torch.int8, device=self.device + ) + + def validate(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor) -> None: + n_token, n_hidden = hidden_states.shape + assert hidden_states.device == self.device + assert topk_ids.device == self.device + assert n_token <= self.max_num_tokens + assert topk_ids.size(1) == self.topk + assert topk_ids.size(0) == n_token + if self.hidden_size is not None: + assert n_hidden == self.hidden_size + assert hidden_states.dtype == self.hidden_dtype + assert self.permuted_hidden_states is not None + + def token_expert_indices_view(self, n_token: int) -> torch.Tensor: + return self.token_expert_indices[: n_token * self.topk].view(n_token, self.topk) + + def prepare_topk_ids(self, topk_ids: torch.Tensor) -> torch.Tensor: + if topk_ids.dtype == torch.int32: + return topk_ids + numel = topk_ids.numel() + topk_ids_int32 = self.topk_ids_int32[:numel].view_as(topk_ids) + topk_ids_int32.copy_(topk_ids) + return topk_ids_int32 + + def moe_permute( hidden_states: torch.Tensor, a1q_scale: torch.Tensor | None, @@ -12,6 +107,7 @@ def moe_permute( n_local_expert: int = -1, expert_map: torch.Tensor | None = None, permuted_hidden_states: torch.Tensor | None = None, + scratch: MoEPermuteScratch | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens @@ -45,46 +141,92 @@ def moe_permute( if n_local_expert == -1: n_local_expert = n_expert if permuted_hidden_states is None: - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + if scratch is None: + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + else: + scratch.validate(hidden_states, topk_ids) + hidden_numel = permuted_row_size * n_hidden + scratch_hidden_states = scratch.permuted_hidden_states + assert scratch_hidden_states is not None + permuted_hidden_states = scratch_hidden_states[:hidden_numel].view( + permuted_row_size, n_hidden + ) assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" f" but got {permuted_hidden_states.size()}" ) - token_expert_indices = torch.arange( - 0, n_token * topk, dtype=torch.int32, device=hidden_states.device - ).reshape((n_token, topk)) + if scratch is None: + token_expert_indices = torch.arange( + 0, n_token * topk, dtype=torch.int32, device=hidden_states.device + ).reshape((n_token, topk)) - expert_first_token_offset = torch.empty( - n_local_expert + 1, dtype=torch.int64, device=hidden_states.device - ) - permuted_idx = torch.full( - (permuted_row_size,), - n_token * topk, - dtype=torch.int32, - device=hidden_states.device, - ) - inv_permuted_idx = torch.empty( - (n_token, topk), dtype=torch.int32, device=hidden_states.device - ) - topk_ids = topk_ids.to(torch.int32) - torch.ops._moe_C.moe_permute( - hidden_states, - topk_ids, - token_expert_indices, - expert_map, - n_expert, - n_local_expert, - topk, - permuted_hidden_states, - expert_first_token_offset, - inv_permuted_idx, - permuted_idx, - ) + expert_first_token_offset = torch.empty( + n_local_expert + 1, dtype=torch.int64, device=hidden_states.device + ) + permuted_idx = torch.full( + (permuted_row_size,), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device, + ) + inv_permuted_idx = torch.empty( + (n_token, topk), dtype=torch.int32, device=hidden_states.device + ) + topk_ids_int32 = topk_ids.to(torch.int32) + torch.ops._moe_C.moe_permute( + hidden_states, + topk_ids_int32, + token_expert_indices, + expert_map, + n_expert, + n_local_expert, + topk, + permuted_hidden_states, + expert_first_token_offset, + inv_permuted_idx, + permuted_idx, + ) + else: + scratch.validate(hidden_states, topk_ids) + assert n_expert == scratch.num_experts + assert n_local_expert == scratch.num_local_experts + token_expert_indices = scratch.token_expert_indices_view(n_token) + expert_first_token_offset = scratch.expert_first_token_offset + permuted_idx = scratch.permuted_idx[:permuted_row_size] + permuted_idx.fill_(permuted_row_size) + inv_permuted_idx = scratch.inv_permuted_idx[:permuted_row_size].view( + n_token, topk + ) + permuted_experts_id = scratch.permuted_experts_id[:permuted_row_size].view( + n_token, topk + ) + sorted_row_idx = scratch.sorted_row_idx[:permuted_row_size].view(n_token, topk) + topk_ids_for_sort = scratch.topk_ids_for_sort[:permuted_row_size].view( + n_token, topk + ) + topk_ids_int32 = scratch.prepare_topk_ids(topk_ids) + torch.ops._moe_C.moe_permute_with_scratch( + hidden_states, + topk_ids_int32, + token_expert_indices, + expert_map, + n_expert, + n_local_expert, + topk, + permuted_hidden_states, + expert_first_token_offset, + inv_permuted_idx, + permuted_idx, + scratch.sort_workspace, + permuted_experts_id, + sorted_row_idx, + topk_ids_for_sort, + ) if a1q_scale is not None and a1q_scale.dim() > 1: a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk]