From 704fca41789494b70de4c00b38952ac0ce282ac0 Mon Sep 17 00:00:00 2001 From: Liao Lanyu <108499334+lancelly@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:20:09 +0800 Subject: [PATCH] [TRTLLM-6835][fix] Fix potential hang caused by python multiprocessing when prefetching weights (#6927) Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- .../_torch/models/checkpoints/hf/weight_loader.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index f3992dab78..d6660c95a7 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -1,6 +1,7 @@ import glob import multiprocessing import os +from concurrent.futures import ThreadPoolExecutor from typing import Any, List import psutil @@ -128,7 +129,7 @@ class HfWeightLoader(BaseWeightLoader): if len(local_file_names) == 0: return - max_processes = min(multiprocessing.cpu_count() * 2, 16, - len(local_file_names)) - with multiprocessing.Pool(processes=max_processes) as pool: - pool.map(self._prefetch_one_file, local_file_names) + max_workers = min(multiprocessing.cpu_count() * 2, 16, + len(local_file_names)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + list(executor.map(self._prefetch_one_file, local_file_names))