Support RingAttention in the BertAttention plugin and the DiT model (#3661)

support ring attn for bert_attention plugin and dit model

Signed-off-by: ChunhuanLin <lch_xdu@163.com>
This commit is contained in:
forrestl 2025-05-09 08:06:54 +08:00 committed by GitHub
parent 9afe510367
commit 9477661f4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 605 additions and 173 deletions

View File

@ -0,0 +1,139 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/kernels/recoverFromRingAtten.h"
#include "math.h"
#include <cooperative_groups.h>
#include <cuda/barrier>
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
template <typename Tout>
__global__ void reduce4ring_attention(
// this is the accumulated results for all finished ring attention blocks
Tout* __restrict__ accu_output, // b x s_block x h x d
float* __restrict__ accu_softmax_sum, // b x s_block x h
float* __restrict__ accu_max, // b x s_block x h
// this is the new ring attention block results
Tout* __restrict__ output, // b x s_block x h x d
float* __restrict__ softmax_sum, // b x s_block x h
float* __restrict__ max, // b x s_block x h
// necessary constant parameters
int const b, int const s_block, int const h, int const d, int const block_seq_len, int* cu_seqlens)
{
auto block = cooperative_groups::this_thread_block();
int batchid = blockIdx.x;
int block_seq_idx = blockIdx.y;
int block_s_start = block_seq_idx * block_seq_len;
int block_s_end = (block_seq_idx + 1) * block_seq_len;
block_s_end = s_block < block_s_end ? s_block : block_s_end;
int64_t output_start_offset = batchid * s_block * d + block_s_start * d;
int64_t lm_start_offset = batchid * s_block + block_s_start;
__shared__ cuda::barrier<cuda::thread_scope::thread_scope_block> barrier;
if (block.thread_rank() == 0)
{
init(&barrier, block.size());
}
block.sync();
int s_len = block_s_end - block_s_start;
int laneid = threadIdx.x % 32;
int local_warpid = threadIdx.x / 32;
int warp_num = blockDim.x / 32;
int loop_on_s = (s_len + warp_num * 32 - 1) / (warp_num * 32);
for (int l = 0; l < loop_on_s; l++)
{
int s_ = local_warpid + warp_num * laneid + l * warp_num * 32;
float scaled_my_ss1_ = 1.0, scaled_my_ss2_ = 1.0;
if (s_ < s_len)
{
uint64_t lm_start_offset_ = lm_start_offset + s_;
float my_accu_ss = accu_softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : accu_softmax_sum[lm_start_offset_];
float my_ss = softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : softmax_sum[lm_start_offset_];
float cur_max = (accu_max[lm_start_offset_] > max[lm_start_offset_]) ? accu_max[lm_start_offset_]
: max[lm_start_offset_];
float scale1 = exp(accu_max[lm_start_offset_] - cur_max);
float scale2 = exp(max[lm_start_offset_] - cur_max);
float cur_softmax_sum = my_accu_ss * scale1 + my_ss * scale2;
if (cur_softmax_sum == 0)
cur_softmax_sum = 1.0;
scaled_my_ss1_ = scale1 * my_accu_ss / cur_softmax_sum;
scaled_my_ss2_ = scale2 * my_ss / cur_softmax_sum;
accu_softmax_sum[lm_start_offset_] = cur_softmax_sum;
accu_max[lm_start_offset_] = cur_max;
}
int sid = l * warp_num * 32 + local_warpid;
int s_end = (l + 1) * warp_num * 32 < s_len ? (l + 1) * warp_num * 32 : s_len;
for (int ss = 0;; ss++)
{
uint64_t output_start_offset_ = output_start_offset + sid * d;
float scaled_my_ss1 = __shfl_sync(0xffffffff, scaled_my_ss1_, ss, 32);
float scaled_my_ss2 = __shfl_sync(0xffffffff, scaled_my_ss2_, ss, 32);
for (int eid = laneid; eid < d; eid += 32)
{
accu_output[output_start_offset_ + eid]
= (float) accu_output[output_start_offset_ + eid] * scaled_my_ss1
+ (float) output[output_start_offset_ + eid] * scaled_my_ss2;
}
sid += warp_num;
if (sid >= s_end)
break;
}
}
barrier.arrive_and_wait();
return;
}
template <typename Tout>
void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, float* softmax_stats, int b, int s,
int h, int d, int* cu_seqlens, cudaStream_t stream)
{
float* accu_softmax_sum = accu_softmax_stats;
float* accu_softmax_max = accu_softmax_stats + b * s * h;
float* softmax_sum = softmax_stats;
float* softmax_max = softmax_stats + b * s * h;
int threads_per_block = 128;
int saturated_s_block_dim = 3000 / b + 1;
s = s * h;
int block_seq_len = (s / saturated_s_block_dim + 255) / 256 * 256;
block_seq_len = block_seq_len < 256 ? 256 : block_seq_len;
int dim_s = (s + block_seq_len - 1) / block_seq_len;
dim3 block_num(b, dim_s, 1);
reduce4ring_attention<Tout><<<block_num, threads_per_block, 0, stream>>>(accu_output, accu_softmax_sum,
accu_softmax_max, output, softmax_sum, softmax_max, b, s, h, d, block_seq_len, cu_seqlens);
}
#define INSTANTIATE_RECOVER_RA(Tout) \
template void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, \
float* softmax_stats, int b, int s, int h, int d, int* cu_seqlens, cudaStream_t stream)
INSTANTIATE_RECOVER_RA(float);
INSTANTIATE_RECOVER_RA(half);
#ifdef ENABLE_BF16
INSTANTIATE_RECOVER_RA(__nv_bfloat16);
#endif
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,34 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace kernels
{
template <typename Tout>
void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, float* softmax_stats, int b, int s,
int h, int d, int* cu_seqlens, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -16,6 +16,7 @@
*/
#include "bertAttentionPlugin.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/recoverFromRingAtten.h"
#include "tensorrt_llm/kernels/sageAttentionKernels.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include "tensorrt_llm/runtime/iBuffer.h"
@ -35,7 +36,7 @@ std::vector<nvinfer1::PluginField> BertAttentionPluginCreator::mPluginAttributes
BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_scaling,
ContextFMHAType context_fmha_type, nvinfer1::DataType type, bool do_relative_attention, int max_distance,
bool remove_padding, bool sage_attn, int sage_attn_q_block_size, int sage_attn_k_block_size,
int sage_attn_v_block_size)
int sage_attn_v_block_size, int cp_size, int cp_rank, std::set<int> cp_group)
: mNumHeads(num_heads)
, mHeadSize(head_size)
, mQScaling(q_scaling)
@ -46,6 +47,9 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
, mEnableContextFMHA(context_fmha_type != ContextFMHAType::DISABLED)
, mFMHAForceFP32Acc(context_fmha_type == ContextFMHAType::ENABLED_WITH_FP32_ACC)
, mSageAttn(sage_attn)
, mCpSize(cp_size)
, mCpRank(cp_rank)
, mCpGroup(std::move(cp_group))
{
// pre-check whether FMHA is supported in order to save memory allocation
if (mEnableContextFMHA)
@ -84,6 +88,11 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
mSageAttnKBlockSize, mSageAttnVBlockSize);
}
}
if (cp_group.size() > 1 && !mEnableContextFMHA)
{
TLLM_LOG_ERROR("Unfused MHA do not support context parallel now.");
}
}
// Parameterized constructor
@ -104,6 +113,15 @@ BertAttentionPlugin::BertAttentionPlugin(void const* data, size_t length)
read(d, mSageAttnQBlockSize);
read(d, mSageAttnKBlockSize);
read(d, mSageAttnVBlockSize);
read(d, mCpSize);
read(d, mCpRank);
mCpGroup.clear();
int groupItem = 0;
while (d != a + length)
{
read(d, groupItem);
mCpGroup.insert(groupItem);
}
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
@ -207,7 +225,20 @@ size_t BertAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
? (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
: sage_quant_space_size;
int const NUM_BUFFERS = 18;
// workspace for RingAttention ping-pong buffer
bool const enableRingAttn = (mCpGroup.size() > 1);
const size_t ring_q_buf_size = enableRingAttn ? size * batch_size * input_seq_len * local_hidden_units_ : 0;
const size_t ring_kv_buf_size = enableRingAttn
? 2 * size * batch_size * input_seq_len * local_hidden_units_ + sizeof(int) * (batch_size + 1)
: 0;
const size_t ring_softmax_stats_buf_size
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
const size_t ring_softmax_stats_accu_buf_size
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
const size_t ring_block_output_size = enableRingAttn ? size * batch_size * input_seq_len * local_hidden_units_ : 0;
int const NUM_BUFFERS = 24;
size_t workspaces[NUM_BUFFERS];
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
workspaces[1] = attention_mask_size;
@ -227,6 +258,12 @@ size_t BertAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
workspaces[15] = scale_bmm1_device_size;
workspaces[16] = scale_bmm2_device_size;
workspaces[17] = sage_quant_space_size;
workspaces[18] = ring_q_buf_size;
workspaces[19] = ring_kv_buf_size; // kv1
workspaces[20] = ring_kv_buf_size; // kv2
workspaces[21] = ring_softmax_stats_buf_size;
workspaces[22] = ring_softmax_stats_accu_buf_size;
workspaces[23] = ring_block_output_size;
return tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
}
@ -315,6 +352,15 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
? (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
: sage_quant_space_size;
bool const enableRingAttn = (mCpGroup.size() > 1);
const size_t ring_q_buf_size = enableRingAttn ? sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
const size_t ring_kv_buf_size
= enableRingAttn ? 2 * sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
const size_t ring_softmax_stats_buf_size
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
const size_t ring_block_output_size
= enableRingAttn ? sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
// Workspace pointer shift
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
size_t offset = CUBLAS_WORKSPACE_SIZE;
@ -344,6 +390,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
void* sage_quant_space_ptr
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, sage_quant_space_size));
T* ring_q_buf_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_q_buf_size));
T* ring_kv_buf_1_ = reinterpret_cast<T*>(
tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_kv_buf_size + sizeof(int) * (batch_size + 1)));
T* ring_kv_buf_2_ = reinterpret_cast<T*>(
tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_kv_buf_size + sizeof(int) * (batch_size + 1)));
float* ring_softmax_stats_buf_
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_softmax_stats_buf_size));
float* ring_softmax_accu_stats_buf_
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_softmax_stats_buf_size));
T* ring_block_output_
= reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_block_output_size));
// build attention_mask, cu_seqlens, and padding_offset tensors
BuildDecoderInfoParams<T> params{};
params.seqQOffsets = cu_seqlens;
@ -382,176 +440,279 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
// We update mEnableContextFMHA in constructor to check this condition
if (mEnableContextFMHA)
{
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
if (enableRingAttn)
{
sage_quant<72, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
// make sure the padding part of key/value buffer is 0
cudaMemsetAsync(ring_kv_buf_1_, 0,
reinterpret_cast<int8_t*>(ring_kv_buf_2_) - reinterpret_cast<int8_t*>(ring_kv_buf_1_), stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<80, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
cudaMemcpyAsync(ring_q_buf_, attention_input, ring_q_buf_size, cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(ring_kv_buf_1_,
const_cast<char*>(reinterpret_cast<char const*>(attention_input)) + ring_q_buf_size, ring_kv_buf_size,
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(reinterpret_cast<char*>(ring_kv_buf_1_) + ring_kv_buf_size, cu_seqlens,
sizeof(int) * (batch_size + 1), cudaMemcpyDeviceToDevice, stream);
// init softmax_stats
cudaMemsetAsync(ring_softmax_accu_stats_buf_, 0, ring_softmax_stats_buf_size, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<128, 128, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
#if ENABLE_MULTI_DEVICE
// relative position of prev/next rank in cp group
int prev_rank = mCpRank > 0 ? mCpRank - 1 : mCpGroup.size() - 1;
int next_rank = (mCpRank == static_cast<int>(mCpGroup.size() - 1)) ? 0 : mCpRank + 1;
#endif // ENABLE_MULTI_DEVICE
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<128, 128, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
common::check_cuda_error(cudaStreamCreate(&mNcclStream));
common::check_cuda_error(cudaStreamSynchronize(stream));
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<80, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<72, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize, attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens, cu_seqlens,
sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
// Construct the fmha params for running kernels.
MHARunnerParams fmhaParams{};
fmhaParams.b = request_batch_size;
fmhaParams.qSeqLen = request_seq_len;
fmhaParams.kvSeqLen = request_seq_len;
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
// Device buffer pointers.
fmhaParams.qkvPtr = attention_input;
fmhaParams.outputPtr = context_buf_;
fmhaParams.cuQSeqLenPtr = cu_seqlens;
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.stream = stream;
if (mSageAttn)
{
if (paddedHeadSize != mHeadSize)
fmhaParams.outputPtr = sage_quant_space_ptr;
fmhaParams.qkvPtr = quanted_qkv_ptr;
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
fmhaParams.qScalePtr = q_scale_ptr;
fmhaParams.kScalePtr = k_scale_ptr;
fmhaParams.vScalePtr = v_scale_ptr;
fmhaParams.qMaxNBlock = (input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize;
fmhaParams.kMaxNBlock = (input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize;
fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize;
}
// Run the fmha kernel.
mFMHARunner->run(fmhaParams);
sync_check_cuda_error(stream);
if (mSageAttn)
{
if (paddedHeadSize != mHeadSize && mHeadSize == 72)
uint32_t* fmha_scheduler_counter_h = (uint32_t*) malloc(sizeof(uint32_t));
cudaMemcpyAsync(
fmha_scheduler_counter_h, fmha_tile_counter_ptr, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
for (size_t iter = 0; iter < mCpGroup.size(); ++iter)
{
unpadding<80, 72, __nv_bfloat16>(batch_size, mNumHeads, input_seq_len, sage_quant_space_ptr,
mNumHeads * 72, mNumHeads * 80, cu_seqlens, context_buf_, stream);
// KV buffer used by fmha
T* ring_fmha_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_1_ : ring_kv_buf_2_;
#if ENABLE_MULTI_DEVICE
T* ring_send_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_1_ : ring_kv_buf_2_;
T* ring_recv_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_2_ : ring_kv_buf_1_;
if (iter < mCpGroup.size() - 1)
{
NCCLCHECK(ncclGroupStart());
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
NCCLCHECK(ncclSend(ring_send_kv_buf_,
ring_kv_buf_size / sizeof(T) + sizeof(int) / sizeof(T) * (batch_size + 1),
(*getDtypeMap())[inputDesc[0].type], next_rank, *mNcclComm, mNcclStream));
NCCLCHECK(ncclRecv(ring_recv_kv_buf_,
ring_kv_buf_size / sizeof(T) + sizeof(int) / sizeof(T) * (batch_size + 1),
(*getDtypeMap())[inputDesc[0].type], prev_rank, *mNcclComm, mNcclStream));
NCCLCHECK(ncclGroupEnd());
}
#else
TLLM_LOG_ERROR("Please set ENABLE_MULTI_DEVICE to enable RingAttention");
return 1;
#endif // ENABLE_MULTI_DEVICE
// Construct the fmha params for running kernels.
MHARunnerParams fmhaParams{};
fmhaParams.b = request_batch_size;
fmhaParams.qSeqLen = request_seq_len;
fmhaParams.kvSeqLen = request_seq_len;
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
// Device buffer pointers.
fmhaParams.qPtr = ring_q_buf_;
fmhaParams.kvPtr = ring_fmha_kv_buf_;
if (iter == 0)
{
fmhaParams.outputPtr = context_buf_;
fmhaParams.softmaxStatsPtr = ring_softmax_accu_stats_buf_;
}
else
{
cudaMemsetAsync(ring_softmax_stats_buf_, 0, ring_softmax_stats_buf_size, stream);
fmhaParams.outputPtr = ring_block_output_;
fmhaParams.softmaxStatsPtr = ring_softmax_stats_buf_;
}
fmhaParams.cuQSeqLenPtr = cu_seqlens;
fmhaParams.cuKvSeqLenPtr
= reinterpret_cast<int*>(reinterpret_cast<char*>(ring_fmha_kv_buf_) + ring_kv_buf_size);
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.stream = stream;
// Run the fmha kernel.
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
cudaMemcpyHostToDevice, stream);
mFMHARunner->run(fmhaParams);
if (iter != 0)
{
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
(T*) ring_block_output_, (float*) ring_softmax_stats_buf_, fmhaParams.b, fmhaParams.qSeqLen,
mNumHeads, mHeadSize, cu_seqlens, stream);
}
cudaStreamSynchronize(stream);
cudaStreamSynchronize(mNcclStream);
}
common::check_cuda_error(cudaStreamDestroy(mNcclStream));
free(fmha_scheduler_counter_h);
}
else
{
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<72, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<80, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<128, 128, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<128, 128, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<80, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<72, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
// Construct the fmha params for running kernels.
MHARunnerParams fmhaParams{};
fmhaParams.b = request_batch_size;
fmhaParams.qSeqLen = request_seq_len;
fmhaParams.kvSeqLen = request_seq_len;
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
// Device buffer pointers.
fmhaParams.qkvPtr = attention_input;
fmhaParams.outputPtr = context_buf_;
fmhaParams.cuQSeqLenPtr = cu_seqlens;
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.stream = stream;
if (mSageAttn)
{
if (paddedHeadSize != mHeadSize)
fmhaParams.outputPtr = sage_quant_space_ptr;
fmhaParams.qkvPtr = quanted_qkv_ptr;
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
fmhaParams.qScalePtr = q_scale_ptr;
fmhaParams.kScalePtr = k_scale_ptr;
fmhaParams.vScalePtr = v_scale_ptr;
fmhaParams.qMaxNBlock = (input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize;
fmhaParams.kMaxNBlock = (input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize;
fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize;
}
// Run the fmha kernel.
mFMHARunner->run(fmhaParams);
sync_check_cuda_error(stream);
if (mSageAttn)
{
if (paddedHeadSize != mHeadSize && mHeadSize == 72)
{
unpadding<80, 72, __nv_bfloat16>(batch_size, mNumHeads, input_seq_len, sage_quant_space_ptr,
mNumHeads * 72, mNumHeads * 80, cu_seqlens, context_buf_, stream);
}
}
}
}
@ -780,6 +941,12 @@ int BertAttentionPlugin::initialize() noexcept
fmhaParams.headSize = paddedHeadSize;
}
if (mCpGroup.size() > 1)
{
fmhaParams.attentionInputLayout = AttentionInputLayout::Q_CONTIGUOUS_KV;
fmhaParams.saveSoftmax = true;
}
// Load kernels from the pre-compiled cubins.
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));
@ -787,6 +954,15 @@ int BertAttentionPlugin::initialize() noexcept
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
}
#if ENABLE_MULTI_DEVICE
if (mCpGroup.size() > 1 && COMM_SESSION.getSize() > 1)
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mCpGroup);
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
}
#endif // ENABLE_MULTI_DEVICE
return 0;
}
@ -800,7 +976,7 @@ size_t BertAttentionPlugin::getSerializationSize() const noexcept
return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mQScaling) + sizeof(mQKHalfAccum) + sizeof(mEnableContextFMHA)
+ sizeof(mFMHAForceFP32Acc) + sizeof(mType) + sizeof(mRelativeAttention) + sizeof(mMaxDistance)
+ sizeof(mRemovePadding) + sizeof(mSageAttn) + sizeof(mSageAttnQBlockSize) + sizeof(mSageAttnKBlockSize)
+ sizeof(mSageAttnVBlockSize);
+ sizeof(mSageAttnVBlockSize) + sizeof(mCpSize) + sizeof(mCpRank) + sizeof(int32_t) * mCpGroup.size();
}
void BertAttentionPlugin::serialize(void* buffer) const noexcept
@ -820,6 +996,12 @@ void BertAttentionPlugin::serialize(void* buffer) const noexcept
write(d, mSageAttnQBlockSize);
write(d, mSageAttnKBlockSize);
write(d, mSageAttnVBlockSize);
write(d, mCpSize);
write(d, mCpRank);
for (auto it = mCpGroup.begin(); it != mCpGroup.end(); ++it)
{
write(d, *it);
}
TLLM_CHECK(d == a + getSerializationSize());
}
@ -844,6 +1026,9 @@ BertAttentionPluginCreator::BertAttentionPluginCreator()
mPluginAttributes.emplace_back(PluginField("sage_attn_q_block_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("sage_attn_k_block_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("sage_attn_v_block_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("cp_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("cp_rank", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("cp_group", nullptr, PluginFieldType::kINT32));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
@ -879,6 +1064,9 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(char const* name, PluginFiel
int sage_attn_q_block_size{};
int sage_attn_k_block_size{};
int sage_attn_v_block_size{};
int cp_size{};
int cp_rank{};
std::set<int> cp_group{};
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
@ -948,12 +1136,32 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(char const* name, PluginFiel
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
sage_attn_v_block_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "cp_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
cp_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "cp_rank"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
cp_rank = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "cp_group"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
auto const* r = static_cast<int const*>(fields[i].data);
for (int j = 0; j < fields[i].length; ++j)
{
cp_group.insert(*r);
++r;
}
}
}
try
{
auto* obj = new BertAttentionPlugin(num_heads, head_size, q_scaling, context_fmha_type, type,
do_relative_attention, max_distance, remove_padding, sage_attn, sage_attn_q_block_size,
sage_attn_k_block_size, sage_attn_v_block_size);
sage_attn_k_block_size, sage_attn_v_block_size, cp_size, cp_rank, cp_group);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -21,7 +21,9 @@
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <cassert>
#include <cuda_runtime.h>
#include <set>
#include <string>
#include <vector>
@ -37,7 +39,8 @@ public:
BertAttentionPlugin(int num_heads, int head_size, float q_scaling,
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, nvinfer1::DataType type,
bool do_relative_attention = false, int max_distance = 0, bool remove_padding = false, bool sage_attn = false,
int sage_attn_q_block_size = 0, int sage_attn_k_block_size = 0, int sage_attn_v_block_size = 0);
int sage_attn_q_block_size = 0, int sage_attn_k_block_size = 0, int sage_attn_v_block_size = 0, int cp_size = 1,
int cp_rank = 0, std::set<int> cp_group = {});
BertAttentionPlugin(void const* data, size_t length);
@ -101,6 +104,15 @@ private:
int mSM = tensorrt_llm::common::getSMVersion();
// comm group for RingAttention
int mCpSize = 1;
int mCpRank = 0;
std::set<int> mCpGroup = {};
#if ENABLE_MULTI_DEVICE
std::shared_ptr<ncclComm_t> mNcclComm;
#endif // ENABLE_MULTI_DEVICE
cudaStream_t mNcclStream;
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;

View File

@ -4449,7 +4449,10 @@ def bert_attention(tensor: Tensor,
sage_attn: bool = False,
sage_attn_q_block_size: int = 0,
sage_attn_k_block_size: int = 0,
sage_attn_v_block_size: int = 0) -> Tuple[Tensor]:
sage_attn_v_block_size: int = 0,
cp_group: list[int] = None,
cp_size: int = 1,
cp_rank: int = 0) -> Tuple[Tensor]:
'''
Add an operation that performs the multi-head attention in BERT.
@ -4516,6 +4519,15 @@ def bert_attention(tensor: Tensor,
sage_attn_v_quant_size: int = 0
dynamic quant block size along sequence dimension of v tensor. Each quant block will share one scale.
cp_group: list[int] = None
The communication group for context parallel
cp_size: int = 1
The communication size for context parallel
cp_rank: int = 0
The communication rank for context parallel
Returns:
The tensor produced by that layer.
'''
@ -4570,10 +4582,31 @@ def bert_attention(tensor: Tensor,
np.array(sage_attn_v_block_size, dtype=np.int32),
trt.PluginFieldType.INT32)
if cp_size > 1:
# transpose q,k,v inside qkv to make kv contiguous, which is required by ring attention
# (b, s, 3d)
query, key, value = chunk(tensor, 3, dim=-1)
bs = shape(query, 0)
seq_len = shape(query, 1)
# (b, s, d) -> (b, s, 2d) -> (2b, s, d)
kv = concat([key, value],
dim=-1).view(concat((2 * bs, seq_len, query.shape[-1])))
tensor = concat((query, kv),
dim=0).view(concat((bs, seq_len, query.shape[-1] * 3)))
cp_size = trt.PluginField("cp_size", np.array(cp_size, dtype=np.int32),
trt.PluginFieldType.INT32)
cp_rank = trt.PluginField("cp_rank", np.array(cp_rank, dtype=np.int32),
trt.PluginFieldType.INT32)
cp_group = cp_group or [0]
cp_group = np.array(cp_group, dtype=np.int32)
cp_group = trt.PluginField("cp_group", cp_group, trt.PluginFieldType.INT32)
pfc = trt.PluginFieldCollection([
nheads, head_size, q_scaling, context_fmha_type, pf_type,
do_relative_attention, max_distance, remove_padding, sage_attn,
sage_attn_q_block_size, sage_attn_k_block_size, sage_attn_v_block_size
sage_attn_q_block_size, sage_attn_k_block_size, sage_attn_v_block_size,
cp_size, cp_rank, cp_group
])
attn_plug = attn_plg_creator.create_plugin("padding_attn", pfc)

View File

@ -1608,6 +1608,7 @@ class BertAttention(Module):
tp_rank=0,
cp_group=None,
cp_size=1,
cp_rank=0,
relative_attention=False,
max_distance=0,
num_buckets=0,
@ -1628,6 +1629,7 @@ class BertAttention(Module):
self.tp_rank = tp_rank
self.cp_group = cp_group
self.cp_size = cp_size
self.cp_rank = cp_rank
self.num_layers = num_layers
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
@ -1725,7 +1727,6 @@ class BertAttention(Module):
if default_net().plugin_config.bert_attention_plugin:
# TRT plugin mode
assert input_lengths is not None
assert self.cp_size == 1
assert get_sm_version() < 100 or get_sm_version() >= 120, \
"bert_attention_plugin does not support SM100"
context = bert_attention(
@ -1738,7 +1739,10 @@ class BertAttention(Module):
max_distance=self.max_distance,
relative_attention_bias=self.rel_attn_table.value
if self.relative_attention else None,
max_input_length=max_input_length)
max_input_length=max_input_length,
cp_group=self.cp_group,
cp_size=self.cp_size,
cp_rank=self.cp_rank)
else:
# plain TRT mode
def transpose_for_scores(x):

View File

@ -134,6 +134,7 @@ class DiTBlock(Module):
tp_rank=mapping.tp_rank,
cp_group=mapping.cp_group,
cp_size=mapping.cp_size,
cp_rank=mapping.cp_rank,
dtype=dtype,
quant_mode=quant_mode)
self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
@ -299,6 +300,7 @@ class DiT(PretrainedModel):
if self.mapping.cp_size > 1:
assert x.shape[1] % self.mapping.cp_size == 0
x = chunk(x, self.mapping.cp_size, dim=1)[self.mapping.cp_rank]
input_lengths = input_lengths // self.mapping.cp_size
for block in self.blocks:
x = block(x, c, input_lengths) # (N, T, D)
self.register_network_output('before_final_layer', x)