[https://nvbugs/5412562][feat] Allocate MoE workspace only when necessary (release/1.0 retargeted) (#6955)

Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
Yilin Fan 2025-08-17 17:50:35 -07:00 committed by GitHub
parent 33fce8ece5
commit 7f7a301f6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -308,8 +308,8 @@ public:
std::vector<int64_t> output_shape = {num_rows, hidden_size}; std::vector<int64_t> output_shape = {num_rows, hidden_size};
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype)); auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode); static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
kernels::MoeMinLatencyParams min_latency_params{}; kernels::MoeMinLatencyParams min_latency_params{};
@ -439,8 +439,8 @@ public:
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr()); min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr()); min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode); static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
@ -577,6 +577,7 @@ private:
// e.g. 16 nvfp4 elements are packed into a single int64 element // e.g. 16 nvfp4 elements are packed into a single int64 element
int64_t mInnerDimMultiplier; int64_t mInnerDimMultiplier;
char* mProfileWorkspace = nullptr; char* mProfileWorkspace = nullptr;
WorkspaceInfo workspace_info;
bool mUseDeepSeekFP8BlockScaling = false; bool mUseDeepSeekFP8BlockScaling = false;
bool mUseW4A8GroupScaling = false; bool mUseW4A8GroupScaling = false;
@ -622,9 +623,9 @@ private:
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile); mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
} }
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int num_experts, int experts_per_token, ActivationType activation_type, int num_experts, int experts_per_token, ActivationType activation_type,
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode) kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
{ {
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
@ -633,15 +634,29 @@ private:
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size}; std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
WorkspaceInfo info{}; bool is_capturing = tensorrt_llm::common::isCapturing(stream);
info.workspace = torch::empty({static_cast<long>(total_workspace_size)}, // Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
info.src_to_dest_map {
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size); if (is_capturing)
{
TLLM_LOG_DEBUG(
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
}
else
{
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
workspace_info.workspace.numel(), total_workspace_size);
}
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
}
workspace_info.src_to_dest_map
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
return info; return workspace_info;
} }
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size, kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,