diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index ba4703875e..bc90849c24 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -1,6 +1,7 @@ import glob import multiprocessing import os +from concurrent.futures import ThreadPoolExecutor from typing import Any, List import psutil @@ -120,7 +121,7 @@ class HfWeightLoader(BaseWeightLoader): if len(local_file_names) == 0: return - max_processes = min(multiprocessing.cpu_count() * 2, 16, - len(local_file_names)) - with multiprocessing.Pool(processes=max_processes) as pool: - pool.map(self._prefetch_one_file, local_file_names) + max_workers = min(multiprocessing.cpu_count() * 2, 16, + len(local_file_names)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + list(executor.map(self._prefetch_one_file, local_file_names))