mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6835][fix] Fix potential hang caused by python multiprocessing when prefetching weights (#6927)
Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
261ffacfa4
commit
704fca4178
@ -1,6 +1,7 @@
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List
|
||||
|
||||
import psutil
|
||||
@ -128,7 +129,7 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
if len(local_file_names) == 0:
|
||||
return
|
||||
|
||||
max_processes = min(multiprocessing.cpu_count() * 2, 16,
|
||||
len(local_file_names))
|
||||
with multiprocessing.Pool(processes=max_processes) as pool:
|
||||
pool.map(self._prefetch_one_file, local_file_names)
|
||||
max_workers = min(multiprocessing.cpu_count() * 2, 16,
|
||||
len(local_file_names))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
list(executor.map(self._prefetch_one_file, local_file_names))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user