From ffc13bd3253539b00334548336beb0ca694d6530 Mon Sep 17 00:00:00 2001 From: chenfeiz0326 Date: Fri, 9 May 2025 17:29:24 +0800 Subject: [PATCH] Cherry-pick: Use multi-threading to load MoE expert weights (#4137) * Use multi-threading to load MoE expert weights Signed-off-by: Po-Han Huang * Update code formatting Signed-off-by: Chenfei Zhang * Update code formatting Signed-off-by: Chenfei Zhang --------- Signed-off-by: Po-Han Huang Signed-off-by: Chenfei Zhang Co-authored-by: Po-Han Huang --- tensorrt_llm/_torch/modules/fused_moe.py | 29 ++++++++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index 7007f471ac..0adbaca045 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -1,5 +1,6 @@ import math import os +import threading from enum import Enum from typing import Dict, List, NamedTuple, Optional, Union @@ -1167,6 +1168,12 @@ class FusedMoE(nn.Module): epilogue_tile_m) dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype)) + # Use multi-threading to load expert weights in parallel. + # Even though CPython has global interpreter lock (GIL), + # it's still faster to load weights in parallel because it can utilize + # CPU memory bandwidth better. + threads = [] + for expert_id in range(self.expert_start, self.expert_end): expert_idx = expert_id - self.expert_start @@ -1187,11 +1194,23 @@ class FusedMoE(nn.Module): is_trtllm_nvfp4 = self.is_trtllm( ) and self.quant_config.quant_mode.has_nvfp4() - load_expert_w3_w1_weight(w1_weight, w3_weight, - self.w3_w1_weight.data[expert_idx], - is_trtllm_nvfp4) - load_expert_w2_weight(w2_weight, self.w2_weight.data[expert_idx], - is_trtllm_nvfp4) + + thread = threading.Thread(target=load_expert_w3_w1_weight, + args=(w1_weight, w3_weight, + self.w3_w1_weight.data[expert_idx], + is_trtllm_nvfp4)) + thread.start() + threads.append(thread) + + thread = threading.Thread(target=load_expert_w2_weight, + args=(w2_weight, + self.w2_weight.data[expert_idx], + is_trtllm_nvfp4)) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() if self.quant_config and self.quant_config.quant_mode.has_any_quant( exclude_kv_cache=True):