diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 711968a92e..84468dc612 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -515,7 +515,7 @@ class MNNVLAllReduce(nn.Module): This class handles the MNNVL-specific allreduce operations, which can be more efficient for certain operations when using NVLink for multi-node communication. """ - allreduce_mnnvl_workspaces: Dict[int, Dict] = {} + allreduce_mnnvl_workspaces: Dict[Mapping, Dict] = {} def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__()