[https://nvbugs/5782112][fix] Cherry-pick #10633: Fix hanging issue for MNNVL Allreduce under PP (#10750)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-01-19 13:41:56 +08:00 committed by Yanchao Lu
parent 2b3bb2e9b0
commit 25bdc30162

View File

@ -515,7 +515,7 @@ class MNNVLAllReduce(nn.Module):
This class handles the MNNVL-specific allreduce operations, which can be more efficient
for certain operations when using NVLink for multi-node communication.
"""
allreduce_mnnvl_workspaces: Dict[int, Dict] = {}
allreduce_mnnvl_workspaces: Dict[Mapping, Dict] = {}
def __init__(self, mapping: Mapping, dtype: torch.dtype):
super().__init__()