From 25bdc301625a9c8037b77db2b170b0002790ee03 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Mon, 19 Jan 2026 13:41:56 +0800 Subject: [PATCH] [https://nvbugs/5782112][fix] Cherry-pick #10633: Fix hanging issue for MNNVL Allreduce under PP (#10750) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/distributed/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 711968a92e..84468dc612 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -515,7 +515,7 @@ class MNNVLAllReduce(nn.Module): This class handles the MNNVL-specific allreduce operations, which can be more efficient for certain operations when using NVLink for multi-node communication. """ - allreduce_mnnvl_workspaces: Dict[int, Dict] = {} + allreduce_mnnvl_workspaces: Dict[Mapping, Dict] = {} def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__()