mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5498967][fix] Downgrade NCCL (#7556)
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
2d5f0e1038
commit
4658b778ef
@ -83,6 +83,8 @@ communicator* UserBufferAllocator::comm()
|
|||||||
return mUbComm;
|
return mUbComm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||||
|
|
||||||
void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||||
{
|
{
|
||||||
if (!isInitialized())
|
if (!isInitialized())
|
||||||
@ -244,6 +246,18 @@ bool NCCLHelper::isLoaded() const
|
|||||||
return mIsLoaded;
|
return mIsLoaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(false,
|
||||||
|
"NCCL symmetric is not supported for nccl version < 2.27. Please upgrade nccl to 2.27 or higher and rebuild "
|
||||||
|
"tensorrt_llm or disable nccl symmetric");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
bool UserBufferAllocator::use_nccl_symmetric = false;
|
bool UserBufferAllocator::use_nccl_symmetric = false;
|
||||||
|
|
||||||
}; // namespace tensorrt_llm::runtime::ub
|
}; // namespace tensorrt_llm::runtime::ub
|
||||||
|
|||||||
@ -35,13 +35,22 @@ struct UBBuffer
|
|||||||
void* addr;
|
void* addr;
|
||||||
int handle;
|
int handle;
|
||||||
size_t size;
|
size_t size;
|
||||||
|
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||||
ncclWindow_t window;
|
ncclWindow_t window;
|
||||||
|
#endif
|
||||||
|
|
||||||
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
|
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0
|
||||||
|
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||||
|
,
|
||||||
|
ncclWindow_t w = nullptr
|
||||||
|
#endif
|
||||||
|
)
|
||||||
: addr(a)
|
: addr(a)
|
||||||
, handle(h)
|
, handle(h)
|
||||||
, size(s)
|
, size(s)
|
||||||
|
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||||
, window(w)
|
, window(w)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,6 +86,8 @@ protected:
|
|||||||
tensorrt_llm::runtime::WorldConfig mWorldConfig;
|
tensorrt_llm::runtime::WorldConfig mWorldConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||||
|
|
||||||
class NCCLHelper
|
class NCCLHelper
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -125,6 +136,13 @@ private:
|
|||||||
std::shared_ptr<ncclComm_t> mComm;
|
std::shared_ptr<ncclComm_t> mComm;
|
||||||
static std::unique_ptr<NCCLHelper> mNCCLHelper;
|
static std::unique_ptr<NCCLHelper> mNCCLHelper;
|
||||||
};
|
};
|
||||||
|
#else
|
||||||
|
class NCCLUserBufferAllocator : public UserBufferAllocator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
using communicator = void;
|
using communicator = void;
|
||||||
|
|||||||
@ -9,9 +9,8 @@ CUDA_VER="12.9" # 12.9.1
|
|||||||
# Keep the installation for cuDNN if users want to install PyTorch with source codes.
|
# Keep the installation for cuDNN if users want to install PyTorch with source codes.
|
||||||
# PyTorch 2.x can compile with cuDNN v9.
|
# PyTorch 2.x can compile with cuDNN v9.
|
||||||
CUDNN_VER="9.10.2.21-1"
|
CUDNN_VER="9.10.2.21-1"
|
||||||
# NGC PyTorch 25.06 image uses NCCL 2.27.3, while NCCL 2.27.5 resolves a perf regression issue.
|
# Downgrade NCCL version to 2.25.1 temporarily
|
||||||
# Use NCCL version 2.27.5 instead.
|
NCCL_VER="2.25.1-1+cuda12.8"
|
||||||
NCCL_VER="2.27.5-1+cuda12.9"
|
|
||||||
CUBLAS_VER="12.9.1.4-1"
|
CUBLAS_VER="12.9.1.4-1"
|
||||||
# Align with the pre-installed CUDA / NVCC / NVRTC versions from
|
# Align with the pre-installed CUDA / NVCC / NVRTC versions from
|
||||||
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
||||||
|
|||||||
@ -12,7 +12,7 @@
|
|||||||
# NB: Typically, the suffix indicates the PR whose CI pipeline generated the images. In case that
|
# NB: Typically, the suffix indicates the PR whose CI pipeline generated the images. In case that
|
||||||
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
|
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
|
||||||
IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm
|
IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm
|
||||||
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508201630-pre-test
|
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202509051530-7556
|
||||||
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508201630-pre-test
|
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202509051530-7556
|
||||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202508201630-pre-test
|
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202509051530-7556
|
||||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202508201630-pre-test
|
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202509051530-7556
|
||||||
|
|||||||
@ -2334,6 +2334,8 @@ class PyTorchModelEngine(ModelEngine):
|
|||||||
if not ub.ub_supported():
|
if not ub.ub_supported():
|
||||||
return False
|
return False
|
||||||
use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
|
use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
|
||||||
|
if use_nccl_symmetric:
|
||||||
|
return False
|
||||||
ub.initialize_userbuffers_manager(
|
ub.initialize_userbuffers_manager(
|
||||||
self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size,
|
self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size,
|
||||||
self.mapping.rank, self.mapping.gpus_per_node,
|
self.mapping.rank, self.mapping.gpus_per_node,
|
||||||
|
|||||||
@ -184,6 +184,9 @@ def row_linear_residual_norm_fusion_forward(
|
|||||||
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, strategy,
|
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, strategy,
|
||||||
fusion):
|
fusion):
|
||||||
|
|
||||||
|
if strategy == AllReduceStrategy.NCCL_SYMMETRIC:
|
||||||
|
pytest.skip("NCCL symmetric is not supported for nccl version < 2.27.")
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
tensor_parallel_size = 2
|
tensor_parallel_size = 2
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user