mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Support RingAttention in the BertAttention plugin and the DiT model (#3661)
support ring attn for bert_attention plugin and dit model Signed-off-by: ChunhuanLin <lch_xdu@163.com>
This commit is contained in:
parent
9afe510367
commit
9477661f4c
139
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
Normal file
139
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
Normal file
@ -0,0 +1,139 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/kernels/recoverFromRingAtten.h"
|
||||
|
||||
#include "math.h"
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename Tout>
|
||||
__global__ void reduce4ring_attention(
|
||||
// this is the accumulated results for all finished ring attention blocks
|
||||
Tout* __restrict__ accu_output, // b x s_block x h x d
|
||||
float* __restrict__ accu_softmax_sum, // b x s_block x h
|
||||
float* __restrict__ accu_max, // b x s_block x h
|
||||
// this is the new ring attention block results
|
||||
Tout* __restrict__ output, // b x s_block x h x d
|
||||
float* __restrict__ softmax_sum, // b x s_block x h
|
||||
float* __restrict__ max, // b x s_block x h
|
||||
// necessary constant parameters
|
||||
int const b, int const s_block, int const h, int const d, int const block_seq_len, int* cu_seqlens)
|
||||
{
|
||||
auto block = cooperative_groups::this_thread_block();
|
||||
int batchid = blockIdx.x;
|
||||
int block_seq_idx = blockIdx.y;
|
||||
int block_s_start = block_seq_idx * block_seq_len;
|
||||
int block_s_end = (block_seq_idx + 1) * block_seq_len;
|
||||
block_s_end = s_block < block_s_end ? s_block : block_s_end;
|
||||
int64_t output_start_offset = batchid * s_block * d + block_s_start * d;
|
||||
int64_t lm_start_offset = batchid * s_block + block_s_start;
|
||||
|
||||
__shared__ cuda::barrier<cuda::thread_scope::thread_scope_block> barrier;
|
||||
if (block.thread_rank() == 0)
|
||||
{
|
||||
init(&barrier, block.size());
|
||||
}
|
||||
block.sync();
|
||||
|
||||
int s_len = block_s_end - block_s_start;
|
||||
int laneid = threadIdx.x % 32;
|
||||
int local_warpid = threadIdx.x / 32;
|
||||
int warp_num = blockDim.x / 32;
|
||||
int loop_on_s = (s_len + warp_num * 32 - 1) / (warp_num * 32);
|
||||
for (int l = 0; l < loop_on_s; l++)
|
||||
{
|
||||
int s_ = local_warpid + warp_num * laneid + l * warp_num * 32;
|
||||
float scaled_my_ss1_ = 1.0, scaled_my_ss2_ = 1.0;
|
||||
if (s_ < s_len)
|
||||
{
|
||||
uint64_t lm_start_offset_ = lm_start_offset + s_;
|
||||
float my_accu_ss = accu_softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : accu_softmax_sum[lm_start_offset_];
|
||||
float my_ss = softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : softmax_sum[lm_start_offset_];
|
||||
|
||||
float cur_max = (accu_max[lm_start_offset_] > max[lm_start_offset_]) ? accu_max[lm_start_offset_]
|
||||
: max[lm_start_offset_];
|
||||
float scale1 = exp(accu_max[lm_start_offset_] - cur_max);
|
||||
float scale2 = exp(max[lm_start_offset_] - cur_max);
|
||||
float cur_softmax_sum = my_accu_ss * scale1 + my_ss * scale2;
|
||||
if (cur_softmax_sum == 0)
|
||||
cur_softmax_sum = 1.0;
|
||||
scaled_my_ss1_ = scale1 * my_accu_ss / cur_softmax_sum;
|
||||
scaled_my_ss2_ = scale2 * my_ss / cur_softmax_sum;
|
||||
accu_softmax_sum[lm_start_offset_] = cur_softmax_sum;
|
||||
accu_max[lm_start_offset_] = cur_max;
|
||||
}
|
||||
int sid = l * warp_num * 32 + local_warpid;
|
||||
int s_end = (l + 1) * warp_num * 32 < s_len ? (l + 1) * warp_num * 32 : s_len;
|
||||
for (int ss = 0;; ss++)
|
||||
{
|
||||
uint64_t output_start_offset_ = output_start_offset + sid * d;
|
||||
float scaled_my_ss1 = __shfl_sync(0xffffffff, scaled_my_ss1_, ss, 32);
|
||||
float scaled_my_ss2 = __shfl_sync(0xffffffff, scaled_my_ss2_, ss, 32);
|
||||
for (int eid = laneid; eid < d; eid += 32)
|
||||
{
|
||||
accu_output[output_start_offset_ + eid]
|
||||
= (float) accu_output[output_start_offset_ + eid] * scaled_my_ss1
|
||||
+ (float) output[output_start_offset_ + eid] * scaled_my_ss2;
|
||||
}
|
||||
sid += warp_num;
|
||||
if (sid >= s_end)
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier.arrive_and_wait();
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename Tout>
|
||||
void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, float* softmax_stats, int b, int s,
|
||||
int h, int d, int* cu_seqlens, cudaStream_t stream)
|
||||
{
|
||||
float* accu_softmax_sum = accu_softmax_stats;
|
||||
float* accu_softmax_max = accu_softmax_stats + b * s * h;
|
||||
float* softmax_sum = softmax_stats;
|
||||
float* softmax_max = softmax_stats + b * s * h;
|
||||
|
||||
int threads_per_block = 128;
|
||||
int saturated_s_block_dim = 3000 / b + 1;
|
||||
s = s * h;
|
||||
int block_seq_len = (s / saturated_s_block_dim + 255) / 256 * 256;
|
||||
block_seq_len = block_seq_len < 256 ? 256 : block_seq_len;
|
||||
int dim_s = (s + block_seq_len - 1) / block_seq_len;
|
||||
|
||||
dim3 block_num(b, dim_s, 1);
|
||||
reduce4ring_attention<Tout><<<block_num, threads_per_block, 0, stream>>>(accu_output, accu_softmax_sum,
|
||||
accu_softmax_max, output, softmax_sum, softmax_max, b, s, h, d, block_seq_len, cu_seqlens);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_RECOVER_RA(Tout) \
|
||||
template void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, \
|
||||
float* softmax_stats, int b, int s, int h, int d, int* cu_seqlens, cudaStream_t stream)
|
||||
INSTANTIATE_RECOVER_RA(float);
|
||||
INSTANTIATE_RECOVER_RA(half);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_RECOVER_RA(__nv_bfloat16);
|
||||
#endif
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
34
cpp/tensorrt_llm/kernels/recoverFromRingAtten.h
Normal file
34
cpp/tensorrt_llm/kernels/recoverFromRingAtten.h
Normal file
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename Tout>
|
||||
void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, float* softmax_stats, int b, int s,
|
||||
int h, int d, int* cu_seqlens, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -16,6 +16,7 @@
|
||||
*/
|
||||
#include "bertAttentionPlugin.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/recoverFromRingAtten.h"
|
||||
#include "tensorrt_llm/kernels/sageAttentionKernels.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
@ -35,7 +36,7 @@ std::vector<nvinfer1::PluginField> BertAttentionPluginCreator::mPluginAttributes
|
||||
BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_scaling,
|
||||
ContextFMHAType context_fmha_type, nvinfer1::DataType type, bool do_relative_attention, int max_distance,
|
||||
bool remove_padding, bool sage_attn, int sage_attn_q_block_size, int sage_attn_k_block_size,
|
||||
int sage_attn_v_block_size)
|
||||
int sage_attn_v_block_size, int cp_size, int cp_rank, std::set<int> cp_group)
|
||||
: mNumHeads(num_heads)
|
||||
, mHeadSize(head_size)
|
||||
, mQScaling(q_scaling)
|
||||
@ -46,6 +47,9 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
|
||||
, mEnableContextFMHA(context_fmha_type != ContextFMHAType::DISABLED)
|
||||
, mFMHAForceFP32Acc(context_fmha_type == ContextFMHAType::ENABLED_WITH_FP32_ACC)
|
||||
, mSageAttn(sage_attn)
|
||||
, mCpSize(cp_size)
|
||||
, mCpRank(cp_rank)
|
||||
, mCpGroup(std::move(cp_group))
|
||||
{
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
if (mEnableContextFMHA)
|
||||
@ -84,6 +88,11 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
|
||||
mSageAttnKBlockSize, mSageAttnVBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
if (cp_group.size() > 1 && !mEnableContextFMHA)
|
||||
{
|
||||
TLLM_LOG_ERROR("Unfused MHA do not support context parallel now.");
|
||||
}
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
@ -104,6 +113,15 @@ BertAttentionPlugin::BertAttentionPlugin(void const* data, size_t length)
|
||||
read(d, mSageAttnQBlockSize);
|
||||
read(d, mSageAttnKBlockSize);
|
||||
read(d, mSageAttnVBlockSize);
|
||||
read(d, mCpSize);
|
||||
read(d, mCpRank);
|
||||
mCpGroup.clear();
|
||||
int groupItem = 0;
|
||||
while (d != a + length)
|
||||
{
|
||||
read(d, groupItem);
|
||||
mCpGroup.insert(groupItem);
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(d == a + length,
|
||||
"Expected length (%d) != real length (%d). This is often "
|
||||
@ -207,7 +225,20 @@ size_t BertAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
|
||||
? (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
|
||||
: sage_quant_space_size;
|
||||
|
||||
int const NUM_BUFFERS = 18;
|
||||
// workspace for RingAttention ping-pong buffer
|
||||
bool const enableRingAttn = (mCpGroup.size() > 1);
|
||||
const size_t ring_q_buf_size = enableRingAttn ? size * batch_size * input_seq_len * local_hidden_units_ : 0;
|
||||
const size_t ring_kv_buf_size = enableRingAttn
|
||||
? 2 * size * batch_size * input_seq_len * local_hidden_units_ + sizeof(int) * (batch_size + 1)
|
||||
: 0;
|
||||
const size_t ring_softmax_stats_buf_size
|
||||
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
|
||||
const size_t ring_softmax_stats_accu_buf_size
|
||||
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
|
||||
const size_t ring_block_output_size = enableRingAttn ? size * batch_size * input_seq_len * local_hidden_units_ : 0;
|
||||
|
||||
int const NUM_BUFFERS = 24;
|
||||
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
|
||||
workspaces[1] = attention_mask_size;
|
||||
@ -227,6 +258,12 @@ size_t BertAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
|
||||
workspaces[15] = scale_bmm1_device_size;
|
||||
workspaces[16] = scale_bmm2_device_size;
|
||||
workspaces[17] = sage_quant_space_size;
|
||||
workspaces[18] = ring_q_buf_size;
|
||||
workspaces[19] = ring_kv_buf_size; // kv1
|
||||
workspaces[20] = ring_kv_buf_size; // kv2
|
||||
workspaces[21] = ring_softmax_stats_buf_size;
|
||||
workspaces[22] = ring_softmax_stats_accu_buf_size;
|
||||
workspaces[23] = ring_block_output_size;
|
||||
|
||||
return tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
}
|
||||
@ -315,6 +352,15 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
? (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
|
||||
: sage_quant_space_size;
|
||||
|
||||
bool const enableRingAttn = (mCpGroup.size() > 1);
|
||||
const size_t ring_q_buf_size = enableRingAttn ? sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
|
||||
const size_t ring_kv_buf_size
|
||||
= enableRingAttn ? 2 * sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
|
||||
const size_t ring_softmax_stats_buf_size
|
||||
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
|
||||
const size_t ring_block_output_size
|
||||
= enableRingAttn ? sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
|
||||
|
||||
// Workspace pointer shift
|
||||
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
|
||||
size_t offset = CUBLAS_WORKSPACE_SIZE;
|
||||
@ -344,6 +390,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
void* sage_quant_space_ptr
|
||||
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, sage_quant_space_size));
|
||||
|
||||
T* ring_q_buf_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_q_buf_size));
|
||||
T* ring_kv_buf_1_ = reinterpret_cast<T*>(
|
||||
tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_kv_buf_size + sizeof(int) * (batch_size + 1)));
|
||||
T* ring_kv_buf_2_ = reinterpret_cast<T*>(
|
||||
tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_kv_buf_size + sizeof(int) * (batch_size + 1)));
|
||||
float* ring_softmax_stats_buf_
|
||||
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_softmax_stats_buf_size));
|
||||
float* ring_softmax_accu_stats_buf_
|
||||
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_softmax_stats_buf_size));
|
||||
T* ring_block_output_
|
||||
= reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_block_output_size));
|
||||
|
||||
// build attention_mask, cu_seqlens, and padding_offset tensors
|
||||
BuildDecoderInfoParams<T> params{};
|
||||
params.seqQOffsets = cu_seqlens;
|
||||
@ -382,176 +440,279 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
// We update mEnableContextFMHA in constructor to check this condition
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
|
||||
&& mSageAttnVBlockSize == 256)
|
||||
if (enableRingAttn)
|
||||
{
|
||||
sage_quant<72, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
|
||||
sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
// make sure the padding part of key/value buffer is 0
|
||||
cudaMemsetAsync(ring_kv_buf_1_, 0,
|
||||
reinterpret_cast<int8_t*>(ring_kv_buf_2_) - reinterpret_cast<int8_t*>(ring_kv_buf_1_), stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
|
||||
&& mSageAttnVBlockSize == 256)
|
||||
{
|
||||
sage_quant<80, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
|
||||
sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
cudaMemcpyAsync(ring_q_buf_, attention_input, ring_q_buf_size, cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(ring_kv_buf_1_,
|
||||
const_cast<char*>(reinterpret_cast<char const*>(attention_input)) + ring_q_buf_size, ring_kv_buf_size,
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(reinterpret_cast<char*>(ring_kv_buf_1_) + ring_kv_buf_size, cu_seqlens,
|
||||
sizeof(int) * (batch_size + 1), cudaMemcpyDeviceToDevice, stream);
|
||||
// init softmax_stats
|
||||
cudaMemsetAsync(ring_softmax_accu_stats_buf_, 0, ring_softmax_stats_buf_size, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
|
||||
&& mSageAttnVBlockSize == 256)
|
||||
{
|
||||
sage_quant<128, 128, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
|
||||
sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
// relative position of prev/next rank in cp group
|
||||
int prev_rank = mCpRank > 0 ? mCpRank - 1 : mCpGroup.size() - 1;
|
||||
int next_rank = (mCpRank == static_cast<int>(mCpGroup.size() - 1)) ? 0 : mCpRank + 1;
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
|
||||
&& mSageAttnVBlockSize == 32)
|
||||
{
|
||||
sage_quant<128, 128, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
|
||||
sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
common::check_cuda_error(cudaStreamCreate(&mNcclStream));
|
||||
common::check_cuda_error(cudaStreamSynchronize(stream));
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
|
||||
&& mSageAttnVBlockSize == 32)
|
||||
{
|
||||
sage_quant<80, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
|
||||
sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
|
||||
&& mSageAttnVBlockSize == 32)
|
||||
{
|
||||
sage_quant<72, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
|
||||
sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
// Construct the fmha params for running kernels.
|
||||
MHARunnerParams fmhaParams{};
|
||||
fmhaParams.b = request_batch_size;
|
||||
fmhaParams.qSeqLen = request_seq_len;
|
||||
fmhaParams.kvSeqLen = request_seq_len;
|
||||
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
|
||||
// Device buffer pointers.
|
||||
fmhaParams.qkvPtr = attention_input;
|
||||
fmhaParams.outputPtr = context_buf_;
|
||||
fmhaParams.cuQSeqLenPtr = cu_seqlens;
|
||||
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
|
||||
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
|
||||
fmhaParams.stream = stream;
|
||||
if (mSageAttn)
|
||||
{
|
||||
if (paddedHeadSize != mHeadSize)
|
||||
fmhaParams.outputPtr = sage_quant_space_ptr;
|
||||
fmhaParams.qkvPtr = quanted_qkv_ptr;
|
||||
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
|
||||
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
|
||||
fmhaParams.qScalePtr = q_scale_ptr;
|
||||
fmhaParams.kScalePtr = k_scale_ptr;
|
||||
fmhaParams.vScalePtr = v_scale_ptr;
|
||||
fmhaParams.qMaxNBlock = (input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize;
|
||||
fmhaParams.kMaxNBlock = (input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize;
|
||||
fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize;
|
||||
}
|
||||
|
||||
// Run the fmha kernel.
|
||||
mFMHARunner->run(fmhaParams);
|
||||
sync_check_cuda_error(stream);
|
||||
if (mSageAttn)
|
||||
{
|
||||
if (paddedHeadSize != mHeadSize && mHeadSize == 72)
|
||||
uint32_t* fmha_scheduler_counter_h = (uint32_t*) malloc(sizeof(uint32_t));
|
||||
cudaMemcpyAsync(
|
||||
fmha_scheduler_counter_h, fmha_tile_counter_ptr, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
|
||||
for (size_t iter = 0; iter < mCpGroup.size(); ++iter)
|
||||
{
|
||||
unpadding<80, 72, __nv_bfloat16>(batch_size, mNumHeads, input_seq_len, sage_quant_space_ptr,
|
||||
mNumHeads * 72, mNumHeads * 80, cu_seqlens, context_buf_, stream);
|
||||
// KV buffer used by fmha
|
||||
T* ring_fmha_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_1_ : ring_kv_buf_2_;
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
T* ring_send_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_1_ : ring_kv_buf_2_;
|
||||
T* ring_recv_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_2_ : ring_kv_buf_1_;
|
||||
if (iter < mCpGroup.size() - 1)
|
||||
{
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
|
||||
NCCLCHECK(ncclSend(ring_send_kv_buf_,
|
||||
ring_kv_buf_size / sizeof(T) + sizeof(int) / sizeof(T) * (batch_size + 1),
|
||||
(*getDtypeMap())[inputDesc[0].type], next_rank, *mNcclComm, mNcclStream));
|
||||
NCCLCHECK(ncclRecv(ring_recv_kv_buf_,
|
||||
ring_kv_buf_size / sizeof(T) + sizeof(int) / sizeof(T) * (batch_size + 1),
|
||||
(*getDtypeMap())[inputDesc[0].type], prev_rank, *mNcclComm, mNcclStream));
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
#else
|
||||
TLLM_LOG_ERROR("Please set ENABLE_MULTI_DEVICE to enable RingAttention");
|
||||
return 1;
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
// Construct the fmha params for running kernels.
|
||||
MHARunnerParams fmhaParams{};
|
||||
fmhaParams.b = request_batch_size;
|
||||
fmhaParams.qSeqLen = request_seq_len;
|
||||
fmhaParams.kvSeqLen = request_seq_len;
|
||||
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
|
||||
// Device buffer pointers.
|
||||
fmhaParams.qPtr = ring_q_buf_;
|
||||
fmhaParams.kvPtr = ring_fmha_kv_buf_;
|
||||
if (iter == 0)
|
||||
{
|
||||
fmhaParams.outputPtr = context_buf_;
|
||||
fmhaParams.softmaxStatsPtr = ring_softmax_accu_stats_buf_;
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaMemsetAsync(ring_softmax_stats_buf_, 0, ring_softmax_stats_buf_size, stream);
|
||||
fmhaParams.outputPtr = ring_block_output_;
|
||||
fmhaParams.softmaxStatsPtr = ring_softmax_stats_buf_;
|
||||
}
|
||||
fmhaParams.cuQSeqLenPtr = cu_seqlens;
|
||||
fmhaParams.cuKvSeqLenPtr
|
||||
= reinterpret_cast<int*>(reinterpret_cast<char*>(ring_fmha_kv_buf_) + ring_kv_buf_size);
|
||||
|
||||
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
|
||||
fmhaParams.stream = stream;
|
||||
// Run the fmha kernel.
|
||||
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
|
||||
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
mFMHARunner->run(fmhaParams);
|
||||
if (iter != 0)
|
||||
{
|
||||
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
|
||||
(T*) ring_block_output_, (float*) ring_softmax_stats_buf_, fmhaParams.b, fmhaParams.qSeqLen,
|
||||
mNumHeads, mHeadSize, cu_seqlens, stream);
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamSynchronize(mNcclStream);
|
||||
}
|
||||
common::check_cuda_error(cudaStreamDestroy(mNcclStream));
|
||||
free(fmha_scheduler_counter_h);
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
|
||||
&& mSageAttnVBlockSize == 256)
|
||||
{
|
||||
sage_quant<72, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize,
|
||||
attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
|
||||
cu_seqlens, sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
|
||||
&& mSageAttnVBlockSize == 256)
|
||||
{
|
||||
sage_quant<80, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize,
|
||||
attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
|
||||
cu_seqlens, sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
|
||||
&& mSageAttnVBlockSize == 256)
|
||||
{
|
||||
sage_quant<128, 128, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize,
|
||||
attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
|
||||
cu_seqlens, sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
|
||||
&& mSageAttnVBlockSize == 32)
|
||||
{
|
||||
sage_quant<128, 128, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize,
|
||||
attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
|
||||
cu_seqlens, sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
|
||||
&& mSageAttnVBlockSize == 32)
|
||||
{
|
||||
sage_quant<80, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize,
|
||||
attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
|
||||
cu_seqlens, sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
|
||||
&& mSageAttnVBlockSize == 32)
|
||||
{
|
||||
sage_quant<72, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
|
||||
// host var
|
||||
batch_size, mNumHeads, input_seq_len, true, true,
|
||||
// device var
|
||||
// q k v
|
||||
attention_input, attention_input + mNumHeads * mHeadSize,
|
||||
attention_input + 2 * mNumHeads * mHeadSize,
|
||||
// stride
|
||||
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
|
||||
cu_seqlens, sage_quant_space_ptr,
|
||||
// quant q k v
|
||||
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
|
||||
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
|
||||
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
|
||||
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
|
||||
// scales
|
||||
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
// Construct the fmha params for running kernels.
|
||||
MHARunnerParams fmhaParams{};
|
||||
fmhaParams.b = request_batch_size;
|
||||
fmhaParams.qSeqLen = request_seq_len;
|
||||
fmhaParams.kvSeqLen = request_seq_len;
|
||||
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
|
||||
// Device buffer pointers.
|
||||
fmhaParams.qkvPtr = attention_input;
|
||||
fmhaParams.outputPtr = context_buf_;
|
||||
fmhaParams.cuQSeqLenPtr = cu_seqlens;
|
||||
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
|
||||
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
|
||||
fmhaParams.stream = stream;
|
||||
if (mSageAttn)
|
||||
{
|
||||
if (paddedHeadSize != mHeadSize)
|
||||
fmhaParams.outputPtr = sage_quant_space_ptr;
|
||||
fmhaParams.qkvPtr = quanted_qkv_ptr;
|
||||
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
|
||||
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
|
||||
fmhaParams.qScalePtr = q_scale_ptr;
|
||||
fmhaParams.kScalePtr = k_scale_ptr;
|
||||
fmhaParams.vScalePtr = v_scale_ptr;
|
||||
fmhaParams.qMaxNBlock = (input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize;
|
||||
fmhaParams.kMaxNBlock = (input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize;
|
||||
fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize;
|
||||
}
|
||||
|
||||
// Run the fmha kernel.
|
||||
mFMHARunner->run(fmhaParams);
|
||||
sync_check_cuda_error(stream);
|
||||
if (mSageAttn)
|
||||
{
|
||||
if (paddedHeadSize != mHeadSize && mHeadSize == 72)
|
||||
{
|
||||
unpadding<80, 72, __nv_bfloat16>(batch_size, mNumHeads, input_seq_len, sage_quant_space_ptr,
|
||||
mNumHeads * 72, mNumHeads * 80, cu_seqlens, context_buf_, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -780,6 +941,12 @@ int BertAttentionPlugin::initialize() noexcept
|
||||
fmhaParams.headSize = paddedHeadSize;
|
||||
}
|
||||
|
||||
if (mCpGroup.size() > 1)
|
||||
{
|
||||
fmhaParams.attentionInputLayout = AttentionInputLayout::Q_CONTIGUOUS_KV;
|
||||
fmhaParams.saveSoftmax = true;
|
||||
}
|
||||
|
||||
// Load kernels from the pre-compiled cubins.
|
||||
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));
|
||||
|
||||
@ -787,6 +954,15 @@ int BertAttentionPlugin::initialize() noexcept
|
||||
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
|
||||
}
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (mCpGroup.size() > 1 && COMM_SESSION.getSize() > 1)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
|
||||
mNcclComm = getComm(mCpGroup);
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
|
||||
}
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -800,7 +976,7 @@ size_t BertAttentionPlugin::getSerializationSize() const noexcept
|
||||
return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mQScaling) + sizeof(mQKHalfAccum) + sizeof(mEnableContextFMHA)
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mType) + sizeof(mRelativeAttention) + sizeof(mMaxDistance)
|
||||
+ sizeof(mRemovePadding) + sizeof(mSageAttn) + sizeof(mSageAttnQBlockSize) + sizeof(mSageAttnKBlockSize)
|
||||
+ sizeof(mSageAttnVBlockSize);
|
||||
+ sizeof(mSageAttnVBlockSize) + sizeof(mCpSize) + sizeof(mCpRank) + sizeof(int32_t) * mCpGroup.size();
|
||||
}
|
||||
|
||||
void BertAttentionPlugin::serialize(void* buffer) const noexcept
|
||||
@ -820,6 +996,12 @@ void BertAttentionPlugin::serialize(void* buffer) const noexcept
|
||||
write(d, mSageAttnQBlockSize);
|
||||
write(d, mSageAttnKBlockSize);
|
||||
write(d, mSageAttnVBlockSize);
|
||||
write(d, mCpSize);
|
||||
write(d, mCpRank);
|
||||
for (auto it = mCpGroup.begin(); it != mCpGroup.end(); ++it)
|
||||
{
|
||||
write(d, *it);
|
||||
}
|
||||
TLLM_CHECK(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
@ -844,6 +1026,9 @@ BertAttentionPluginCreator::BertAttentionPluginCreator()
|
||||
mPluginAttributes.emplace_back(PluginField("sage_attn_q_block_size", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("sage_attn_k_block_size", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("sage_attn_v_block_size", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("cp_size", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("cp_rank", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("cp_group", nullptr, PluginFieldType::kINT32));
|
||||
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
@ -879,6 +1064,9 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(char const* name, PluginFiel
|
||||
int sage_attn_q_block_size{};
|
||||
int sage_attn_k_block_size{};
|
||||
int sage_attn_v_block_size{};
|
||||
int cp_size{};
|
||||
int cp_rank{};
|
||||
std::set<int> cp_group{};
|
||||
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
@ -948,12 +1136,32 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(char const* name, PluginFiel
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
sage_attn_v_block_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "cp_size"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
cp_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "cp_rank"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
cp_rank = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "cp_group"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
auto const* r = static_cast<int const*>(fields[i].data);
|
||||
for (int j = 0; j < fields[i].length; ++j)
|
||||
{
|
||||
cp_group.insert(*r);
|
||||
++r;
|
||||
}
|
||||
}
|
||||
}
|
||||
try
|
||||
{
|
||||
auto* obj = new BertAttentionPlugin(num_heads, head_size, q_scaling, context_fmha_type, type,
|
||||
do_relative_attention, max_distance, remove_padding, sage_attn, sage_attn_q_block_size,
|
||||
sage_attn_k_block_size, sage_attn_v_block_size);
|
||||
sage_attn_k_block_size, sage_attn_v_block_size, cp_size, cp_rank, cp_group);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -21,7 +21,9 @@
|
||||
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -37,7 +39,8 @@ public:
|
||||
BertAttentionPlugin(int num_heads, int head_size, float q_scaling,
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, nvinfer1::DataType type,
|
||||
bool do_relative_attention = false, int max_distance = 0, bool remove_padding = false, bool sage_attn = false,
|
||||
int sage_attn_q_block_size = 0, int sage_attn_k_block_size = 0, int sage_attn_v_block_size = 0);
|
||||
int sage_attn_q_block_size = 0, int sage_attn_k_block_size = 0, int sage_attn_v_block_size = 0, int cp_size = 1,
|
||||
int cp_rank = 0, std::set<int> cp_group = {});
|
||||
|
||||
BertAttentionPlugin(void const* data, size_t length);
|
||||
|
||||
@ -101,6 +104,15 @@ private:
|
||||
|
||||
int mSM = tensorrt_llm::common::getSMVersion();
|
||||
|
||||
// comm group for RingAttention
|
||||
int mCpSize = 1;
|
||||
int mCpRank = 0;
|
||||
std::set<int> mCpGroup = {};
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
std::shared_ptr<ncclComm_t> mNcclComm;
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
cudaStream_t mNcclStream;
|
||||
|
||||
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
|
||||
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
||||
|
||||
@ -4449,7 +4449,10 @@ def bert_attention(tensor: Tensor,
|
||||
sage_attn: bool = False,
|
||||
sage_attn_q_block_size: int = 0,
|
||||
sage_attn_k_block_size: int = 0,
|
||||
sage_attn_v_block_size: int = 0) -> Tuple[Tensor]:
|
||||
sage_attn_v_block_size: int = 0,
|
||||
cp_group: list[int] = None,
|
||||
cp_size: int = 1,
|
||||
cp_rank: int = 0) -> Tuple[Tensor]:
|
||||
'''
|
||||
Add an operation that performs the multi-head attention in BERT.
|
||||
|
||||
@ -4516,6 +4519,15 @@ def bert_attention(tensor: Tensor,
|
||||
sage_attn_v_quant_size: int = 0
|
||||
dynamic quant block size along sequence dimension of v tensor. Each quant block will share one scale.
|
||||
|
||||
cp_group: list[int] = None
|
||||
The communication group for context parallel
|
||||
|
||||
cp_size: int = 1
|
||||
The communication size for context parallel
|
||||
|
||||
cp_rank: int = 0
|
||||
The communication rank for context parallel
|
||||
|
||||
Returns:
|
||||
The tensor produced by that layer.
|
||||
'''
|
||||
@ -4570,10 +4582,31 @@ def bert_attention(tensor: Tensor,
|
||||
np.array(sage_attn_v_block_size, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
|
||||
if cp_size > 1:
|
||||
# transpose q,k,v inside qkv to make kv contiguous, which is required by ring attention
|
||||
# (b, s, 3d)
|
||||
query, key, value = chunk(tensor, 3, dim=-1)
|
||||
bs = shape(query, 0)
|
||||
seq_len = shape(query, 1)
|
||||
# (b, s, d) -> (b, s, 2d) -> (2b, s, d)
|
||||
kv = concat([key, value],
|
||||
dim=-1).view(concat((2 * bs, seq_len, query.shape[-1])))
|
||||
tensor = concat((query, kv),
|
||||
dim=0).view(concat((bs, seq_len, query.shape[-1] * 3)))
|
||||
|
||||
cp_size = trt.PluginField("cp_size", np.array(cp_size, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
cp_rank = trt.PluginField("cp_rank", np.array(cp_rank, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
cp_group = cp_group or [0]
|
||||
cp_group = np.array(cp_group, dtype=np.int32)
|
||||
cp_group = trt.PluginField("cp_group", cp_group, trt.PluginFieldType.INT32)
|
||||
|
||||
pfc = trt.PluginFieldCollection([
|
||||
nheads, head_size, q_scaling, context_fmha_type, pf_type,
|
||||
do_relative_attention, max_distance, remove_padding, sage_attn,
|
||||
sage_attn_q_block_size, sage_attn_k_block_size, sage_attn_v_block_size
|
||||
sage_attn_q_block_size, sage_attn_k_block_size, sage_attn_v_block_size,
|
||||
cp_size, cp_rank, cp_group
|
||||
])
|
||||
|
||||
attn_plug = attn_plg_creator.create_plugin("padding_attn", pfc)
|
||||
|
||||
@ -1608,6 +1608,7 @@ class BertAttention(Module):
|
||||
tp_rank=0,
|
||||
cp_group=None,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
relative_attention=False,
|
||||
max_distance=0,
|
||||
num_buckets=0,
|
||||
@ -1628,6 +1629,7 @@ class BertAttention(Module):
|
||||
self.tp_rank = tp_rank
|
||||
self.cp_group = cp_group
|
||||
self.cp_size = cp_size
|
||||
self.cp_rank = cp_rank
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
@ -1725,7 +1727,6 @@ class BertAttention(Module):
|
||||
if default_net().plugin_config.bert_attention_plugin:
|
||||
# TRT plugin mode
|
||||
assert input_lengths is not None
|
||||
assert self.cp_size == 1
|
||||
assert get_sm_version() < 100 or get_sm_version() >= 120, \
|
||||
"bert_attention_plugin does not support SM100"
|
||||
context = bert_attention(
|
||||
@ -1738,7 +1739,10 @@ class BertAttention(Module):
|
||||
max_distance=self.max_distance,
|
||||
relative_attention_bias=self.rel_attn_table.value
|
||||
if self.relative_attention else None,
|
||||
max_input_length=max_input_length)
|
||||
max_input_length=max_input_length,
|
||||
cp_group=self.cp_group,
|
||||
cp_size=self.cp_size,
|
||||
cp_rank=self.cp_rank)
|
||||
else:
|
||||
# plain TRT mode
|
||||
def transpose_for_scores(x):
|
||||
|
||||
@ -134,6 +134,7 @@ class DiTBlock(Module):
|
||||
tp_rank=mapping.tp_rank,
|
||||
cp_group=mapping.cp_group,
|
||||
cp_size=mapping.cp_size,
|
||||
cp_rank=mapping.cp_rank,
|
||||
dtype=dtype,
|
||||
quant_mode=quant_mode)
|
||||
self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
@ -299,6 +300,7 @@ class DiT(PretrainedModel):
|
||||
if self.mapping.cp_size > 1:
|
||||
assert x.shape[1] % self.mapping.cp_size == 0
|
||||
x = chunk(x, self.mapping.cp_size, dim=1)[self.mapping.cp_rank]
|
||||
input_lengths = input_lengths // self.mapping.cp_size
|
||||
for block in self.blocks:
|
||||
x = block(x, c, input_lengths) # (N, T, D)
|
||||
self.register_network_output('before_final_layer', x)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user