Cherry-pick: Use multi-threading to load MoE expert weights (#4137)

* Use multi-threading to load MoE expert weights

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>

* Update code formatting

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Update code formatting

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

---------

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Co-authored-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
chenfeiz0326 2025-05-09 17:29:24 +08:00 committed by GitHub
parent 0f01826dde
commit ffc13bd325
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import math
import os
import threading
from enum import Enum
from typing import Dict, List, NamedTuple, Optional, Union
@ -1167,6 +1168,12 @@ class FusedMoE(nn.Module):
epilogue_tile_m)
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype))
# Use multi-threading to load expert weights in parallel.
# Even though CPython has global interpreter lock (GIL),
# it's still faster to load weights in parallel because it can utilize
# CPU memory bandwidth better.
threads = []
for expert_id in range(self.expert_start, self.expert_end):
expert_idx = expert_id - self.expert_start
@ -1187,11 +1194,23 @@ class FusedMoE(nn.Module):
is_trtllm_nvfp4 = self.is_trtllm(
) and self.quant_config.quant_mode.has_nvfp4()
load_expert_w3_w1_weight(w1_weight, w3_weight,
self.w3_w1_weight.data[expert_idx],
is_trtllm_nvfp4)
load_expert_w2_weight(w2_weight, self.w2_weight.data[expert_idx],
is_trtllm_nvfp4)
thread = threading.Thread(target=load_expert_w3_w1_weight,
args=(w1_weight, w3_weight,
self.w3_w1_weight.data[expert_idx],
is_trtllm_nvfp4))
thread.start()
threads.append(thread)
thread = threading.Thread(target=load_expert_w2_weight,
args=(w2_weight,
self.w2_weight.data[expert_idx],
is_trtllm_nvfp4))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):