mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 5c376ea7b3 into 6df2c8a074
This commit is contained in:
commit
6eb792a901
@ -978,7 +978,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
|
||||
void prepare(int num_tokens, char* workspace, void const* expert_weights,
|
||||
void const* token_selected_experts_customized = nullptr, bool use_customized_router = false,
|
||||
cudaStream_t stream = nullptr);
|
||||
|
||||
std::map<std::string, std::pair<size_t, size_t>> getProfilerWorkspaces(int maxM, bool is_tma_ws);
|
||||
size_t getWorkspaceSize(int maxM);
|
||||
@ -1002,6 +1004,7 @@ public:
|
||||
bool mEnableAlltoall = false;
|
||||
|
||||
int mSampleIndex = 0;
|
||||
bool mIsCustomizedRouter = false;
|
||||
|
||||
nvinfer1::DataType mDType{};
|
||||
nvinfer1::DataType mWType{};
|
||||
@ -1024,8 +1027,9 @@ public:
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType mScalingType{};
|
||||
|
||||
private:
|
||||
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
|
||||
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
|
||||
void prepareRouting(int num_tokens, char* workspace, void const* token_selected_experts_customized,
|
||||
bool use_customized_router, cudaStream_t stream);
|
||||
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
|
||||
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, cudaStream_t stream);
|
||||
};
|
||||
|
||||
@ -4472,7 +4472,8 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
|
||||
return out_map;
|
||||
}
|
||||
|
||||
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream)
|
||||
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char,
|
||||
void const* token_selected_experts_customized, bool use_customized_router, cudaStream_t stream)
|
||||
{
|
||||
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
|
||||
#define GET_WS_PTR_BASE(type, name) \
|
||||
@ -4513,10 +4514,19 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha
|
||||
int const start_expert_id = mNumExpertsPerNode * mParallelismConfig.ep_rank;
|
||||
|
||||
uint32_t num_threads = 256;
|
||||
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
|
||||
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
|
||||
token_selected_experts_base, num_tokens, mK, mNumExperts);
|
||||
sync_check_cuda_error(stream);
|
||||
if (use_customized_router)
|
||||
{
|
||||
// copy token selected experts to token_selected_experts_base
|
||||
cudaMemcpyAsync(token_selected_experts_base, token_selected_experts_customized,
|
||||
num_tokens * mK * sizeof(int), cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
|
||||
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
|
||||
token_selected_experts_base, num_tokens, mK, mNumExperts);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
|
||||
{
|
||||
@ -4726,15 +4736,16 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
|
||||
}
|
||||
}
|
||||
|
||||
void GemmProfilerBackend::prepare(
|
||||
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
|
||||
void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
|
||||
void const* token_selected_experts_customized, bool use_customized_router, cudaStream_t stream)
|
||||
{
|
||||
mSampleIndex = 0;
|
||||
mIsCustomizedRouter = use_customized_router;
|
||||
|
||||
auto workspace_size = getWorkspaceSize(num_tokens);
|
||||
populateRandomBuffer(workspace_ptr_char, workspace_size, stream);
|
||||
|
||||
prepareRouting(num_tokens, workspace_ptr_char, stream);
|
||||
prepareRouting(num_tokens, workspace_ptr_char, token_selected_experts_customized, use_customized_router, stream);
|
||||
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
|
||||
for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE,
|
||||
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE})
|
||||
@ -4762,7 +4773,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
|
||||
int64_t expanded_num_tokens = original_num_tokens * mK;
|
||||
int64_t num_experts_per_node = mNumExpertsPerNode;
|
||||
|
||||
mSampleIndex = (mSampleIndex + 1) % NUM_ROUTING_SAMPLES;
|
||||
mSampleIndex = mIsCustomizedRouter ? 0 : (mSampleIndex + 1) % NUM_ROUTING_SAMPLES;
|
||||
|
||||
auto workspaces = getProfilerWorkspaces(original_num_tokens, tactic.is_tma_warp_specialized);
|
||||
|
||||
|
||||
@ -1287,7 +1287,8 @@ void MixtureOfExpertsGemmProfiler::initTmpData(
|
||||
int m, int n, int k, char* workspace, size_t ws_size, cudaStream_t stream)
|
||||
{
|
||||
checkInit();
|
||||
backend.prepare(m, workspace, /*expert_weights*/ nullptr, stream);
|
||||
backend.prepare(m, workspace, /*expert_weights*/ nullptr, /*token_selected_experts_customized*/ nullptr,
|
||||
/*use_customized_router*/ false, stream);
|
||||
}
|
||||
|
||||
void MixtureOfExpertsGemmProfiler::checkInit()
|
||||
|
||||
@ -666,13 +666,13 @@ public:
|
||||
}
|
||||
|
||||
// TODO Update this to be able to tell if we are profiling swiglu bias
|
||||
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
|
||||
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
|
||||
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
|
||||
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
|
||||
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
|
||||
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
|
||||
int64_t const unpadded_hidden_size)
|
||||
void runGemmProfile(torch::Tensor const& input, torch::optional<torch::Tensor> const& token_final_scales,
|
||||
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
|
||||
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
|
||||
int64_t const top_k, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
|
||||
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode,
|
||||
int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
|
||||
int64_t const unpadded_hidden_size, bool const use_customized_router)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
@ -752,7 +752,10 @@ public:
|
||||
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
|
||||
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
|
||||
|
||||
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream);
|
||||
void const* token_selected_experts_customized
|
||||
= token_final_scales.has_value() ? token_final_scales.value().const_data_ptr() : nullptr;
|
||||
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, token_selected_experts_customized,
|
||||
use_customized_router, stream);
|
||||
}
|
||||
|
||||
// Profile specific tactic. Assuming at least one preparation phase has been executed already.
|
||||
|
||||
@ -2569,7 +2569,8 @@ TYPED_TEST(MixtureOfExpertsTest, RunProfiler)
|
||||
|
||||
for (int64_t num_tokens : {1, 128})
|
||||
{
|
||||
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, this->mStream->get());
|
||||
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr,
|
||||
/*token_selected_experts_customized=*/nullptr, /*use_customized_router=*/false, this->mStream->get());
|
||||
for (auto const& tactic : this->getAllTileConfigsToTest())
|
||||
{
|
||||
backend.runProfiler(num_tokens,
|
||||
@ -2616,7 +2617,8 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
|
||||
auto workspace = this->allocBuffer<char>(ws_size);
|
||||
int64_t num_experts_per_node = num_experts / ep;
|
||||
|
||||
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr, mStream->get());
|
||||
backend.prepare(num_tokens, workspace, /*expert_weights=*/nullptr,
|
||||
/*token_selected_experts_customized=*/nullptr, /*use_customized_router=*/false, mStream->get());
|
||||
|
||||
auto workspaces = backend.getProfilerWorkspaces(num_tokens, getSMVersion() >= 90 && getSMVersion() < 120);
|
||||
#define GET_WS_PTR(type, name) auto* name = reinterpret_cast<type>(workspace + workspaces.at(#name).second)
|
||||
|
||||
@ -7,6 +7,8 @@ import triton # type: ignore[import]
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||
from tensorrt_llm import deep_gemm
|
||||
from tensorrt_llm._torch.modules.fused_moe.routing import (
|
||||
ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType)
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -30,6 +32,74 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.bmm(a, b, out=out)
|
||||
|
||||
|
||||
def prepare_dummy_token_selected_experts_hook(
|
||||
input: torch.Tensor,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
n_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
routed_scaling_factor: Optional[float],
|
||||
routing_method_type: int = int(RoutingMethodType.Default),
|
||||
):
|
||||
"""
|
||||
Creates a hook function that generates dummy token_selected_experts for tuning.
|
||||
|
||||
Args:
|
||||
input: Input tensor to determine shape and device
|
||||
top_k: Number of experts per token
|
||||
num_experts: Total number of experts
|
||||
routing_method_type: Type of routing method to use
|
||||
|
||||
Returns:
|
||||
A hook function that can be used with the tuner
|
||||
"""
|
||||
tuner = AutoTuner.get()
|
||||
if not tuner.is_tuning_mode:
|
||||
return lambda inputs: inputs
|
||||
|
||||
input_tensor = input[0]
|
||||
|
||||
# Get routing method
|
||||
routing_cls_kwargs = {}
|
||||
if routing_method_type == int(RoutingMethodType.DeepSeekV3):
|
||||
routing_cls_kwargs.update({
|
||||
'n_group':
|
||||
n_group,
|
||||
'topk_group':
|
||||
topk_group,
|
||||
'routed_scaling_factor':
|
||||
routed_scaling_factor,
|
||||
'is_fused':
|
||||
False,
|
||||
'callable_e_score_correction_bias':
|
||||
lambda: torch.randn(
|
||||
num_experts, dtype=torch.bfloat16, device=input_tensor.device)
|
||||
})
|
||||
routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type](
|
||||
top_k=top_k, **routing_cls_kwargs)
|
||||
|
||||
def create_dummy_token_selected_experts(
|
||||
inputs: List[torch.Tensor], ) -> List[torch.Tensor]:
|
||||
input_tensor = inputs[0] # First tensor is the input
|
||||
# Generate dummy routing logits with correct shape
|
||||
routing_logits_for_tuner = torch.randn(input_tensor.shape[0],
|
||||
num_experts,
|
||||
dtype=torch.bfloat16,
|
||||
device=input_tensor.device)
|
||||
|
||||
# Apply routing to get properly shaped token_selected_experts
|
||||
topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
|
||||
routing_logits_for_tuner)
|
||||
|
||||
# Replace the token_selected_experts tensor (inputs[1]) with our generated one
|
||||
if len(inputs) > 1:
|
||||
inputs[1] = topk_ids_for_tuner
|
||||
|
||||
return inputs
|
||||
|
||||
return create_dummy_token_selected_experts
|
||||
|
||||
|
||||
class MoERunner(TunableRunner):
|
||||
# avoid overhead of creating a new runner in forward pass
|
||||
runner_dict = dict()
|
||||
@ -37,7 +107,9 @@ class MoERunner(TunableRunner):
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
|
||||
tune_max_num_tokens=8192,
|
||||
inputs_pre_hook=None, # Will be set dynamically in fused_moe function
|
||||
distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL,
|
||||
)
|
||||
|
||||
@ -82,6 +154,7 @@ class MoERunner(TunableRunner):
|
||||
self.use_fused_finalize = use_fused_finalize
|
||||
self.activation_type = activation_type
|
||||
self.unpadded_hidden_size = unpadded_hidden_size if unpadded_hidden_size is not None else 0
|
||||
self.use_customized_router = False
|
||||
|
||||
instance_key = (x_dtype, weight_dtype, output_dtype,
|
||||
use_deepseek_fp8_block_scale, use_w4_group_scaling,
|
||||
@ -126,10 +199,12 @@ class MoERunner(TunableRunner):
|
||||
gemm_idx: int = 0,
|
||||
tactic: int = -1,
|
||||
do_preparation: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
|
||||
x, token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
|
||||
self.fused_moe_runner.run_gemm_profile(
|
||||
x,
|
||||
token_selected_experts,
|
||||
fc1_expert_weights,
|
||||
fc1_expert_biases,
|
||||
fc2_expert_weights,
|
||||
@ -148,6 +223,7 @@ class MoERunner(TunableRunner):
|
||||
do_preparation,
|
||||
self.activation_type,
|
||||
self.unpadded_hidden_size,
|
||||
self.use_customized_router,
|
||||
)
|
||||
|
||||
|
||||
@ -184,6 +260,10 @@ def fused_moe(
|
||||
tuner_num_tokens: Optional[int] = None,
|
||||
tuner_top_k: Optional[int] = None,
|
||||
activation_type: int = int(ActivationType.Swiglu),
|
||||
routing_method_type: int = int(RoutingMethodType.Default),
|
||||
n_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
unpadded_hidden_size: Optional[int] = None,
|
||||
out_tensor: Optional[torch.Tensor] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
@ -201,6 +281,18 @@ def fused_moe(
|
||||
tuner_input = input
|
||||
tuner_top_k = token_selected_experts.size(1)
|
||||
|
||||
tuning_config = MoERunner.tuning_config
|
||||
tuning_config.inputs_pre_hook = prepare_dummy_token_selected_experts_hook(
|
||||
tuner_input,
|
||||
tuner_top_k,
|
||||
fc1_expert_weights.shape[0] *
|
||||
ep_size, # num_experts from weight tensor shape
|
||||
n_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
routing_method_type,
|
||||
)
|
||||
|
||||
# allocate workspace for profiling
|
||||
moe_runner = MoERunner(
|
||||
x_dtype=input.dtype,
|
||||
@ -224,27 +316,30 @@ def fused_moe(
|
||||
)
|
||||
|
||||
MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
|
||||
|
||||
input_tensors = [
|
||||
tuner_input,
|
||||
token_selected_experts,
|
||||
fc1_expert_weights,
|
||||
fc1_expert_biases,
|
||||
fc2_expert_weights,
|
||||
fc2_expert_biases,
|
||||
]
|
||||
_, gemm_tactic_1 = tuner.choose_one(
|
||||
"trtllm::fused_moe::gemm1",
|
||||
[moe_runner],
|
||||
MoERunner.tuning_config,
|
||||
[
|
||||
tuner_input, fc1_expert_weights, fc1_expert_biases,
|
||||
fc2_expert_weights, fc2_expert_biases
|
||||
],
|
||||
input_tensors,
|
||||
gemm_idx=1,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
|
||||
_, gemm_tactic_2 = tuner.choose_one(
|
||||
"trtllm::fused_moe::gemm2",
|
||||
[moe_runner],
|
||||
MoERunner.tuning_config,
|
||||
[
|
||||
tuner_input, fc1_expert_weights, fc1_expert_biases,
|
||||
fc2_expert_weights, fc2_expert_biases
|
||||
],
|
||||
input_tensors,
|
||||
gemm_idx=2,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
|
||||
run_moe = moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else moe_runner.fused_moe_runner.run_moe
|
||||
|
||||
@ -24,7 +24,7 @@ from .quantization import (
|
||||
W4A8MXFP4MXFP8CutlassFusedMoEMethod, WFP4A16FusedMoEMethod,
|
||||
WInt4AFP8FusedMoEMethod)
|
||||
# isort: on
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
|
||||
|
||||
|
||||
class CutlassFusedMoE(MoE):
|
||||
@ -433,6 +433,15 @@ class CutlassFusedMoE(MoE):
|
||||
elif self.has_w4a16_mxfp4:
|
||||
weight_dtype = torch.uint8
|
||||
|
||||
if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod):
|
||||
n_group = self.routing_method.routing_impl.n_group
|
||||
topk_group = self.routing_method.routing_impl.topk_group
|
||||
routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor
|
||||
else:
|
||||
n_group = None
|
||||
topk_group = None
|
||||
routed_scaling_factor = None
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
token_selected_experts,
|
||||
@ -465,6 +474,10 @@ class CutlassFusedMoE(MoE):
|
||||
tuner_num_tokens=tuner_num_tokens,
|
||||
tuner_top_k=tuner_top_k,
|
||||
activation_type=self.activation_type,
|
||||
routing_method_type=self.routing_method.routing_method_type,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
unpadded_hidden_size=self.unpadded_hidden_size,
|
||||
out_tensor=moe_output,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user