diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index ba0ccd77fb..82f5a23b61 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -1,7 +1,8 @@ +from tensorrt_llm.functional import AllReduceFusionOp + from .communicator import Distributed, MPIDist, PPComm, TorchDist -from .ops import (AllReduce, AllReduceFusionOp, AllReduceParams, - AllReduceStrategy, MoEAllReduce, allgather, reducescatter, - userbuffers_allreduce_finalize) +from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, MoEAllReduce, + allgather, reducescatter, userbuffers_allreduce_finalize) __all__ = [ "allgather",