Fix TMA error with GEMM+AR on TP=2 (#6071)

Signed-off-by: Xavier Simmons <xsimmons@nvidia.com>
This commit is contained in:
xavier-nvidia 2025-07-16 10:27:32 -07:00 committed by GitHub
parent ac0b3f8b66
commit 7b818de700
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 15 additions and 18 deletions

View File

@ -221,9 +221,6 @@ public:
{
MPI_group_barrier(_ranks);
}
TLLM_CUDA_CHECK(cudaStreamCreate(&_memcpy_stream));
TLLM_CUDA_CHECK(cudaEventCreate(&_fork_join_event));
}
int free() override
@ -267,8 +264,6 @@ public:
DeviceAllocationNvls<BarrierT> _tile_barriers;
DeviceAllocationNvls<BarrierT> _completion_barriers;
DeviceAllocationNvls<ElementD> _stage_buf;
cudaStream_t _memcpy_stream;
cudaEvent_t _fork_join_event;
};
GemmAllReduceImplTwoshot_Sm100()

View File

@ -186,9 +186,6 @@ public:
{
MPI_group_barrier(_ranks);
}
TLLM_CUDA_CHECK(cudaStreamCreate(&_memcpy_stream));
TLLM_CUDA_CHECK(cudaEventCreate(&_fork_join_event));
}
int free() override
@ -232,8 +229,6 @@ public:
DeviceAllocationNvls<BarrierT> _tile_barriers;
DeviceAllocationNvls<BarrierT> _completion_barriers;
DeviceAllocationNvls<ElementD> _stage_buf;
cudaStream_t _memcpy_stream;
cudaEvent_t _fork_join_event;
};
GemmAllReduceImplTwoshot_Sm90()

View File

@ -201,7 +201,7 @@ public:
auto [M, N, K, L] = problem_shape;
auto [m, n, k, l] = tile_coord;
if (!tile_valid(m, n) || params_ptr->world_size <= 1)
if (!tile_valid(m, n) || params_ptr->world_size <= 2)
{
return; // nothing to do
}

View File

@ -108,6 +108,8 @@ void GemmAllReducePlugin::allocatePersistentWorkspace()
{
TLLM_CHECK(mOptions.maxProblemShape.isInitialized());
mWorkspaceKey = "gemm_allreduce_workspace_m" + std::to_string(mOptions.maxProblemShape.maxM);
cutlass_kernels::GemmAllReduceImplInterface::LaunchConfig smallest_tile_config
= mGemm->getSupportedLaunchConfigs()[0];
cutlass_kernels::GemmAllReduceImplInterface::ProblemArgs args;
@ -123,7 +125,7 @@ void GemmAllReducePlugin::allocatePersistentWorkspace()
// Register and allocate workspace
mWorkspace = static_cast<GemmAllReducePersistentWorkspace*>(
getPluginRegistry()->acquirePluginResource(mWorkspaceKey, &unallocated_resource));
getPluginRegistry()->acquirePluginResource(mWorkspaceKey.c_str(), &unallocated_resource));
TLLM_CHECK(mWorkspace != nullptr);
}
@ -395,6 +397,7 @@ int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensor
auto const N = utils::computeNDimension(mOptions.transB, inputDesc[1].dims);
auto const K = mOptions.transA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1];
TLLM_CHECK_WITH_INFO(M <= mOptions.maxProblemShape.maxM, "GemmAllReducePlugin M > maxM.");
TLLM_CHECK_WITH_INFO(M > 0, "GemmAllReducePlugin M is 0.");
TLLM_CHECK_WITH_INFO(N > 0, "GemmAllReducePlugin N is 0.");
TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0.");
@ -513,7 +516,7 @@ void GemmAllReducePlugin::terminate() noexcept
// free mWorkspace
if (mWorkspace)
{
getPluginRegistry()->releasePluginResource(mWorkspaceKey);
getPluginRegistry()->releasePluginResource(mWorkspaceKey.c_str());
mWorkspace = nullptr;
}
}

View File

@ -154,7 +154,7 @@ private:
int mNbOutputs = 0;
std::map<KeyType, ValueType> mTypedInstantiators;
char const* mWorkspaceKey = "gemm_allreduce_workspace";
std::string mWorkspaceKey;
std::shared_ptr<cutlass_kernels::GemmAllReduceImplInterface> mGemm;
// Params that are initialized during configurePlugin()
GemmAllReducePersistentWorkspace* mWorkspace = nullptr;

View File

@ -60,8 +60,9 @@ void GemmAllReducePluginProfiler::deserializeFromOwnFile(GemmIdCore gemmId, Gemm
bool GemmAllReducePluginProfiler::useProfiler()
{
char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
return envDir != nullptr;
// char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
// return envDir != nullptr;
return false;
}
std::string GemmAllReducePluginProfiler::getCacheFileName(GemmIdCore gemmId)

View File

@ -295,6 +295,7 @@ public:
// Clean up
MPI_Group_free(&new_group);
MPI_Group_free(&world_group);
MPI_Comm_free(&new_comm);
return nvls_handle;
}
@ -401,14 +402,14 @@ void MPI_group_barrier(std::set<int> group)
MPI_Comm new_comm;
// Get the group of the world communicator
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
MPI_Comm_group(COMM_SESSION, &world_group);
// Create a new group containing only the ranks we want
std::vector<int> ranks(group.begin(), group.end());
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);
// Create a new communicator from the group
MPI_Comm_create_group(MPI_COMM_WORLD, new_group, 0, &new_comm);
MPI_Comm_create_group(COMM_SESSION, new_group, 0, &new_comm);
// Use the new communicator for the barrier
MPI_Barrier(new_comm);
@ -510,6 +511,8 @@ IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set<int> group)
MPI_Barrier(new_comm);
MPI_Comm_free(&new_comm);
return handle;
#else
TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");