mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6835][fix] Fix potential hang caused by python multiprocessing when prefetching weights (#6927)
Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com>
This commit is contained in:
parent
7f7a301f6e
commit
d9b9b5d053
@ -1,6 +1,7 @@
|
|||||||
import glob
|
import glob
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
@ -120,7 +121,7 @@ class HfWeightLoader(BaseWeightLoader):
|
|||||||
if len(local_file_names) == 0:
|
if len(local_file_names) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
max_processes = min(multiprocessing.cpu_count() * 2, 16,
|
max_workers = min(multiprocessing.cpu_count() * 2, 16,
|
||||||
len(local_file_names))
|
len(local_file_names))
|
||||||
with multiprocessing.Pool(processes=max_processes) as pool:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
pool.map(self._prefetch_one_file, local_file_names)
|
list(executor.map(self._prefetch_one_file, local_file_names))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user