This commit is contained in:
Yukun He 2026-01-13 21:20:51 +08:00 committed by GitHub
commit 6eb792a901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 162 additions and 33 deletions

View File

@ -978,7 +978,9 @@ public:
}
}
void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
void prepare(int num_tokens, char* workspace, void const* expert_weights,
void const* token_selected_experts_customized = nullptr, bool use_customized_router = false,
cudaStream_t stream = nullptr);
std::map<std::string, std::pair<size_t, size_t>> getProfilerWorkspaces(int maxM, bool is_tma_ws);
size_t getWorkspaceSize(int maxM);
@ -1002,6 +1004,7 @@ public:
bool mEnableAlltoall = false;
int mSampleIndex = 0;
bool mIsCustomizedRouter = false;
nvinfer1::DataType mDType{};
nvinfer1::DataType mWType{};
@ -1024,8 +1027,9 @@ public:
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType mScalingType{};
private:
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
void prepareRouting(int num_tokens, char* workspace, void const* token_selected_experts_customized,
bool use_customized_router, cudaStream_t stream);
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, cudaStream_t stream);
};

View File

@ -4472,7 +4472,8 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
return out_map;
}
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream)
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char,
void const* token_selected_experts_customized, bool use_customized_router, cudaStream_t stream)
{
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
#define GET_WS_PTR_BASE(type, name) \
@ -4513,10 +4514,19 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha
int const start_expert_id = mNumExpertsPerNode * mParallelismConfig.ep_rank;
uint32_t num_threads = 256;
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
token_selected_experts_base, num_tokens, mK, mNumExperts);
sync_check_cuda_error(stream);
if (use_customized_router)
{
// copy token selected experts to token_selected_experts_base
cudaMemcpyAsync(token_selected_experts_base, token_selected_experts_customized,
num_tokens * mK * sizeof(int), cudaMemcpyDeviceToDevice, stream);
}
else
{
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
token_selected_experts_base, num_tokens, mK, mNumExperts);
sync_check_cuda_error(stream);
}
for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
{
@ -4726,15 +4736,16 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
}
}
void GemmProfilerBackend::prepare(
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
void const* token_selected_experts_customized, bool use_customized_router, cudaStream_t stream)
{
mSampleIndex = 0;
mIsCustomizedRouter = use_customized_router;
auto workspace_size = getWorkspaceSize(num_tokens);
populateRandomBuffer(workspace_ptr_char, workspace_size, stream);
prepareRouting(num_tokens, workspace_ptr_char, stream);
prepareRouting(num_tokens, workspace_ptr_char, token_selected_experts_customized, use_customized_router, stream);
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE})
@ -4762,7 +4773,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
int64_t expanded_num_tokens = original_num_tokens * mK;
int64_t num_experts_per_node = mNumExpertsPerNode;
mSampleIndex = (mSampleIndex + 1) % NUM_ROUTING_SAMPLES;
mSampleIndex = mIsCustomizedRouter ? 0 : (mSampleIndex + 1) % NUM_ROUTING_SAMPLES;
auto workspaces = getProfilerWorkspaces(original_num_tokens, tactic.is_tma_warp_specialized);

View File

@ -1287,7 +1287,8 @@ void MixtureOfExpertsGemmProfiler::initTmpData(
int m, int n, int k, char* workspace, size_t ws_size, cudaStream_t stream)
{
checkInit();
backend.prepare(m, workspace, /*expert_weights*/ nullptr, stream);
backend.prepare(m, workspace, /*expert_weights*/ nullptr, /*token_selected_experts_customized*/ nullptr,
/*use_customized_router*/ false, stream);
}
void MixtureOfExpertsGemmProfiler::checkInit()

View File

@ -666,13 +666,13 @@ public:
}
// TODO Update this to be able to tell if we are profiling swiglu bias
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
int64_t const unpadded_hidden_size)
void runGemmProfile(torch::Tensor const& input, torch::optional<torch::Tensor> const& token_final_scales,
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
int64_t const top_k, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode,
int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
int64_t const unpadded_hidden_size, bool const use_customized_router)
{
std::lock_guard<std::mutex> lock(mMutex);
@ -752,7 +752,10 @@ public:
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream);
void const* token_selected_experts_customized
= token_final_scales.has_value() ? token_final_scales.value().const_data_ptr() : nullptr;
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, token_selected_experts_customized,
use_customized_router, stream);
}
// Profile specific tactic. Assuming at least one preparation phase has been executed already.

View File

@ -2569,7 +2569,8 @@ TYPED_TEST(MixtureOfExpertsTest, RunProfiler)
for (int64_t num_tokens : {1, 128})
{
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, this->mStream->get());
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr,
/*token_selected_experts_customized=*/nullptr, /*use_customized_router=*/false, this->mStream->get());
for (auto const& tactic : this->getAllTileConfigsToTest())
{
backend.runProfiler(num_tokens,
@ -2616,7 +2617,8 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
auto workspace = this->allocBuffer<char>(ws_size);
int64_t num_experts_per_node = num_experts / ep;
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, mStream->get());
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr,
/*token_selected_experts_customized=*/nullptr, /*use_customized_router=*/false, mStream->get());
auto workspaces = backend.getProfilerWorkspaces(num_tokens, getSMVersion() >= 90 && getSMVersion() < 120);
#define GET_WS_PTR(type, name) auto* name = reinterpret_cast<type>(workspace + workspaces.at(#name).second)

View File

@ -7,6 +7,8 @@ import triton # type: ignore[import]
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm import deep_gemm
from tensorrt_llm._torch.modules.fused_moe.routing import (
ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType)
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
from tensorrt_llm.logger import logger
@ -30,6 +32,74 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
torch.bmm(a, b, out=out)
def prepare_dummy_token_selected_experts_hook(
input: torch.Tensor,
top_k: int,
num_experts: int,
n_group: Optional[int],
topk_group: Optional[int],
routed_scaling_factor: Optional[float],
routing_method_type: int = int(RoutingMethodType.Default),
):
"""
Creates a hook function that generates dummy token_selected_experts for tuning.
Args:
input: Input tensor to determine shape and device
top_k: Number of experts per token
num_experts: Total number of experts
routing_method_type: Type of routing method to use
Returns:
A hook function that can be used with the tuner
"""
tuner = AutoTuner.get()
if not tuner.is_tuning_mode:
return lambda inputs: inputs
input_tensor = input[0]
# Get routing method
routing_cls_kwargs = {}
if routing_method_type == int(RoutingMethodType.DeepSeekV3):
routing_cls_kwargs.update({
'n_group':
n_group,
'topk_group':
topk_group,
'routed_scaling_factor':
routed_scaling_factor,
'is_fused':
False,
'callable_e_score_correction_bias':
lambda: torch.randn(
num_experts, dtype=torch.bfloat16, device=input_tensor.device)
})
routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type](
top_k=top_k, **routing_cls_kwargs)
def create_dummy_token_selected_experts(
inputs: List[torch.Tensor], ) -> List[torch.Tensor]:
input_tensor = inputs[0] # First tensor is the input
# Generate dummy routing logits with correct shape
routing_logits_for_tuner = torch.randn(input_tensor.shape[0],
num_experts,
dtype=torch.bfloat16,
device=input_tensor.device)
# Apply routing to get properly shaped token_selected_experts
topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
routing_logits_for_tuner)
# Replace the token_selected_experts tensor (inputs[1]) with our generated one
if len(inputs) > 1:
inputs[1] = topk_ids_for_tuner
return inputs
return create_dummy_token_selected_experts
class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict = dict()
@ -37,7 +107,9 @@ class MoERunner(TunableRunner):
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
tune_max_num_tokens=8192,
inputs_pre_hook=None, # Will be set dynamically in fused_moe function
distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL,
)
@ -82,6 +154,7 @@ class MoERunner(TunableRunner):
self.use_fused_finalize = use_fused_finalize
self.activation_type = activation_type
self.unpadded_hidden_size = unpadded_hidden_size if unpadded_hidden_size is not None else 0
self.use_customized_router = False
instance_key = (x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4_group_scaling,
@ -126,10 +199,12 @@ class MoERunner(TunableRunner):
gemm_idx: int = 0,
tactic: int = -1,
do_preparation: bool = False,
**kwargs,
):
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
x, token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
self.fused_moe_runner.run_gemm_profile(
x,
token_selected_experts,
fc1_expert_weights,
fc1_expert_biases,
fc2_expert_weights,
@ -148,6 +223,7 @@ class MoERunner(TunableRunner):
do_preparation,
self.activation_type,
self.unpadded_hidden_size,
self.use_customized_router,
)
@ -184,6 +260,10 @@ def fused_moe(
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
activation_type: int = int(ActivationType.Swiglu),
routing_method_type: int = int(RoutingMethodType.Default),
n_group: Optional[int] = None,
topk_group: Optional[int] = None,
routed_scaling_factor: Optional[float] = None,
unpadded_hidden_size: Optional[int] = None,
out_tensor: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
@ -201,6 +281,18 @@ def fused_moe(
tuner_input = input
tuner_top_k = token_selected_experts.size(1)
tuning_config = MoERunner.tuning_config
tuning_config.inputs_pre_hook = prepare_dummy_token_selected_experts_hook(
tuner_input,
tuner_top_k,
fc1_expert_weights.shape[0] *
ep_size, # num_experts from weight tensor shape
n_group,
topk_group,
routed_scaling_factor,
routing_method_type,
)
# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,
@ -224,27 +316,30 @@ def fused_moe(
)
MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
input_tensors = [
tuner_input,
token_selected_experts,
fc1_expert_weights,
fc1_expert_biases,
fc2_expert_weights,
fc2_expert_biases,
]
_, gemm_tactic_1 = tuner.choose_one(
"trtllm::fused_moe::gemm1",
[moe_runner],
MoERunner.tuning_config,
[
tuner_input, fc1_expert_weights, fc1_expert_biases,
fc2_expert_weights, fc2_expert_biases
],
input_tensors,
gemm_idx=1,
ep_size=ep_size,
)
_, gemm_tactic_2 = tuner.choose_one(
"trtllm::fused_moe::gemm2",
[moe_runner],
MoERunner.tuning_config,
[
tuner_input, fc1_expert_weights, fc1_expert_biases,
fc2_expert_weights, fc2_expert_biases
],
input_tensors,
gemm_idx=2,
ep_size=ep_size,
)
run_moe = moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else moe_runner.fused_moe_runner.run_moe

View File

@ -24,7 +24,7 @@ from .quantization import (
W4A8MXFP4MXFP8CutlassFusedMoEMethod, WFP4A16FusedMoEMethod,
WInt4AFP8FusedMoEMethod)
# isort: on
from .routing import BaseMoeRoutingMethod
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
class CutlassFusedMoE(MoE):
@ -433,6 +433,15 @@ class CutlassFusedMoE(MoE):
elif self.has_w4a16_mxfp4:
weight_dtype = torch.uint8
if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod):
n_group = self.routing_method.routing_impl.n_group
topk_group = self.routing_method.routing_impl.topk_group
routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor
else:
n_group = None
topk_group = None
routed_scaling_factor = None
final_hidden_states = torch.ops.trtllm.fused_moe(
x,
token_selected_experts,
@ -465,6 +474,10 @@ class CutlassFusedMoE(MoE):
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
activation_type=self.activation_type,
routing_method_type=self.routing_method.routing_method_type,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
unpadded_hidden_size=self.unpadded_hidden_size,
out_tensor=moe_output,
)