[TRTLLM-4647][fix] Fix the no fusion allreduce hanging (#4594)

Signed-off-by: Shiyu Li <shili@nvidia.com>
This commit is contained in:
Shiyu Li 2025-06-04 18:26:13 -07:00 committed by GitHub
parent 8433091630
commit b0d287c9b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 56 additions and 45 deletions

View File

@ -80,6 +80,15 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
uint32_t input_offset = buffer_flags[0] * buffer_size;
uint32_t clear_offset = buffer_flags[1] * buffer_size;
if (wait_for_results)
{
__syncthreads();
if (threadIdx.x == 0)
{
atomicAdd(offset_access_ptr, 1);
}
}
if (elt < token_dim)
{
// Scatter token
@ -312,13 +321,14 @@ __global__ void __launch_bounds__(128, 1)
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
cudaTriggerProgrammaticLaunchCompletion();
uint32_t* offset_access_ptr = &buffer_flags[3];
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
uint32_t buffer_size = buffer_flags[2];
uint32_t buffer_offset = buffer_flags[0] * (buffer_size << 1);
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
cudaTriggerProgrammaticLaunchCompletion();
__syncthreads();
if (threadIdx.x == 0)
{

View File

@ -30,19 +30,6 @@ namespace tensorrt_llm::runtime
namespace
{
#define CUCHECK(cmd) \
do \
{ \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) \
{ \
const char* error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0)
// An efficient implementation assuming gran is a power of 2
inline size_t roundUp(size_t val, size_t gran)
{
@ -66,7 +53,7 @@ McastDeviceMemory::McastDeviceMemory(
cudaSetDevice(mDeviceIdx);
// Check if the device support multicasting
int multicast_supported{0};
CUCHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
TLLM_CU_CHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
if (multicast_supported == 0)
{
TLLM_THROW("[McastDeviceMemory] Device does not support multicasting.");
@ -83,7 +70,7 @@ McastDeviceMemory::McastDeviceMemory(
{
// For multi-node, we also need to check if fabric handle is supported
int fabric_handle_supported{0};
CUCHECK(cuDeviceGetAttribute(
TLLM_CU_CHECK(cuDeviceGetAttribute(
&fabric_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, mDeviceIdx));
if (fabric_handle_supported == 0)
{
@ -142,12 +129,15 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = mDeviceIdx;
prop.allocFlags.gpuDirectRDMACapable = 1;
size_t granularity{0};
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
size_t alloc_granularity{0}, mc_granularity{0};
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&alloc_granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
// Round up the buffer size for grnularity
mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, granularity);
mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity);
CUmulticastObjectProp mcProp = {.numDevices = mGroupSize, .size = mAllocationSize, .handleTypes = handle_type};
TLLM_CU_CHECK(cuMulticastGetGranularity(&mc_granularity, &mcProp, CU_MULTICAST_GRANULARITY_RECOMMENDED));
mAllocationSize = roundUp(mAllocationSize, mc_granularity);
mUcHandles.resize(mGroupSize);
// Allocates local gpu memory
TLLM_CU_CHECK(cuMemCreate(&(mUcHandles[mGroupRank]), mAllocationSize, &prop, 0));
@ -170,8 +160,6 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
cudaFreeHost(exphndl);
// Initialize multicasting
CUmulticastObjectProp mcProp
= {.numDevices = mGroupSize, .size = mAllocationSize, .handleTypes = CU_MEM_HANDLE_TYPE_FABRIC};
CUmemFabricHandle* fabric_handle;
cudaMallocHost(&fabric_handle, sizeof(CUmemFabricHandle));
if (mGroupRank == 0)
@ -192,7 +180,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
// Bind memory addresses
mUcPtrs.resize(mGroupSize);
CUdeviceptr ptr;
TLLM_CU_CHECK(cuMemAddressReserve(&ptr, mAllocationSize * mGroupSize, 0, 0, 0));
TLLM_CU_CHECK(cuMemAddressReserve(&ptr, mAllocationSize * mGroupSize, mc_granularity, 0ULL, 0));
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
@ -206,7 +194,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
TLLM_CU_CHECK(cuMemSetAccess(ptr, mAllocationSize * mGroupSize, &accessDesc, 1));
// Bind MC Pointers
TLLM_CU_CHECK(cuMemAddressReserve(&mMcPtr, mAllocationSize, 0, 0U, 0));
TLLM_CU_CHECK(cuMemAddressReserve(&mMcPtr, mAllocationSize, mc_granularity, 0ULL, 0));
TLLM_CU_CHECK(cuMemMap(mMcPtr, mAllocationSize, 0, mMcHandle, 0));
TLLM_CU_CHECK(cuMemSetAccess(mMcPtr, mAllocationSize, &accessDesc, 1));

View File

@ -1044,12 +1044,13 @@ std::vector<torch::Tensor> moe_allreduce(torch::Tensor const& residual, torch::T
}
at::Tensor mnnvlTwoShotAllReduce(
at::Tensor& output, at::Tensor& input, at::Tensor& comm_buffer, at::Tensor& buffer_flags, bool wait_for_results)
at::Tensor& input, at::Tensor& comm_buffer, at::Tensor& buffer_flags, bool wait_for_results)
{
auto* mcast_mem = tensorrt_llm::common::findMcastDevMemBuffer(comm_buffer.data_ptr());
TORCH_CHECK(mcast_mem != nullptr, "two_shot_all_reduce: comm_buffer must be obtained from a mcastBuffer instance.");
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
at::Tensor output = torch::empty_like(input);
auto allreduce_params = tensorrt_llm::kernels::mnnvl::AllReduceParams();
allreduce_params.dtype = dtype;
@ -1071,35 +1072,42 @@ at::Tensor mnnvlTwoShotAllReduce(
return output;
}
void twoShotRMSNorm(torch::Tensor& prenorm_output, torch::Tensor& normed_output, torch::Tensor const& input,
torch::Tensor const& gamma, double epsilon, torch::Tensor const& residual, torch::Tensor& buffer_flags)
std::vector<torch::Tensor> twoShotRMSNorm(torch::Tensor const& comm_buf, torch::Tensor const& gamma, double epsilon,
torch::Tensor const& residual, torch::Tensor& buffer_flags)
{
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(comm_buf.scalar_type());
auto rmsnorm_params = tensorrt_llm::kernels::mnnvl::RMSNormParams();
// Input is the communication buffer so we need to get the shape from residual
torch::Tensor normed_output = torch::empty_like(residual);
torch::Tensor prenorm_output = torch::empty_like(residual);
rmsnorm_params.dtype = dtype;
rmsnorm_params.residual_output = prenorm_output.data_ptr();
rmsnorm_params.output = normed_output.data_ptr();
rmsnorm_params.input = input.data_ptr();
rmsnorm_params.input = comm_buf.data_ptr();
rmsnorm_params.gamma = gamma.data_ptr();
rmsnorm_params.epsilon = epsilon;
rmsnorm_params.residual = residual.data_ptr();
rmsnorm_params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr());
rmsnorm_params.batch = normed_output.size(0);
rmsnorm_params.hidden_dim = normed_output.size(1);
rmsnorm_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
rmsnorm_params.stream = at::cuda::getCurrentCUDAStream(comm_buf.get_device());
tensorrt_llm::kernels::mnnvl::twoshot_rmsnorm_op(rmsnorm_params);
return {normed_output, prenorm_output};
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"mnnvl_twoshot_allreduce(Tensor(output!) output, Tensor(input!) input, Tensor(comm_buf!) comm_buffer, "
"mnnvl_twoshot_allreduce(Tensor(input!) input, Tensor(comm_buf!) comm_buffer, "
"Tensor(buffer_flags!) buffer_flags, bool wait_for_result) -> Tensor");
m.def(
"mnnvl_twoshot_rmsnorm(Tensor prenorm_output, Tensor normed_output, Tensor input, Tensor gamma, "
"float epsilon, Tensor residual, Tensor buffer_flags) -> ()");
"mnnvl_twoshot_rmsnorm(Tensor comm_buf, Tensor gamma, "
"float epsilon, Tensor residual, Tensor buffer_flags) -> Tensor[]");
m.def(
"allreduce("
"Tensor input,"

View File

@ -54,6 +54,18 @@ def _register_fake():
else:
return [torch.empty_like(input)]
#MNNVL Allreduce
@torch.library.register_fake("trtllm::mnnvl_twoshot_allreduce")
def _(input, buffer, buffer_flags, wait_for_results):
output = input.new_empty(input.shape)
return output
@torch.library.register_fake("trtllm::mnnvl_twoshot_rmsnorm")
def _(comm_buf, gamma, eps, residual, buffer_flags):
output = residual.new_empty(residual.shape)
residual_out = residual.new_empty(residual.shape)
return [output, residual_out]
@torch.library.register_fake("trtllm::moe_allreduce")
def _(residual, norm_weight, device_num_experts, scale_input,
active_experts_token_input, token_input, workspace, rank, nranks,

View File

@ -345,8 +345,7 @@ class MNNVLAllReduce(nn.Module):
buffer_mnnvl = self.buffer_mnnvl.view(3, 2, -1, shape[-1])
if fusion_op == AllReduceFusionOp.NONE:
torch.ops.trtllm.mnnvl_twoshot_allreduce(
output,
output = torch.ops.trtllm.mnnvl_twoshot_allreduce(
input,
buffer_mnnvl,
self.buffer_flags_mnnvl,
@ -355,19 +354,16 @@ class MNNVLAllReduce(nn.Module):
return output.view(shape)
elif fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM:
torch.ops.trtllm.mnnvl_twoshot_allreduce(
output,
input,
buffer_mnnvl,
self.buffer_flags_mnnvl,
False,
)
residual_in = all_reduce_params.residual
residual_out = torch.empty_like(input)
torch.ops.trtllm.mnnvl_twoshot_rmsnorm(
residual_out, output, buffer_mnnvl,
all_reduce_params.norm_weight, all_reduce_params.eps,
residual_in, self.buffer_flags_mnnvl)
output, residual_out = torch.ops.trtllm.mnnvl_twoshot_rmsnorm(
buffer_mnnvl, all_reduce_params.norm_weight,
all_reduce_params.eps, residual_in, self.buffer_flags_mnnvl)
return output.view(shape), residual_out.view(shape)
return None

View File

@ -156,9 +156,6 @@ def row_linear_residual_norm_fusion_forward(
)
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, fusion):
if not fusion:
pytest.skip("skip no fusion test")
torch.manual_seed(42)
dtype = torch.bfloat16
tensor_parallel_size = 2