mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5729697][fix] MNNVL Allreduce: use CUDA runtime instead of Macro to get SM version. (#10062)
Signed-off-by: Shiyu Li <shili@nvidia.com>
This commit is contained in:
parent
48c875f8ea
commit
3ddc9d2b48
@ -230,59 +230,62 @@ inline __device__ __host__ T divUp(T m, T n)
|
||||
// Return (block_size, cluster_size, loads_per_thread)
|
||||
std::tuple<int, int, int> adjustGridConfig(int numTokens, int dim, int eltsPerThread)
|
||||
{
|
||||
// Start with preferred block_size and cluster_size
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
int clusterSize = 8;
|
||||
#else
|
||||
int clusterSize = 1;
|
||||
#endif
|
||||
static int SM = tensorrt_llm::common::getSMVersion();
|
||||
|
||||
int clusterSize = SM >= 90 ? 8 : 1;
|
||||
int blockSize = 128;
|
||||
// ========================== Adjust the grid configuration ==========================
|
||||
int threadsNeeded = divUp(dim, eltsPerThread);
|
||||
int loadsPerThread = 1;
|
||||
|
||||
blockSize = divUp(threadsNeeded, clusterSize);
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
|
||||
if (clusterSize > 1)
|
||||
{
|
||||
clusterSize /= 2;
|
||||
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
|
||||
{
|
||||
clusterSize /= 2;
|
||||
}
|
||||
blockSize = divUp(threadsNeeded, clusterSize);
|
||||
while (blockSize < 128 && clusterSize >= 2)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
int smCount = getMultiProcessorCount();
|
||||
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
}
|
||||
blockSize = divUp(threadsNeeded, clusterSize);
|
||||
while (blockSize < 128 && clusterSize >= 2)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
int smCount = getMultiProcessorCount();
|
||||
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Trying to scale up use multiple loads or CGA
|
||||
while (blockSize > 1024)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
if (clusterSize < 8)
|
||||
// Scale up with CGA if supported
|
||||
if (SM >= 90)
|
||||
{
|
||||
clusterSize = clusterSize << 1;
|
||||
if (clusterSize < 8)
|
||||
{
|
||||
clusterSize = clusterSize << 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
|
||||
if (loadsPerThread < 8)
|
||||
{
|
||||
loadsPerThread += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (loadsPerThread < 8)
|
||||
{
|
||||
loadsPerThread += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
blockSize = divUp(threadsNeeded, clusterSize * loadsPerThread);
|
||||
}
|
||||
return {blockSize, clusterSize, loadsPerThread};
|
||||
@ -420,9 +423,9 @@ __global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPt
|
||||
}
|
||||
float blockSum = blockReduceSum<float, true>(threadSum);
|
||||
|
||||
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
|
||||
float fullSum = blockSum;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
|
||||
namespace cg = cooperative_groups;
|
||||
cg::cluster_group cluster = cg::this_cluster();
|
||||
int const numBlocks = cluster.num_blocks();
|
||||
@ -459,6 +462,8 @@ using detail::adjustGridConfig;
|
||||
|
||||
void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
{
|
||||
|
||||
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
|
||||
int const numTokens = params.numTokens;
|
||||
int const tokenDim = params.tokenDim;
|
||||
int const eltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
|
||||
@ -466,38 +471,31 @@ void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig(numTokens, tokenDim, eltsPerThread);
|
||||
dim3 grid(numTokens, clusterSize, 1);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
|
||||
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
1024 * 8 * eltsPerThread);
|
||||
#else
|
||||
1024 * eltsPerThread);
|
||||
#endif
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
|
||||
"loads_per_thread: %d, "
|
||||
"threads_needed: %d",
|
||||
numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp(tokenDim, eltsPerThread));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
|
||||
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
|
||||
1024 * (kSMVersion >= 90 ? 8 : 1) * eltsPerThread);
|
||||
|
||||
cudaLaunchAttribute attrs[2];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
attrs[1].id = cudaLaunchAttributeClusterDimension;
|
||||
attrs[1].val.clusterDim.x = 1;
|
||||
attrs[1].val.clusterDim.y = clusterSize;
|
||||
attrs[1].val.clusterDim.z = 1;
|
||||
#endif
|
||||
|
||||
cudaLaunchConfig_t config
|
||||
{
|
||||
.gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0, .stream = params.stream, .attrs = attrs,
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
.numAttrs = 2,
|
||||
#else
|
||||
.numAttrs = 1,
|
||||
#endif
|
||||
cudaLaunchConfig_t config{
|
||||
.gridDim = grid,
|
||||
.blockDim = blockSize,
|
||||
.dynamicSmemBytes = 0,
|
||||
.stream = params.stream,
|
||||
.attrs = attrs,
|
||||
.numAttrs = kSMVersion >= 90 ? 2U : 1U,
|
||||
};
|
||||
|
||||
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, RMSNORM) \
|
||||
@ -831,9 +829,9 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
|
||||
float blockSum = blockReduceSum<float, true>(threadSum);
|
||||
|
||||
float fullSum = blockSum;
|
||||
__shared__ float sharedVal[8];
|
||||
// Use CGA Reduction if supported
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__shared__ float sharedVal[8];
|
||||
int const numBlocks = cluster.num_blocks();
|
||||
if (numBlocks > 1)
|
||||
{
|
||||
@ -876,6 +874,11 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
|
||||
}
|
||||
constexpr int kELTS_SIZE = sizeof(T_IN);
|
||||
|
||||
// Issue ACQBLK at the end. Assuming preceding kernel will not modify the buffer_flags.
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
// Update the buffer pointers
|
||||
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, worldSize) * worldSize * dim * kELTS_SIZE),
|
||||
static_cast<uint32_t>(numTokens * dim * kELTS_SIZE), 0, 0});
|
||||
@ -883,6 +886,7 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
|
||||
|
||||
void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
{
|
||||
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
|
||||
int const numTokens = params.numTokens;
|
||||
int const tokenDim = params.tokenDim;
|
||||
int const numEltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
|
||||
@ -959,17 +963,13 @@ void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
rnConfig.attrs = rnAttrs;
|
||||
rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
rnAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
||||
#ifndef DISABLE_CGA
|
||||
rnAttrs[1].id = cudaLaunchAttributeClusterDimension;
|
||||
rnAttrs[1].val.clusterDim.x = 1;
|
||||
rnAttrs[1].val.clusterDim.y = rnClusterSize;
|
||||
rnAttrs[1].val.clusterDim.z = 1;
|
||||
rnConfig.numAttrs = 2;
|
||||
#else
|
||||
rnConfig.numAttrs = 1;
|
||||
#endif
|
||||
rnConfig.numAttrs = (kSMVersion >= 90) ? 2U : 1U;
|
||||
|
||||
bool const rnUseCGA = rnClusterSize > 1;
|
||||
bool const rnUseCGA = kSMVersion >= 90 && rnClusterSize > 1;
|
||||
int const dimPadded = divUp(tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
|
||||
int const iters = dimPadded / rnNumThreads;
|
||||
|
||||
|
||||
@ -179,7 +179,7 @@ def row_linear_residual_norm_fusion_forward(
|
||||
], # Test for max_num_token fallback
|
||||
ids=lambda x: f"seqlen:{x}",
|
||||
)
|
||||
@pytest.mark.parametrize("hidden_size", [8, 2880, 7168, 7176, 8192],
|
||||
@pytest.mark.parametrize("hidden_size", [8, 2880, 7168, 7176, 8192, 16384],
|
||||
ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16],
|
||||
ids=lambda x: f"dtype:{torch.finfo(x).dtype}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user