From 7b818de700822124c5b1a309d9e10822878c71de Mon Sep 17 00:00:00 2001 From: xavier-nvidia Date: Wed, 16 Jul 2025 10:27:32 -0700 Subject: [PATCH] Fix TMA error with GEMM+AR on TP=2 (#6071) Signed-off-by: Xavier Simmons --- .../allreduce_gemm/allreduce_gemm_impl_sm100.h | 5 ----- .../allreduce_gemm/allreduce_gemm_impl_sm90.h | 5 ----- .../communication/sm90_allreduce_nvls_warpspecialized.hpp | 2 +- .../plugins/gemmAllReducePlugin/gemmAllReducePlugin.cpp | 7 +++++-- .../plugins/gemmAllReducePlugin/gemmAllReducePlugin.h | 2 +- .../gemmAllReducePlugin/gemmAllReducePluginProfiler.cpp | 5 +++-- cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu | 7 +++++-- 7 files changed, 15 insertions(+), 18 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h index ed18541d0a..a4be82607a 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h @@ -221,9 +221,6 @@ public: { MPI_group_barrier(_ranks); } - - TLLM_CUDA_CHECK(cudaStreamCreate(&_memcpy_stream)); - TLLM_CUDA_CHECK(cudaEventCreate(&_fork_join_event)); } int free() override @@ -267,8 +264,6 @@ public: DeviceAllocationNvls _tile_barriers; DeviceAllocationNvls _completion_barriers; DeviceAllocationNvls _stage_buf; - cudaStream_t _memcpy_stream; - cudaEvent_t _fork_join_event; }; GemmAllReduceImplTwoshot_Sm100() diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm90.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm90.h index ab867b69a8..fb446b451d 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm90.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm90.h @@ -186,9 +186,6 @@ public: { MPI_group_barrier(_ranks); } - - TLLM_CUDA_CHECK(cudaStreamCreate(&_memcpy_stream)); - TLLM_CUDA_CHECK(cudaEventCreate(&_fork_join_event)); } int free() override @@ -232,8 +229,6 @@ public: DeviceAllocationNvls _tile_barriers; DeviceAllocationNvls _completion_barriers; DeviceAllocationNvls _stage_buf; - cudaStream_t _memcpy_stream; - cudaEvent_t _fork_join_event; }; GemmAllReduceImplTwoshot_Sm90() diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/communication/sm90_allreduce_nvls_warpspecialized.hpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/communication/sm90_allreduce_nvls_warpspecialized.hpp index 61be032b62..b126beebfe 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/communication/sm90_allreduce_nvls_warpspecialized.hpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/communication/sm90_allreduce_nvls_warpspecialized.hpp @@ -201,7 +201,7 @@ public: auto [M, N, K, L] = problem_shape; auto [m, n, k, l] = tile_coord; - if (!tile_valid(m, n) || params_ptr->world_size <= 1) + if (!tile_valid(m, n) || params_ptr->world_size <= 2) { return; // nothing to do } diff --git a/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.cpp b/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.cpp index 8d80827b90..4cec38b046 100644 --- a/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.cpp @@ -108,6 +108,8 @@ void GemmAllReducePlugin::allocatePersistentWorkspace() { TLLM_CHECK(mOptions.maxProblemShape.isInitialized()); + mWorkspaceKey = "gemm_allreduce_workspace_m" + std::to_string(mOptions.maxProblemShape.maxM); + cutlass_kernels::GemmAllReduceImplInterface::LaunchConfig smallest_tile_config = mGemm->getSupportedLaunchConfigs()[0]; cutlass_kernels::GemmAllReduceImplInterface::ProblemArgs args; @@ -123,7 +125,7 @@ void GemmAllReducePlugin::allocatePersistentWorkspace() // Register and allocate workspace mWorkspace = static_cast( - getPluginRegistry()->acquirePluginResource(mWorkspaceKey, &unallocated_resource)); + getPluginRegistry()->acquirePluginResource(mWorkspaceKey.c_str(), &unallocated_resource)); TLLM_CHECK(mWorkspace != nullptr); } @@ -395,6 +397,7 @@ int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensor auto const N = utils::computeNDimension(mOptions.transB, inputDesc[1].dims); auto const K = mOptions.transA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; + TLLM_CHECK_WITH_INFO(M <= mOptions.maxProblemShape.maxM, "GemmAllReducePlugin M > maxM."); TLLM_CHECK_WITH_INFO(M > 0, "GemmAllReducePlugin M is 0."); TLLM_CHECK_WITH_INFO(N > 0, "GemmAllReducePlugin N is 0."); TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0."); @@ -513,7 +516,7 @@ void GemmAllReducePlugin::terminate() noexcept // free mWorkspace if (mWorkspace) { - getPluginRegistry()->releasePluginResource(mWorkspaceKey); + getPluginRegistry()->releasePluginResource(mWorkspaceKey.c_str()); mWorkspace = nullptr; } } diff --git a/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.h b/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.h index 4cd2a77a5c..4579262460 100644 --- a/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.h +++ b/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePlugin.h @@ -154,7 +154,7 @@ private: int mNbOutputs = 0; std::map mTypedInstantiators; - char const* mWorkspaceKey = "gemm_allreduce_workspace"; + std::string mWorkspaceKey; std::shared_ptr mGemm; // Params that are initialized during configurePlugin() GemmAllReducePersistentWorkspace* mWorkspace = nullptr; diff --git a/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePluginProfiler.cpp b/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePluginProfiler.cpp index d6e0f3b8ac..3ce75b77ce 100644 --- a/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePluginProfiler.cpp +++ b/cpp/tensorrt_llm/plugins/gemmAllReducePlugin/gemmAllReducePluginProfiler.cpp @@ -60,8 +60,9 @@ void GemmAllReducePluginProfiler::deserializeFromOwnFile(GemmIdCore gemmId, Gemm bool GemmAllReducePluginProfiler::useProfiler() { - char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR"); - return envDir != nullptr; + // char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR"); + // return envDir != nullptr; + return false; } std::string GemmAllReducePluginProfiler::getCacheFileName(GemmIdCore gemmId) diff --git a/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu b/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu index c685966148..031ac92168 100644 --- a/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu +++ b/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu @@ -295,6 +295,7 @@ public: // Clean up MPI_Group_free(&new_group); MPI_Group_free(&world_group); + MPI_Comm_free(&new_comm); return nvls_handle; } @@ -401,14 +402,14 @@ void MPI_group_barrier(std::set group) MPI_Comm new_comm; // Get the group of the world communicator - MPI_Comm_group(MPI_COMM_WORLD, &world_group); + MPI_Comm_group(COMM_SESSION, &world_group); // Create a new group containing only the ranks we want std::vector ranks(group.begin(), group.end()); MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group); // Create a new communicator from the group - MPI_Comm_create_group(MPI_COMM_WORLD, new_group, 0, &new_comm); + MPI_Comm_create_group(COMM_SESSION, new_group, 0, &new_comm); // Use the new communicator for the barrier MPI_Barrier(new_comm); @@ -510,6 +511,8 @@ IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set group) MPI_Barrier(new_comm); + MPI_Comm_free(&new_comm); + return handle; #else TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");