mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5412562][feat] Allocate MoE workspace only when necessary (release/1.0 retargeted) (#6955)
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
parent
33fce8ece5
commit
7f7a301f6e
@ -308,8 +308,8 @@ public:
|
|||||||
std::vector<int64_t> output_shape = {num_rows, hidden_size};
|
std::vector<int64_t> output_shape = {num_rows, hidden_size};
|
||||||
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
|
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
|
||||||
|
|
||||||
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||||
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
|
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
|
||||||
|
|
||||||
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
||||||
kernels::MoeMinLatencyParams min_latency_params{};
|
kernels::MoeMinLatencyParams min_latency_params{};
|
||||||
@ -439,8 +439,8 @@ public:
|
|||||||
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
|
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
|
||||||
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
|
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
|
||||||
|
|
||||||
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||||
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
|
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
|
||||||
|
|
||||||
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
||||||
|
|
||||||
@ -577,6 +577,7 @@ private:
|
|||||||
// e.g. 16 nvfp4 elements are packed into a single int64 element
|
// e.g. 16 nvfp4 elements are packed into a single int64 element
|
||||||
int64_t mInnerDimMultiplier;
|
int64_t mInnerDimMultiplier;
|
||||||
char* mProfileWorkspace = nullptr;
|
char* mProfileWorkspace = nullptr;
|
||||||
|
WorkspaceInfo workspace_info;
|
||||||
|
|
||||||
bool mUseDeepSeekFP8BlockScaling = false;
|
bool mUseDeepSeekFP8BlockScaling = false;
|
||||||
bool mUseW4A8GroupScaling = false;
|
bool mUseW4A8GroupScaling = false;
|
||||||
@ -622,9 +623,9 @@ private:
|
|||||||
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
|
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
|
||||||
}
|
}
|
||||||
|
|
||||||
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||||
int num_experts, int experts_per_token, ActivationType activation_type,
|
int num_experts, int experts_per_token, ActivationType activation_type,
|
||||||
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
|
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
|
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
|
||||||
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
|
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
|
||||||
@ -633,15 +634,29 @@ private:
|
|||||||
|
|
||||||
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
|
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
|
||||||
|
|
||||||
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
||||||
|
|
||||||
WorkspaceInfo info{};
|
bool is_capturing = tensorrt_llm::common::isCapturing(stream);
|
||||||
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
// Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
|
||||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
|
||||||
info.src_to_dest_map
|
{
|
||||||
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
|
if (is_capturing)
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG(
|
||||||
|
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
|
||||||
|
workspace_info.workspace.numel(), total_workspace_size);
|
||||||
|
}
|
||||||
|
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||||
|
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||||
|
}
|
||||||
|
workspace_info.src_to_dest_map
|
||||||
|
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
|
||||||
|
|
||||||
return info;
|
return workspace_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
|
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user