mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Cherry-pick: Use multi-threading to load MoE expert weights (#4137)
* Use multi-threading to load MoE expert weights Signed-off-by: Po-Han Huang <pohanh@nvidia.com> * Update code formatting Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> * Update code formatting Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> --------- Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> Co-authored-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
0f01826dde
commit
ffc13bd325
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
@ -1167,6 +1168,12 @@ class FusedMoE(nn.Module):
|
||||
epilogue_tile_m)
|
||||
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype))
|
||||
|
||||
# Use multi-threading to load expert weights in parallel.
|
||||
# Even though CPython has global interpreter lock (GIL),
|
||||
# it's still faster to load weights in parallel because it can utilize
|
||||
# CPU memory bandwidth better.
|
||||
threads = []
|
||||
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
expert_idx = expert_id - self.expert_start
|
||||
|
||||
@ -1187,11 +1194,23 @@ class FusedMoE(nn.Module):
|
||||
|
||||
is_trtllm_nvfp4 = self.is_trtllm(
|
||||
) and self.quant_config.quant_mode.has_nvfp4()
|
||||
load_expert_w3_w1_weight(w1_weight, w3_weight,
|
||||
self.w3_w1_weight.data[expert_idx],
|
||||
is_trtllm_nvfp4)
|
||||
load_expert_w2_weight(w2_weight, self.w2_weight.data[expert_idx],
|
||||
is_trtllm_nvfp4)
|
||||
|
||||
thread = threading.Thread(target=load_expert_w3_w1_weight,
|
||||
args=(w1_weight, w3_weight,
|
||||
self.w3_w1_weight.data[expert_idx],
|
||||
is_trtllm_nvfp4))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
thread = threading.Thread(target=load_expert_w2_weight,
|
||||
args=(w2_weight,
|
||||
self.w2_weight.data[expert_idx],
|
||||
is_trtllm_nvfp4))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user