mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5412562][feat] Allocate MoE workspace only when necessary (release/1.0 retargeted) (#6955)
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
parent
33fce8ece5
commit
7f7a301f6e
@ -308,8 +308,8 @@ public:
|
||||
std::vector<int64_t> output_shape = {num_rows, hidden_size};
|
||||
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
|
||||
|
||||
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
|
||||
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
|
||||
|
||||
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
||||
kernels::MoeMinLatencyParams min_latency_params{};
|
||||
@ -439,8 +439,8 @@ public:
|
||||
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
|
||||
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
|
||||
|
||||
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
|
||||
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
|
||||
|
||||
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
||||
|
||||
@ -577,6 +577,7 @@ private:
|
||||
// e.g. 16 nvfp4 elements are packed into a single int64 element
|
||||
int64_t mInnerDimMultiplier;
|
||||
char* mProfileWorkspace = nullptr;
|
||||
WorkspaceInfo workspace_info;
|
||||
|
||||
bool mUseDeepSeekFP8BlockScaling = false;
|
||||
bool mUseW4A8GroupScaling = false;
|
||||
@ -622,9 +623,9 @@ private:
|
||||
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
|
||||
}
|
||||
|
||||
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int num_experts, int experts_per_token, ActivationType activation_type,
|
||||
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
|
||||
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
|
||||
{
|
||||
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
|
||||
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
|
||||
@ -633,15 +634,29 @@ private:
|
||||
|
||||
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
|
||||
|
||||
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
||||
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
||||
|
||||
WorkspaceInfo info{};
|
||||
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
info.src_to_dest_map
|
||||
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
|
||||
bool is_capturing = tensorrt_llm::common::isCapturing(stream);
|
||||
// Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
|
||||
if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
|
||||
{
|
||||
if (is_capturing)
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
|
||||
workspace_info.workspace.numel(), total_workspace_size);
|
||||
}
|
||||
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
}
|
||||
workspace_info.src_to_dest_map
|
||||
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
|
||||
|
||||
return info;
|
||||
return workspace_info;
|
||||
}
|
||||
|
||||
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user