mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix: Reduce memory usage in fused moe op associated with AutoTuning and fix moe fallback issue. (#3793)
* Reduce memory usage in fused moe op associated with AutoTuning. * Replace pre-defined bucket size strategy with a generating function based on the tune_max_num_tokens. * Add free_memory logic of workspace in min_latency_mode fused moe path. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> * Fix fused_moe fallback issue. (#3652) min_latency_mode is only set to False during warmup phase. Thus when it becomes true during inference, all tactics fall back to the default one and thus cause perf regression. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --------- Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
e0691e6e27
commit
ab2f663101
@ -161,17 +161,11 @@ public:
|
||||
torch::optional<torch::Tensor> input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
|
||||
int64_t const ep_rank, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> profile_ids)
|
||||
{
|
||||
// Free the profile workspace to save memory
|
||||
if (mProfileWorkspace != nullptr)
|
||||
{
|
||||
auto const cu_free_status = cudaFree(mProfileWorkspace);
|
||||
TORCH_CHECK(
|
||||
cu_free_status == cudaSuccess, "Can't free profile workspace for MoE GEMM profile before runMoe.");
|
||||
mProfileWorkspace = nullptr;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
// Free the profile workspace to save memory
|
||||
freeProfileWorkspace();
|
||||
|
||||
CHECK_INPUT(input, mActivationDtype)
|
||||
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
|
||||
if (token_final_scales)
|
||||
@ -248,6 +242,9 @@ public:
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
// Free the profile workspace to save memory
|
||||
freeProfileWorkspace();
|
||||
|
||||
CHECK_INPUT(input, mActivationDtype)
|
||||
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
|
||||
if (token_final_scales)
|
||||
@ -375,13 +372,7 @@ public:
|
||||
hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
|
||||
min_latency_mode, parallelism_config);
|
||||
|
||||
if (mProfileWorkspace != nullptr)
|
||||
{
|
||||
auto const cu_free_status = cudaFree(mProfileWorkspace);
|
||||
TORCH_CHECK(cu_free_status == cudaSuccess,
|
||||
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
|
||||
mProfileWorkspace = nullptr;
|
||||
}
|
||||
freeProfileWorkspace();
|
||||
size_t profile_workspace_size = mProfiler->getWorkspaceSize(num_rows);
|
||||
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
|
||||
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
|
||||
@ -416,6 +407,17 @@ private:
|
||||
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
std::vector<Profile> mAllProfiles;
|
||||
|
||||
void freeProfileWorkspace()
|
||||
{
|
||||
if (mProfileWorkspace != nullptr)
|
||||
{
|
||||
auto const cu_free_status = cudaFree(mProfileWorkspace);
|
||||
TORCH_CHECK(cu_free_status == cudaSuccess,
|
||||
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
|
||||
mProfileWorkspace = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void setRunnerProfiles(torch::optional<c10::ArrayRef<int64_t>> profile_ids)
|
||||
{
|
||||
if (mUseFp8BlockScaling)
|
||||
|
||||
@ -120,13 +120,15 @@ def fused_moe(
|
||||
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
|
||||
tuning_config = TuningConfig(dynamic_tensors=(
|
||||
# input, dim 0, all valid buckets, map a seq_len to power of 2 bucket index
|
||||
(0, 0, ((16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4,
|
||||
2, 1), next_positive_power_of_2)),
|
||||
(0, 0, ((8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
|
||||
next_positive_power_of_2)),
|
||||
# min_latency_tensor, dim 0, (0 for False, 1 for True), map to it self
|
||||
(2, 0, ((0, ), lambda x: x)),
|
||||
))
|
||||
|
||||
min_latency_tensor = torch.empty(1) if min_latency_mode else torch.empty(0)
|
||||
# TODO: set min_latency_mode always to False due to the error in the moe_kernels
|
||||
min_latency_tensor = torch.empty(0)
|
||||
|
||||
# allocate workspace for profiling
|
||||
moe_runner = MoERunner(
|
||||
x_dtype=input.dtype,
|
||||
|
||||
@ -197,7 +197,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
|
||||
num_token_buckets.append(m)
|
||||
m //= 2
|
||||
|
||||
return num_token_buckets
|
||||
return tuple(num_token_buckets)
|
||||
|
||||
|
||||
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user