fix: Reduce memory usage in fused moe op associated with AutoTuning and fix moe fallback issue. (#3793)

* Reduce memory usage in fused moe op associated with AutoTuning.
* Replace pre-defined bucket size strategy with a generating function based on the tune_max_num_tokens.
* Add free_memory logic of workspace in min_latency_mode fused moe path.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Fix fused_moe fallback issue. (#3652)

min_latency_mode is only set to False during warmup phase. Thus when it becomes true during inference, all tactics fall back to the default one and thus cause perf regression.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

---------

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-04-24 10:14:26 +08:00 committed by GitHub
parent e0691e6e27
commit ab2f663101
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 20 deletions

View File

@ -161,17 +161,11 @@ public:
torch::optional<torch::Tensor> input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
int64_t const ep_rank, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> profile_ids)
{
// Free the profile workspace to save memory
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(
cu_free_status == cudaSuccess, "Can't free profile workspace for MoE GEMM profile before runMoe.");
mProfileWorkspace = nullptr;
}
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
freeProfileWorkspace();
CHECK_INPUT(input, mActivationDtype)
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
@ -248,6 +242,9 @@ public:
{
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
freeProfileWorkspace();
CHECK_INPUT(input, mActivationDtype)
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
@ -375,13 +372,7 @@ public:
hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
min_latency_mode, parallelism_config);
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(cu_free_status == cudaSuccess,
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
mProfileWorkspace = nullptr;
}
freeProfileWorkspace();
size_t profile_workspace_size = mProfiler->getWorkspaceSize(num_rows);
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
@ -416,6 +407,17 @@ private:
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;
void freeProfileWorkspace()
{
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(cu_free_status == cudaSuccess,
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
mProfileWorkspace = nullptr;
}
}
void setRunnerProfiles(torch::optional<c10::ArrayRef<int64_t>> profile_ids)
{
if (mUseFp8BlockScaling)

View File

@ -120,13 +120,15 @@ def fused_moe(
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
tuning_config = TuningConfig(dynamic_tensors=(
# input, dim 0, all valid buckets, map a seq_len to power of 2 bucket index
(0, 0, ((16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4,
2, 1), next_positive_power_of_2)),
(0, 0, ((8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
next_positive_power_of_2)),
# min_latency_tensor, dim 0, (0 for False, 1 for True), map to it self
(2, 0, ((0, ), lambda x: x)),
))
min_latency_tensor = torch.empty(1) if min_latency_mode else torch.empty(0)
# TODO: set min_latency_mode always to False due to the error in the moe_kernels
min_latency_tensor = torch.empty(0)
# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,

View File

@ -197,7 +197,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
num_token_buckets.append(m)
m //= 2
return num_token_buckets
return tuple(num_token_buckets)
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: