[https://nvbugs/5443053][fix] Disable finalize fusion when Lora is used

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
This commit is contained in:
Jiagan Cheng 2025-08-31 18:28:09 -07:00 committed by Xiwen Yu
parent 0fb835d7c2
commit 8d5a7ea5b3
3 changed files with 6 additions and 6 deletions

View File

@ -334,8 +334,9 @@ void MixtureOfExpertsPlugin::init()
static_cast<int>(mType), static_cast<int>(mWeightType), static_cast<int>(mOutputType));
}
// Finalize fusion should be disabled if Lora is used.
mMOERunner->use_fused_finalize_
= (mExpertsPerToken < 3 || !mUseDeterministicKernels) && !getEnvMOEDisableFinalizeFusion();
= (mExpertsPerToken < 3 || !mUseDeterministicKernels) && !getEnvMOEDisableFinalizeFusion() && !hasLora();
mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
@ -535,9 +536,9 @@ void MixtureOfExpertsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc c
}
mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode};
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode};
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
if (hasLora())
{

View File

@ -150,7 +150,8 @@ class CutlassFusedMoE(MoE):
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_fused_finalize = not model_config.moe_disable_finalize_fusion
# Finalize fusion should be disabled if Lora is used.
self.use_fused_finalize = not model_config.moe_disable_finalize_fusion and model_config.lora_config is None
self._weights_created = False
if not model_config.skip_create_weights_in_init:

View File

@ -1020,8 +1020,6 @@ class TestMoE(unittest.TestCase):
product(["float16", "bfloat16", "int4", "int8"], ["gelu", "geglu"],
[True], [32, 64])),
name_func=unittest_name_func)
@pytest.mark.skip(
"https://nvbugswb.nvidia.com/NVBugs5/redir.aspx?url=/5443053")
def test_mlp_lora_comparison(self, dtype_str, actfn, use_plugin, lora_rank):
"""This test uses one expert and compares the result to a plain MLP"""
torch.random.manual_seed(42)