mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-4647][fix] Fix the no fusion allreduce hanging (#4594)
Signed-off-by: Shiyu Li <shili@nvidia.com>
This commit is contained in:
parent
8433091630
commit
b0d287c9b7
@ -80,6 +80,15 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
uint32_t input_offset = buffer_flags[0] * buffer_size;
|
||||
uint32_t clear_offset = buffer_flags[1] * buffer_size;
|
||||
|
||||
if (wait_for_results)
|
||||
{
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(offset_access_ptr, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (elt < token_dim)
|
||||
{
|
||||
// Scatter token
|
||||
@ -312,13 +321,14 @@ __global__ void __launch_bounds__(128, 1)
|
||||
|
||||
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
|
||||
uint32_t* offset_access_ptr = &buffer_flags[3];
|
||||
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
|
||||
uint32_t buffer_size = buffer_flags[2];
|
||||
uint32_t buffer_offset = buffer_flags[0] * (buffer_size << 1);
|
||||
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
|
||||
@ -30,19 +30,6 @@ namespace tensorrt_llm::runtime
|
||||
|
||||
namespace
|
||||
{
|
||||
#define CUCHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
CUresult retval = cmd; \
|
||||
if (retval != CUDA_SUCCESS) \
|
||||
{ \
|
||||
const char* error_string; \
|
||||
cuGetErrorString(retval, &error_string); \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// An efficient implementation assuming gran is a power of 2
|
||||
inline size_t roundUp(size_t val, size_t gran)
|
||||
{
|
||||
@ -66,7 +53,7 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
cudaSetDevice(mDeviceIdx);
|
||||
// Check if the device support multicasting
|
||||
int multicast_supported{0};
|
||||
CUCHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
|
||||
TLLM_CU_CHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
|
||||
if (multicast_supported == 0)
|
||||
{
|
||||
TLLM_THROW("[McastDeviceMemory] Device does not support multicasting.");
|
||||
@ -83,7 +70,7 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
{
|
||||
// For multi-node, we also need to check if fabric handle is supported
|
||||
int fabric_handle_supported{0};
|
||||
CUCHECK(cuDeviceGetAttribute(
|
||||
TLLM_CU_CHECK(cuDeviceGetAttribute(
|
||||
&fabric_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, mDeviceIdx));
|
||||
if (fabric_handle_supported == 0)
|
||||
{
|
||||
@ -142,12 +129,15 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = mDeviceIdx;
|
||||
prop.allocFlags.gpuDirectRDMACapable = 1;
|
||||
|
||||
size_t granularity{0};
|
||||
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
size_t alloc_granularity{0}, mc_granularity{0};
|
||||
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&alloc_granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
// Round up the buffer size for grnularity
|
||||
mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, granularity);
|
||||
|
||||
mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity);
|
||||
CUmulticastObjectProp mcProp = {.numDevices = mGroupSize, .size = mAllocationSize, .handleTypes = handle_type};
|
||||
TLLM_CU_CHECK(cuMulticastGetGranularity(&mc_granularity, &mcProp, CU_MULTICAST_GRANULARITY_RECOMMENDED));
|
||||
mAllocationSize = roundUp(mAllocationSize, mc_granularity);
|
||||
mUcHandles.resize(mGroupSize);
|
||||
// Allocates local gpu memory
|
||||
TLLM_CU_CHECK(cuMemCreate(&(mUcHandles[mGroupRank]), mAllocationSize, &prop, 0));
|
||||
@ -170,8 +160,6 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
cudaFreeHost(exphndl);
|
||||
|
||||
// Initialize multicasting
|
||||
CUmulticastObjectProp mcProp
|
||||
= {.numDevices = mGroupSize, .size = mAllocationSize, .handleTypes = CU_MEM_HANDLE_TYPE_FABRIC};
|
||||
CUmemFabricHandle* fabric_handle;
|
||||
cudaMallocHost(&fabric_handle, sizeof(CUmemFabricHandle));
|
||||
if (mGroupRank == 0)
|
||||
@ -192,7 +180,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
// Bind memory addresses
|
||||
mUcPtrs.resize(mGroupSize);
|
||||
CUdeviceptr ptr;
|
||||
TLLM_CU_CHECK(cuMemAddressReserve(&ptr, mAllocationSize * mGroupSize, 0, 0, 0));
|
||||
TLLM_CU_CHECK(cuMemAddressReserve(&ptr, mAllocationSize * mGroupSize, mc_granularity, 0ULL, 0));
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
@ -206,7 +194,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
TLLM_CU_CHECK(cuMemSetAccess(ptr, mAllocationSize * mGroupSize, &accessDesc, 1));
|
||||
|
||||
// Bind MC Pointers
|
||||
TLLM_CU_CHECK(cuMemAddressReserve(&mMcPtr, mAllocationSize, 0, 0U, 0));
|
||||
TLLM_CU_CHECK(cuMemAddressReserve(&mMcPtr, mAllocationSize, mc_granularity, 0ULL, 0));
|
||||
TLLM_CU_CHECK(cuMemMap(mMcPtr, mAllocationSize, 0, mMcHandle, 0));
|
||||
TLLM_CU_CHECK(cuMemSetAccess(mMcPtr, mAllocationSize, &accessDesc, 1));
|
||||
|
||||
|
||||
@ -1044,12 +1044,13 @@ std::vector<torch::Tensor> moe_allreduce(torch::Tensor const& residual, torch::T
|
||||
}
|
||||
|
||||
at::Tensor mnnvlTwoShotAllReduce(
|
||||
at::Tensor& output, at::Tensor& input, at::Tensor& comm_buffer, at::Tensor& buffer_flags, bool wait_for_results)
|
||||
at::Tensor& input, at::Tensor& comm_buffer, at::Tensor& buffer_flags, bool wait_for_results)
|
||||
{
|
||||
auto* mcast_mem = tensorrt_llm::common::findMcastDevMemBuffer(comm_buffer.data_ptr());
|
||||
TORCH_CHECK(mcast_mem != nullptr, "two_shot_all_reduce: comm_buffer must be obtained from a mcastBuffer instance.");
|
||||
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
|
||||
at::Tensor output = torch::empty_like(input);
|
||||
|
||||
auto allreduce_params = tensorrt_llm::kernels::mnnvl::AllReduceParams();
|
||||
allreduce_params.dtype = dtype;
|
||||
@ -1071,35 +1072,42 @@ at::Tensor mnnvlTwoShotAllReduce(
|
||||
return output;
|
||||
}
|
||||
|
||||
void twoShotRMSNorm(torch::Tensor& prenorm_output, torch::Tensor& normed_output, torch::Tensor const& input,
|
||||
torch::Tensor const& gamma, double epsilon, torch::Tensor const& residual, torch::Tensor& buffer_flags)
|
||||
std::vector<torch::Tensor> twoShotRMSNorm(torch::Tensor const& comm_buf, torch::Tensor const& gamma, double epsilon,
|
||||
torch::Tensor const& residual, torch::Tensor& buffer_flags)
|
||||
{
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(comm_buf.scalar_type());
|
||||
auto rmsnorm_params = tensorrt_llm::kernels::mnnvl::RMSNormParams();
|
||||
|
||||
// Input is the communication buffer so we need to get the shape from residual
|
||||
torch::Tensor normed_output = torch::empty_like(residual);
|
||||
torch::Tensor prenorm_output = torch::empty_like(residual);
|
||||
|
||||
rmsnorm_params.dtype = dtype;
|
||||
rmsnorm_params.residual_output = prenorm_output.data_ptr();
|
||||
rmsnorm_params.output = normed_output.data_ptr();
|
||||
rmsnorm_params.input = input.data_ptr();
|
||||
rmsnorm_params.input = comm_buf.data_ptr();
|
||||
rmsnorm_params.gamma = gamma.data_ptr();
|
||||
rmsnorm_params.epsilon = epsilon;
|
||||
rmsnorm_params.residual = residual.data_ptr();
|
||||
rmsnorm_params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr());
|
||||
rmsnorm_params.batch = normed_output.size(0);
|
||||
rmsnorm_params.hidden_dim = normed_output.size(1);
|
||||
rmsnorm_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
rmsnorm_params.stream = at::cuda::getCurrentCUDAStream(comm_buf.get_device());
|
||||
|
||||
tensorrt_llm::kernels::mnnvl::twoshot_rmsnorm_op(rmsnorm_params);
|
||||
|
||||
return {normed_output, prenorm_output};
|
||||
}
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"mnnvl_twoshot_allreduce(Tensor(output!) output, Tensor(input!) input, Tensor(comm_buf!) comm_buffer, "
|
||||
"mnnvl_twoshot_allreduce(Tensor(input!) input, Tensor(comm_buf!) comm_buffer, "
|
||||
"Tensor(buffer_flags!) buffer_flags, bool wait_for_result) -> Tensor");
|
||||
m.def(
|
||||
"mnnvl_twoshot_rmsnorm(Tensor prenorm_output, Tensor normed_output, Tensor input, Tensor gamma, "
|
||||
"float epsilon, Tensor residual, Tensor buffer_flags) -> ()");
|
||||
"mnnvl_twoshot_rmsnorm(Tensor comm_buf, Tensor gamma, "
|
||||
"float epsilon, Tensor residual, Tensor buffer_flags) -> Tensor[]");
|
||||
m.def(
|
||||
"allreduce("
|
||||
"Tensor input,"
|
||||
|
||||
@ -54,6 +54,18 @@ def _register_fake():
|
||||
else:
|
||||
return [torch.empty_like(input)]
|
||||
|
||||
#MNNVL Allreduce
|
||||
@torch.library.register_fake("trtllm::mnnvl_twoshot_allreduce")
|
||||
def _(input, buffer, buffer_flags, wait_for_results):
|
||||
output = input.new_empty(input.shape)
|
||||
return output
|
||||
|
||||
@torch.library.register_fake("trtllm::mnnvl_twoshot_rmsnorm")
|
||||
def _(comm_buf, gamma, eps, residual, buffer_flags):
|
||||
output = residual.new_empty(residual.shape)
|
||||
residual_out = residual.new_empty(residual.shape)
|
||||
return [output, residual_out]
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_allreduce")
|
||||
def _(residual, norm_weight, device_num_experts, scale_input,
|
||||
active_experts_token_input, token_input, workspace, rank, nranks,
|
||||
|
||||
@ -345,8 +345,7 @@ class MNNVLAllReduce(nn.Module):
|
||||
buffer_mnnvl = self.buffer_mnnvl.view(3, 2, -1, shape[-1])
|
||||
|
||||
if fusion_op == AllReduceFusionOp.NONE:
|
||||
torch.ops.trtllm.mnnvl_twoshot_allreduce(
|
||||
output,
|
||||
output = torch.ops.trtllm.mnnvl_twoshot_allreduce(
|
||||
input,
|
||||
buffer_mnnvl,
|
||||
self.buffer_flags_mnnvl,
|
||||
@ -355,19 +354,16 @@ class MNNVLAllReduce(nn.Module):
|
||||
return output.view(shape)
|
||||
elif fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM:
|
||||
torch.ops.trtllm.mnnvl_twoshot_allreduce(
|
||||
output,
|
||||
input,
|
||||
buffer_mnnvl,
|
||||
self.buffer_flags_mnnvl,
|
||||
False,
|
||||
)
|
||||
residual_in = all_reduce_params.residual
|
||||
residual_out = torch.empty_like(input)
|
||||
|
||||
torch.ops.trtllm.mnnvl_twoshot_rmsnorm(
|
||||
residual_out, output, buffer_mnnvl,
|
||||
all_reduce_params.norm_weight, all_reduce_params.eps,
|
||||
residual_in, self.buffer_flags_mnnvl)
|
||||
output, residual_out = torch.ops.trtllm.mnnvl_twoshot_rmsnorm(
|
||||
buffer_mnnvl, all_reduce_params.norm_weight,
|
||||
all_reduce_params.eps, residual_in, self.buffer_flags_mnnvl)
|
||||
return output.view(shape), residual_out.view(shape)
|
||||
return None
|
||||
|
||||
|
||||
@ -156,9 +156,6 @@ def row_linear_residual_norm_fusion_forward(
|
||||
)
|
||||
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, fusion):
|
||||
|
||||
if not fusion:
|
||||
pytest.skip("skip no fusion test")
|
||||
|
||||
torch.manual_seed(42)
|
||||
dtype = torch.bfloat16
|
||||
tensor_parallel_size = 2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user