mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix TMA error with GEMM+AR on TP=2 (#6071)
Signed-off-by: Xavier Simmons <xsimmons@nvidia.com>
This commit is contained in:
parent
ac0b3f8b66
commit
7b818de700
@ -221,9 +221,6 @@ public:
|
||||
{
|
||||
MPI_group_barrier(_ranks);
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaStreamCreate(&_memcpy_stream));
|
||||
TLLM_CUDA_CHECK(cudaEventCreate(&_fork_join_event));
|
||||
}
|
||||
|
||||
int free() override
|
||||
@ -267,8 +264,6 @@ public:
|
||||
DeviceAllocationNvls<BarrierT> _tile_barriers;
|
||||
DeviceAllocationNvls<BarrierT> _completion_barriers;
|
||||
DeviceAllocationNvls<ElementD> _stage_buf;
|
||||
cudaStream_t _memcpy_stream;
|
||||
cudaEvent_t _fork_join_event;
|
||||
};
|
||||
|
||||
GemmAllReduceImplTwoshot_Sm100()
|
||||
|
||||
@ -186,9 +186,6 @@ public:
|
||||
{
|
||||
MPI_group_barrier(_ranks);
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaStreamCreate(&_memcpy_stream));
|
||||
TLLM_CUDA_CHECK(cudaEventCreate(&_fork_join_event));
|
||||
}
|
||||
|
||||
int free() override
|
||||
@ -232,8 +229,6 @@ public:
|
||||
DeviceAllocationNvls<BarrierT> _tile_barriers;
|
||||
DeviceAllocationNvls<BarrierT> _completion_barriers;
|
||||
DeviceAllocationNvls<ElementD> _stage_buf;
|
||||
cudaStream_t _memcpy_stream;
|
||||
cudaEvent_t _fork_join_event;
|
||||
};
|
||||
|
||||
GemmAllReduceImplTwoshot_Sm90()
|
||||
|
||||
@ -201,7 +201,7 @@ public:
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto [m, n, k, l] = tile_coord;
|
||||
|
||||
if (!tile_valid(m, n) || params_ptr->world_size <= 1)
|
||||
if (!tile_valid(m, n) || params_ptr->world_size <= 2)
|
||||
{
|
||||
return; // nothing to do
|
||||
}
|
||||
|
||||
@ -108,6 +108,8 @@ void GemmAllReducePlugin::allocatePersistentWorkspace()
|
||||
{
|
||||
TLLM_CHECK(mOptions.maxProblemShape.isInitialized());
|
||||
|
||||
mWorkspaceKey = "gemm_allreduce_workspace_m" + std::to_string(mOptions.maxProblemShape.maxM);
|
||||
|
||||
cutlass_kernels::GemmAllReduceImplInterface::LaunchConfig smallest_tile_config
|
||||
= mGemm->getSupportedLaunchConfigs()[0];
|
||||
cutlass_kernels::GemmAllReduceImplInterface::ProblemArgs args;
|
||||
@ -123,7 +125,7 @@ void GemmAllReducePlugin::allocatePersistentWorkspace()
|
||||
|
||||
// Register and allocate workspace
|
||||
mWorkspace = static_cast<GemmAllReducePersistentWorkspace*>(
|
||||
getPluginRegistry()->acquirePluginResource(mWorkspaceKey, &unallocated_resource));
|
||||
getPluginRegistry()->acquirePluginResource(mWorkspaceKey.c_str(), &unallocated_resource));
|
||||
TLLM_CHECK(mWorkspace != nullptr);
|
||||
}
|
||||
|
||||
@ -395,6 +397,7 @@ int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensor
|
||||
auto const N = utils::computeNDimension(mOptions.transB, inputDesc[1].dims);
|
||||
auto const K = mOptions.transA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1];
|
||||
|
||||
TLLM_CHECK_WITH_INFO(M <= mOptions.maxProblemShape.maxM, "GemmAllReducePlugin M > maxM.");
|
||||
TLLM_CHECK_WITH_INFO(M > 0, "GemmAllReducePlugin M is 0.");
|
||||
TLLM_CHECK_WITH_INFO(N > 0, "GemmAllReducePlugin N is 0.");
|
||||
TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0.");
|
||||
@ -513,7 +516,7 @@ void GemmAllReducePlugin::terminate() noexcept
|
||||
// free mWorkspace
|
||||
if (mWorkspace)
|
||||
{
|
||||
getPluginRegistry()->releasePluginResource(mWorkspaceKey);
|
||||
getPluginRegistry()->releasePluginResource(mWorkspaceKey.c_str());
|
||||
mWorkspace = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,7 +154,7 @@ private:
|
||||
int mNbOutputs = 0;
|
||||
|
||||
std::map<KeyType, ValueType> mTypedInstantiators;
|
||||
char const* mWorkspaceKey = "gemm_allreduce_workspace";
|
||||
std::string mWorkspaceKey;
|
||||
std::shared_ptr<cutlass_kernels::GemmAllReduceImplInterface> mGemm;
|
||||
// Params that are initialized during configurePlugin()
|
||||
GemmAllReducePersistentWorkspace* mWorkspace = nullptr;
|
||||
|
||||
@ -60,8 +60,9 @@ void GemmAllReducePluginProfiler::deserializeFromOwnFile(GemmIdCore gemmId, Gemm
|
||||
|
||||
bool GemmAllReducePluginProfiler::useProfiler()
|
||||
{
|
||||
char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
|
||||
return envDir != nullptr;
|
||||
// char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
|
||||
// return envDir != nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string GemmAllReducePluginProfiler::getCacheFileName(GemmIdCore gemmId)
|
||||
|
||||
@ -295,6 +295,7 @@ public:
|
||||
// Clean up
|
||||
MPI_Group_free(&new_group);
|
||||
MPI_Group_free(&world_group);
|
||||
MPI_Comm_free(&new_comm);
|
||||
|
||||
return nvls_handle;
|
||||
}
|
||||
@ -401,14 +402,14 @@ void MPI_group_barrier(std::set<int> group)
|
||||
MPI_Comm new_comm;
|
||||
|
||||
// Get the group of the world communicator
|
||||
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
|
||||
MPI_Comm_group(COMM_SESSION, &world_group);
|
||||
|
||||
// Create a new group containing only the ranks we want
|
||||
std::vector<int> ranks(group.begin(), group.end());
|
||||
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);
|
||||
|
||||
// Create a new communicator from the group
|
||||
MPI_Comm_create_group(MPI_COMM_WORLD, new_group, 0, &new_comm);
|
||||
MPI_Comm_create_group(COMM_SESSION, new_group, 0, &new_comm);
|
||||
|
||||
// Use the new communicator for the barrier
|
||||
MPI_Barrier(new_comm);
|
||||
@ -510,6 +511,8 @@ IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set<int> group)
|
||||
|
||||
MPI_Barrier(new_comm);
|
||||
|
||||
MPI_Comm_free(&new_comm);
|
||||
|
||||
return handle;
|
||||
#else
|
||||
TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user