From 5c376ea7b3081a6b4bd0fdcb58f7915c02efe67d Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Sat, 20 Dec 2025 07:00:48 +0000 Subject: [PATCH] [https://nvbugs/5680133][fix] Implement customizable router for cutlass MoE during autotuning Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- .../cutlass_kernels/include/moe_kernels.h | 8 +- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 29 +++-- .../mixtureOfExpertsPlugin.cpp | 3 +- cpp/tensorrt_llm/thop/moeOp.cpp | 19 +-- .../kernels/mixtureOfExpertsTest.cu | 6 +- .../_torch/custom_ops/torch_custom_ops.py | 115 ++++++++++++++++-- .../modules/fused_moe/fused_moe_cutlass.py | 15 ++- 7 files changed, 162 insertions(+), 33 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 30db904020..fce35dace6 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -978,7 +978,9 @@ public: } } - void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream); + void prepare(int num_tokens, char* workspace, void const* expert_weights, + void const* token_selected_experts_customized = nullptr, bool use_customized_router = false, + cudaStream_t stream = nullptr); std::map> getProfilerWorkspaces(int maxM, bool is_tma_ws); size_t getWorkspaceSize(int maxM); @@ -1002,6 +1004,7 @@ public: bool mEnableAlltoall = false; int mSampleIndex = 0; + bool mIsCustomizedRouter = false; nvinfer1::DataType mDType{}; nvinfer1::DataType mWType{}; @@ -1024,8 +1027,9 @@ public: TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType mScalingType{}; private: - void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream); void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream); + void prepareRouting(int num_tokens, char* workspace, void const* token_selected_experts_customized, + bool use_customized_router, cudaStream_t stream); void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, cudaStream_t stream); }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 2447273752..304375b2fc 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -4472,7 +4472,8 @@ std::map> GemmProfilerBackend::getProfile return out_map; } -void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream) +void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, + void const* token_selected_experts_customized, bool use_customized_router, cudaStream_t stream) { auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90); #define GET_WS_PTR_BASE(type, name) \ @@ -4513,10 +4514,19 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha int const start_expert_id = mNumExpertsPerNode * mParallelismConfig.ep_rank; uint32_t num_threads = 256; - dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1}; - prepareFakeRouterBuffers<<>>( - token_selected_experts_base, num_tokens, mK, mNumExperts); - sync_check_cuda_error(stream); + if (use_customized_router) + { + // copy token selected experts to token_selected_experts_base + cudaMemcpyAsync(token_selected_experts_base, token_selected_experts_customized, + num_tokens * mK * sizeof(int), cudaMemcpyDeviceToDevice, stream); + } + else + { + dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1}; + prepareFakeRouterBuffers<<>>( + token_selected_experts_base, num_tokens, mK, mNumExperts); + sync_check_cuda_error(stream); + } for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { @@ -4726,15 +4736,16 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr } } -void GemmProfilerBackend::prepare( - int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream) +void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights, + void const* token_selected_experts_customized, bool use_customized_router, cudaStream_t stream) { mSampleIndex = 0; + mIsCustomizedRouter = use_customized_router; auto workspace_size = getWorkspaceSize(num_tokens); populateRandomBuffer(workspace_ptr_char, workspace_size, stream); - prepareRouting(num_tokens, workspace_ptr_char, stream); + prepareRouting(num_tokens, workspace_ptr_char, token_selected_experts_customized, use_customized_router, stream); prepareQuantParams(num_tokens, workspace_ptr_char, stream); for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE}) @@ -4762,7 +4773,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac int64_t expanded_num_tokens = original_num_tokens * mK; int64_t num_experts_per_node = mNumExpertsPerNode; - mSampleIndex = (mSampleIndex + 1) % NUM_ROUTING_SAMPLES; + mSampleIndex = mIsCustomizedRouter ? 0 : (mSampleIndex + 1) % NUM_ROUTING_SAMPLES; auto workspaces = getProfilerWorkspaces(original_num_tokens, tactic.is_tma_warp_specialized); diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index ccce348507..2c3824bd35 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -1287,7 +1287,8 @@ void MixtureOfExpertsGemmProfiler::initTmpData( int m, int n, int k, char* workspace, size_t ws_size, cudaStream_t stream) { checkInit(); - backend.prepare(m, workspace, /*expert_weights*/ nullptr, stream); + backend.prepare(m, workspace, /*expert_weights*/ nullptr, /*token_selected_experts_customized*/ nullptr, + /*use_customized_router*/ false, stream); } void MixtureOfExpertsGemmProfiler::checkInit() diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 5e744804aa..2921266136 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -666,13 +666,13 @@ public: } // TODO Update this to be able to tell if we are profiling swiglu bias - void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights, - torch::optional const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, - torch::optional const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size, - int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, - int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx, - int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int, - int64_t const unpadded_hidden_size) + void runGemmProfile(torch::Tensor const& input, torch::optional const& token_final_scales, + torch::Tensor const& fc1_expert_weights, torch::optional const& fc1_expert_biases, + torch::Tensor const& fc2_expert_weights, torch::optional const& fc2_expert_biases, + int64_t const top_k, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, + int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, + int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int, + int64_t const unpadded_hidden_size, bool const use_customized_router) { std::lock_guard lock(mMutex); @@ -752,7 +752,10 @@ public: auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size); TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile."); - mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream); + void const* token_selected_experts_customized + = token_final_scales.has_value() ? token_final_scales.value().const_data_ptr() : nullptr; + mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, token_selected_experts_customized, + use_customized_router, stream); } // Profile specific tactic. Assuming at least one preparation phase has been executed already. diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index 01cd1c4d79..38d37f1dd1 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -2569,7 +2569,8 @@ TYPED_TEST(MixtureOfExpertsTest, RunProfiler) for (int64_t num_tokens : {1, 128}) { - backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, this->mStream->get()); + backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, + /*token_selected_experts_customized=*/nullptr, /*use_customized_router=*/false, this->mStream->get()); for (auto const& tactic : this->getAllTileConfigsToTest()) { backend.runProfiler(num_tokens, @@ -2616,7 +2617,8 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution) auto workspace = this->allocBuffer(ws_size); int64_t num_experts_per_node = num_experts / ep; - backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, mStream->get()); + backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, + /*token_selected_experts_customized=*/nullptr, /*use_customized_router=*/false, mStream->get()); auto workspaces = backend.getProfilerWorkspaces(num_tokens, getSMVersion() >= 90 && getSMVersion() < 120); #define GET_WS_PTR(type, name) auto* name = reinterpret_cast(workspace + workspaces.at(#name).second) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 2ee8d29ccc..ff7b6df9ee 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -7,6 +7,8 @@ import triton # type: ignore[import] import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm import deep_gemm +from tensorrt_llm._torch.modules.fused_moe.routing import ( + ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType) from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy from tensorrt_llm.logger import logger @@ -30,6 +32,74 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: torch.bmm(a, b, out=out) +def prepare_dummy_token_selected_experts_hook( + input: torch.Tensor, + top_k: int, + num_experts: int, + n_group: Optional[int], + topk_group: Optional[int], + routed_scaling_factor: Optional[float], + routing_method_type: int = int(RoutingMethodType.Default), +): + """ + Creates a hook function that generates dummy token_selected_experts for tuning. + + Args: + input: Input tensor to determine shape and device + top_k: Number of experts per token + num_experts: Total number of experts + routing_method_type: Type of routing method to use + + Returns: + A hook function that can be used with the tuner + """ + tuner = AutoTuner.get() + if not tuner.is_tuning_mode: + return lambda inputs: inputs + + input_tensor = input[0] + + # Get routing method + routing_cls_kwargs = {} + if routing_method_type == int(RoutingMethodType.DeepSeekV3): + routing_cls_kwargs.update({ + 'n_group': + n_group, + 'topk_group': + topk_group, + 'routed_scaling_factor': + routed_scaling_factor, + 'is_fused': + False, + 'callable_e_score_correction_bias': + lambda: torch.randn( + num_experts, dtype=torch.bfloat16, device=input_tensor.device) + }) + routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type]( + top_k=top_k, **routing_cls_kwargs) + + def create_dummy_token_selected_experts( + inputs: List[torch.Tensor], ) -> List[torch.Tensor]: + input_tensor = inputs[0] # First tensor is the input + # Generate dummy routing logits with correct shape + routing_logits_for_tuner = torch.randn(input_tensor.shape[0], + num_experts, + dtype=torch.bfloat16, + device=input_tensor.device) + + # Apply routing to get properly shaped token_selected_experts + topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply( + routing_logits_for_tuner) + + # Replace the token_selected_experts tensor (inputs[1]) with our generated one + if len(inputs) > 1: + inputs[1] = topk_ids_for_tuner + + return inputs + + return create_dummy_token_selected_experts + + class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() @@ -37,7 +107,9 @@ class MoERunner(TunableRunner): dynamic_tensor_specs=(DynamicTensorSpec( 0, 0, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ), tune_max_num_tokens=8192, + inputs_pre_hook=None, # Will be set dynamically in fused_moe function distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL, ) @@ -82,6 +154,7 @@ class MoERunner(TunableRunner): self.use_fused_finalize = use_fused_finalize self.activation_type = activation_type self.unpadded_hidden_size = unpadded_hidden_size if unpadded_hidden_size is not None else 0 + self.use_customized_router = False instance_key = (x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, @@ -126,10 +199,12 @@ class MoERunner(TunableRunner): gemm_idx: int = 0, tactic: int = -1, do_preparation: bool = False, + **kwargs, ): - x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs + x, token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs self.fused_moe_runner.run_gemm_profile( x, + token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, @@ -148,6 +223,7 @@ class MoERunner(TunableRunner): do_preparation, self.activation_type, self.unpadded_hidden_size, + self.use_customized_router, ) @@ -184,6 +260,10 @@ def fused_moe( tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, activation_type: int = int(ActivationType.Swiglu), + routing_method_type: int = int(RoutingMethodType.Default), + n_group: Optional[int] = None, + topk_group: Optional[int] = None, + routed_scaling_factor: Optional[float] = None, unpadded_hidden_size: Optional[int] = None, out_tensor: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: @@ -201,6 +281,18 @@ def fused_moe( tuner_input = input tuner_top_k = token_selected_experts.size(1) + tuning_config = MoERunner.tuning_config + tuning_config.inputs_pre_hook = prepare_dummy_token_selected_experts_hook( + tuner_input, + tuner_top_k, + fc1_expert_weights.shape[0] * + ep_size, # num_experts from weight tensor shape + n_group, + topk_group, + routed_scaling_factor, + routing_method_type, + ) + # allocate workspace for profiling moe_runner = MoERunner( x_dtype=input.dtype, @@ -224,27 +316,30 @@ def fused_moe( ) MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens - + input_tensors = [ + tuner_input, + token_selected_experts, + fc1_expert_weights, + fc1_expert_biases, + fc2_expert_weights, + fc2_expert_biases, + ] _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], MoERunner.tuning_config, - [ - tuner_input, fc1_expert_weights, fc1_expert_biases, - fc2_expert_weights, fc2_expert_biases - ], + input_tensors, gemm_idx=1, + ep_size=ep_size, ) _, gemm_tactic_2 = tuner.choose_one( "trtllm::fused_moe::gemm2", [moe_runner], MoERunner.tuning_config, - [ - tuner_input, fc1_expert_weights, fc1_expert_biases, - fc2_expert_weights, fc2_expert_biases - ], + input_tensors, gemm_idx=2, + ep_size=ep_size, ) run_moe = moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else moe_runner.fused_moe_runner.run_moe diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 71e13e1324..970e04e792 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -24,7 +24,7 @@ from .quantization import ( W4A8MXFP4MXFP8CutlassFusedMoEMethod, WFP4A16FusedMoEMethod, WInt4AFP8FusedMoEMethod) # isort: on -from .routing import BaseMoeRoutingMethod +from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod class CutlassFusedMoE(MoE): @@ -433,6 +433,15 @@ class CutlassFusedMoE(MoE): elif self.has_w4a16_mxfp4: weight_dtype = torch.uint8 + if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): + n_group = self.routing_method.routing_impl.n_group + topk_group = self.routing_method.routing_impl.topk_group + routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor + else: + n_group = None + topk_group = None + routed_scaling_factor = None + final_hidden_states = torch.ops.trtllm.fused_moe( x, token_selected_experts, @@ -465,6 +474,10 @@ class CutlassFusedMoE(MoE): tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, activation_type=self.activation_type, + routing_method_type=self.routing_method.routing_method_type, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, unpadded_hidden_size=self.unpadded_hidden_size, out_tensor=moe_output, )