/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/runtime/utils/mpiTags.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "cuda.h" #include #include #include #include #include #include TRTLLM_NAMESPACE_BEGIN #if ENABLE_MULTI_DEVICE std::unordered_map* getDtypeMap() { static std::unordered_map dtypeMap = { {nvinfer1::DataType::kFLOAT, ncclFloat32}, {nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}, {nvinfer1::DataType::kFP8, ncclInt8}, {nvinfer1::DataType::kBOOL, ncclInt8}, {nvinfer1::DataType::kINT32, ncclInt32}, {nvinfer1::DataType::kINT64, ncclInt64}, {nvinfer1::DataType::kUINT8, ncclUint8}, {nvinfer1::DataType::kINT8, ncclInt8}, }; return &dtypeMap; } namespace { // Get NCCL unique ID for a group of ranks. ncclUniqueId getUniqueId(std::set const& group) { auto const rank = COMM_SESSION.getRank(); TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); ncclUniqueId id; if (rank == *group.begin()) { NCCLCHECK_THROW(ncclGetUniqueId(&id)); for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); } } else { COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); } TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); return id; } } // namespace std::shared_ptr getComm(std::set const& group) { auto const rank = COMM_SESSION.getRank(); TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); static std::map, std::shared_ptr> commMap; static std::mutex mutex; std::lock_guard lock(mutex); std::ostringstream oss; int index = 0; for (auto const& rank : group) { if (index != 0) { oss << ","; } oss << rank; index++; } auto groupStr = oss.str(); auto it = commMap.find(group); if (it != commMap.end()) { auto ncclComm = it->second; TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); return ncclComm; } TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); ncclUniqueId id = getUniqueId(group); int groupRank = 0; for (auto const& currentRank : group) { if (rank == currentRank) break; ++groupRank; } TLLM_CHECK(static_cast(groupRank) < group.size()); std::shared_ptr ncclComm(new ncclComm_t, [](ncclComm_t* comm) { if (!comm) { return; } // STEP 1: Clean up resources and destroy NCCL communicator if it's valid if (*comm) { // Clean up all registered resources FIRST tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm); // Now destroy the NCCL communicator ncclResult_t result = ncclCommDestroy(*comm); if (result != ncclSuccess) { TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result); } // Clear the communicator value before freeing the pointer *comm = nullptr; } // STEP 2: Always free the pointer memory (regardless of whether *comm was valid) delete comm; }); #if defined(_WIN32) // Need static connection initialization for accurate KV cache size estimation if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) _putenv_s("NCCL_RUNTIME_CONNECT", "0"); // Disable graph register to avoid startup hangs if (getenv("NCCL_GRAPH_REGISTER") == nullptr) _putenv_s("NCCL_GRAPH_REGISTER", "0"); #else setenv("NCCL_RUNTIME_CONNECT", "0", 0); setenv("NCCL_GRAPH_REGISTER", "0", 0); #endif // _WIN32 NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); commMap[group] = ncclComm; TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); return ncclComm; } #endif // ENABLE_MULTI_DEVICE void const* tensorrt_llm::common::op::getCommSessionHandle() { #if ENABLE_MULTI_DEVICE return &COMM_SESSION; #else return nullptr; #endif // ENABLE_MULTI_DEVICE } namespace { using tensorrt_llm::common::op::hash; // Get current cuda context, a default context will be created if there is no context. inline CUcontext getCurrentCudaCtx() { CUcontext ctx{}; CUresult err = cuCtxGetCurrent(&ctx); if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr) { TLLM_CUDA_CHECK(cudaFree(nullptr)); err = cuCtxGetCurrent(&ctx); } TLLM_CHECK(err == CUDA_SUCCESS); return ctx; } // Helper to create per-cuda-context and per-thread singleton managed by std::shared_ptr. // Unlike conventional singletons, singleton created with this will be released // when not needed, instead of on process exit. // Objects of this class shall always be declared static / global, and shall never own CUDA // resources. template class PerCudaCtxPerThreadSingletonCreator { public: using CreatorFunc = std::function()>; using DeleterFunc = std::function; // creator returning std::unique_ptr is by design. // It forces separation of memory for T and memory for control blocks. // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. // creator itself must not own CUDA resources. Only the object it creates can. PerCudaCtxPerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) : mCreator{std::move(creator)} , mDeleter{std::move(deleter)} , mObservers{new std::unordered_map, hash>()} { } ~PerCudaCtxPerThreadSingletonCreator() { std::lock_guard lk{mMutex}; delete mObservers; mObservers = nullptr; } std::shared_ptr operator()() { std::lock_guard lk{mMutex}; CUcontext ctx{getCurrentCudaCtx()}; std::thread::id thread = std::this_thread::get_id(); auto const key = std::make_tuple(ctx, thread); std::shared_ptr result = (*mObservers)[key].lock(); if (result == nullptr) { TLLM_LOG_TRACE("creating singleton instance for CUDA context %lu and thread %lu", ctx, thread); // Create the resource and register with an observer. result = std::shared_ptr{mCreator().release(), [this, key](T* obj) { if (obj == nullptr) { return; } mDeleter(obj); if (mObservers == nullptr) { return; } // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts // frequently. std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. std::lock_guard lk{mMutex}; // Must check observer again because another thread may created new instance for this ctx and this // thread just before we lock mMutex. We can't infer that the observer is stale from the fact that // obj is destroyed, because shared_ptr ref-count checking and observer removing are not in one // atomic operation, and the observer may be changed to observe another instance. auto it = mObservers->find(key); if (it == mObservers->end()) { return; } observedObjHolder = it->second.lock(); if (observedObjHolder == nullptr) { mObservers->erase(it); } }}; (*mObservers)[key] = result; } else { TLLM_LOG_TRACE("singleton instance for CUDA context %d and thread %d is cached", ctx, thread); } return result; } private: CreatorFunc mCreator; DeleterFunc mDeleter; mutable std::mutex mMutex; // CUDA resources are per-context and per-thread. using CacheKey = std::tuple; std::unordered_map, hash>* mObservers; }; // Structure to hold memory information struct MemoryInfo { size_t free_mb; size_t total_mb; float free_percent; }; // Helper function to get current memory information MemoryInfo getMemoryInfo() { size_t free_mem = 0, total_mem = 0; TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); size_t const free_mb = free_mem / (1024 * 1024); size_t const total_mb = total_mem / (1024 * 1024); float const free_percent = (total_mem > 0) ? (static_cast(free_mem) / total_mem * 100.0f) : 0.0f; return {free_mb, total_mb, free_percent}; } // Helper function to log current memory usage void logMemoryUsage(char const* operation, CUcontext ctx) { auto const mem = getMemoryInfo(); TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb, mem.free_percent, mem.total_mb); } // Helper function to throw void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status) { auto const mem = getMemoryInfo(); TLLM_THROW( "Failed to create %s. " "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " "Consider reducing kv_cache_config.free_gpu_memory_fraction.", operation, status, ctx, mem.free_mb, mem.free_percent, mem.total_mb); } } // namespace std::shared_ptr getCublasHandle() { static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { CUcontext ctx = getCurrentCudaCtx(); logMemoryUsage("Creating cublas handle", ctx); auto handle = std::make_unique(); auto status = cublasCreate(handle.get()); if (status != CUBLAS_STATUS_SUCCESS) { throwCublasErrorWithMemInfo("cublas handle", ctx, status); } return handle; }, [](cublasHandle_t* handle) { auto status = cublasDestroy(*handle); if (status != CUBLAS_STATUS_SUCCESS) { TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status); } delete handle; handle = nullptr; }); return creator(); } std::shared_ptr getCublasLtHandle() { static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { CUcontext ctx = getCurrentCudaCtx(); logMemoryUsage("Creating cublasLt handle", ctx); auto handle = std::make_unique(); auto status = cublasLtCreate(handle.get()); if (status != CUBLAS_STATUS_SUCCESS) { throwCublasErrorWithMemInfo("cublasLt handle", ctx, status); } return handle; }, [](cublasLtHandle_t* handle) { auto status = cublasLtDestroy(*handle); if (status != CUBLAS_STATUS_SUCCESS) { TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status); } delete handle; handle = nullptr; }); return creator(); } TRTLLM_NAMESPACE_END