/* * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cutlass_extensions/gemm_configs.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" #include "tensorrt_llm/thop/thUtils.h" #include #include #include #include #include #include #include using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner; using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface; using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes; using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface; namespace { struct AllocationKey { int64_t device_index; std::set group; bool operator==(AllocationKey const& other) const { return device_index == other.device_index && group == other.group; } std::string toString() const { std::stringstream ss; ss << "AllocationKey(device: " << device_index << ", group: ["; for (int rank : group) { ss << rank << ", "; } ss << "])"; return ss.str(); } }; struct AllocationKeyHash { size_t operator()(AllocationKey const& key) const { size_t seed = 0; // Hash the device index hash_combine(seed, key.device_index); // Hash the set elements for (auto const& elem : key.group) { hash_combine(seed, elem); } return seed; } private: template static void hash_combine(size_t& seed, T const& val) { seed ^= std::hash()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } }; class IpcNvlsHandleWrapper { public: IpcNvlsHandleWrapper(size_t size, std::set groups) : mSize(size) { mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups); } tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const { return mHandle; } size_t getSize() const { return mSize; } ~IpcNvlsHandleWrapper() { tensorrt_llm::runtime::ipcNvlsFree(mHandle); } private: size_t mSize; tensorrt_llm::runtime::IpcNvlsHandle* mHandle; }; std::once_flag init_flag; size_t getPreferredWorkspaceSize() { // 128MB static size_t preferredWorkspaceSize = 134217728; std::call_once(init_flag, [&]() { char const* envWorkspaceSize = std::getenv("TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE"); size_t workspaceSize = 0; if (envWorkspaceSize != nullptr) { workspaceSize = std::atoi(envWorkspaceSize); } preferredWorkspaceSize = std::max(preferredWorkspaceSize, workspaceSize); }); return preferredWorkspaceSize; } class GemmAllreduceNvlsMemoryManager { public: GemmAllreduceNvlsMemoryManager() { TLLM_LOG_DEBUG("GemmAllreduceNvlsMemoryManager constructor"); } ~GemmAllreduceNvlsMemoryManager() { TLLM_LOG_DEBUG("GemmAllreduceNvlsMemoryManager destructor"); } std::pair getWorkspace( GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem, AllocationKey const& key) { int M = std::get<0>(problem.problem_size); int N = std::get<1>(problem.problem_size); size_t requiredSize = M * N * 2; size_t preferredWorkspaceSize = getPreferredWorkspaceSize(); if (requiredSize > preferredWorkspaceSize) { std::stringstream ss; ss << "Please set TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE to at least " << requiredSize << " bytes"; TLLM_THROW("%s", ss.str().c_str()); } auto handle = mHandles[key]; if (handle == nullptr) { TLLM_LOG_DEBUG("Creating allreduce workspace for %s", key.toString().c_str()); handle = std::make_shared(preferredWorkspaceSize, key.group); GemmAllReduceImplInterface::ProblemArgs tmpArgs; int maxN = 16384; int maxM = preferredWorkspaceSize / (maxN * 2); tmpArgs.argProblemShape(maxM, maxN, 512, 1) .argRanks(problem.rank, problem.ranks) .argLaunchConfig(runner->getSupportedLaunchConfigs()[0]); auto workspace = runner->getPersistentWorkspace(tmpArgs); workspace->allocate(); mWorkspaces[key] = workspace; mHandles[key] = handle; } return std::make_pair(mWorkspaces[key].get(), mHandles[key]->getHandle()); } private: std::unordered_map, AllocationKeyHash> mWorkspaces; std::unordered_map, AllocationKeyHash> mHandles; }; GemmAllreduceNvlsMemoryManager* getGemmAllreduceNvlsMemoryManager() { static GemmAllreduceNvlsMemoryManager gNvlsMemoryManager; return &gNvlsMemoryManager; } at::Tensor runGemmImpl(GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs& problem, at::ScalarType outputDtype, c10::cuda::CUDAStream stream) { AllocationKey key{stream.device_index(), problem.ranks}; auto [workspace, handle] = getGemmAllreduceNvlsMemoryManager()->getWorkspace(runner, problem, key); problem.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data()); problem.argWorkspace(workspace); runner->run(problem, stream); size_t dSize = std::get<0>(problem.problem_size) * std::get<1>(problem.problem_size) * c10::elementSize(outputDtype); auto D = at::detail::empty_cuda({std::get<0>(problem.problem_size), std::get<1>(problem.problem_size)}, outputDtype, stream.device(), std::nullopt); TLLM_CUDA_CHECK(cudaMemcpyAsync( D.data_ptr(), reinterpret_cast(handle->uc_ptr), dSize, cudaMemcpyDeviceToDevice, stream)); return D; } } // namespace namespace torch_ext { class Fp4GemmAllreduceRunner : public torch::CustomClassHolder { public: explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List group) : mOutputDtype(outputDtype) , mRank(rank) { for (int64_t rank : group) { mGroup.insert(static_cast(rank)); } if (outputDtype == at::ScalarType::Half) { using Traits = GemmTypes; mRunner = std::make_shared>(); } else if (outputDtype == at::ScalarType::BFloat16) { using Traits = GemmTypes; mRunner = std::make_shared>(); } else { TLLM_THROW("Unsupported output dtype: %s", torch::toString(outputDtype)); } mConfigs = mRunner->getSupportedLaunchConfigs(); } at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const { if (configIdx < 0) configIdx = 0; TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds"); const int64_t M = mat1.size(0); const int64_t N = mat2.size(0); const int64_t K = mat1.size(1) * 2; GemmAllReduceImplInterface::ProblemArgs problemArgs; problemArgs.argProblemShape(M, N, K, 1); problemArgs.argA(mat1.data_ptr()); problemArgs.argB(mat2.data_ptr()); problemArgs.argAScale(mat1Scale.data_ptr()); problemArgs.argBScale(mat2Scale.data_ptr()); problemArgs.argC(nullptr); problemArgs.argAlphaPtr(reinterpret_cast(alpha.const_data_ptr())); problemArgs.argBeta(0.f); problemArgs.argRanks(mRank, mGroup); problemArgs.argLaunchConfig(mConfigs[configIdx]); auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device()); return runGemmImpl(mRunner.get(), problemArgs, mOutputDtype, stream); } int64_t getNumConfigs() const { return static_cast(mConfigs.size()); } private: at::ScalarType mOutputDtype; int mRank; std::set mGroup; std::shared_ptr mRunner{nullptr}; std::vector mConfigs; }; } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.class_("Fp4GemmAllreduceRunner") .def(torch::init>()) .def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm) .def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs); }