fix: fix for cp > kvHeadNum (#3002)

* fix for cp > kvHeadNum

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

* fix for None kv_head_num

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

---------

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
This commit is contained in:
DylanChen-NV 2025-03-26 12:39:02 +08:00 committed by GitHub
parent 25f2434495
commit 1ac0566a93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 346 additions and 238 deletions

View File

@ -166,9 +166,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams = {};
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
// TODO(ziqingc): A better description for these parameters affected by CP size
xqaParams.num_q_heads = mNumHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
xqaParams.num_kv_heads = mNumKVHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
xqaParams.num_q_heads = mNumAttnHeads;
xqaParams.num_kv_heads = mNumAttnKVHeads;
xqaParams.head_size = mHeadSize;
xqaParams.unidirectional = mUnidirectional;
xqaParams.q_scaling = mQScaling;
@ -265,15 +264,151 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
return true;
}
template <typename T>
int AttentionOp::ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
{
int32_t partialTokenNum = 0;
int32_t maxPartialLength = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
// full request: [bs, seqlen, head, headSize]
//
// input of cp: [bs, partialLength, head, headSize]
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
// view_2 as [bs, totalLength, partialHead, headSize]
// and this is same to the input under TP.
//
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
// input of cp: [partialNumTokens, head, headSize]
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
// transpose_2 as [NumTokens, partialHead, headSize]
// and this is same to the input under TP.
// view_1 + transpose_1
invokeCpTranspose(output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
mUlyssesMQABroadcast, getHeadSize(), mCpRank, stream);
sync_check_cuda_error(stream);
// Do all to all
#if ENABLE_MULTI_DEVICE
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
NCCLCHECK(ncclSend(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
NCCLCHECK(ncclRecv(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
}
}
ncclGroupEnd();
sync_check_cuda_error(stream);
#endif // ENABLE_MULTI_DEVICE
// transpose_2 + view_2
invokeCpTranspose2(output, buffer, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens, mCpSize,
maxPartialLength, params.batch_size, partialHeads, getHeadSize(), stream);
return 0;
}
template <typename T>
int AttentionOp::ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
{
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
// view: [partialTokens, head, headSize]
int32_t maxPartialLength = 0;
int32_t partialTokenNum = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
// transpose_2_reverse
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(buffer),
reinterpret_cast<__nv_fp8_e4m3 const*>(input), params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
}
else
{
invokeCpTransposeToSeqMajor2(buffer, input, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
}
// all-to-all
#if ENABLE_MULTI_DEVICE
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
if (mFP8ContextFMHA)
{
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
cpIdx, *mCpNcclComm, stream));
}
else
{
NCCLCHECK(ncclSend(
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
}
}
}
ncclGroupEnd();
#endif // ENABLE_MULTI_DEVICE
// transpose_1_reverse + view
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
reinterpret_cast<__nv_fp8_e4m3 const*>(buffer), reinterpret_cast<__nv_fp8_e4m3 const*>(input),
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
}
else
{
invokeCpTransposeToSeqMajor<T>(
(T*) output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
}
return 0;
}
template <typename T>
int AttentionOp::ulyssesGenerationPreprocess(
int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream)
T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
{
if (mCpSize <= 1)
return 0;
auto const partialQHeads = mNumHeads / mCpSize;
auto const partialKVHeads = mNumKVHeads / mCpSize;
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
// attention_input shape: [partialTokenNum, numHeads, headSize]
@ -284,27 +419,26 @@ int AttentionOp::ulyssesGenerationPreprocess(
// do transpose_1
// [1, mNumHeads + 2*mNumKVHeads, headSize]
// -> (view) [1, cpSize * partialQHeads + cpSize * partialKVHeads + cpSize * partilKVHeads,
// -> (view) [1, cpSize * mNumAttnHeads + cpSize * mNumAttnKVHeads + cpSize * partilKVHeads,
// headSize]
// -> (transpose) [cpSize, 1, partialQHeads + partialKvHeads + partialKVHeads, headSize]
invokeCpTranspose(mhaOutput, mhaInput, input, partialTokenNum, mCpSize, partialQHeads, partialKVHeads, mHeadSize,
mCpRank, stream);
// -> (transpose) [cpSize, 1, mNumAttnHeads + mNumAttnKVHeads + mNumAttnKVHeads, headSize]
invokeCpTranspose(buffer, output, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
mUlyssesMQABroadcast, mHeadSize, mCpRank, stream);
sync_check_cuda_error(stream);
// Do all to all
#if ENABLE_MULTI_DEVICE
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
auto const partialHeads = totalHeads / mCpSize;
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
NCCLCHECK(ncclSend(mhaOutput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
NCCLCHECK(ncclSend(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
NCCLCHECK(ncclRecv(mhaInput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
NCCLCHECK(ncclRecv(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
}
@ -312,14 +446,11 @@ int AttentionOp::ulyssesGenerationPreprocess(
ncclGroupEnd();
sync_check_cuda_error(stream);
#endif // ENABLE_MULTI_DEVICE
input = mhaInput;
return 0;
}
template <typename T>
int AttentionOp::ulyssesGenerationPostprocess(
int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream)
int AttentionOp::ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
{
if (mCpSize <= 1)
return 0;
@ -330,12 +461,11 @@ int AttentionOp::ulyssesGenerationPostprocess(
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
// view: [partialTokens, head, headSize]
auto partialHeads = mNumHeads / mCpSize;
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
// do all-to-all
#if ENABLE_MULTI_DEVICE
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
@ -343,17 +473,17 @@ int AttentionOp::ulyssesGenerationPostprocess(
{
if (mFP8ContextFMHA)
{
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(mhaOutput) + cpIdx * elementNum, elementNum,
ncclInt8, cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(mhaInput) + cpIdx * elementNum, elementNum,
ncclInt8, cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
cpIdx, *mCpNcclComm, stream));
}
else
{
NCCLCHECK(ncclSend(
mhaOutput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(
mhaInput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
}
}
}
@ -364,15 +494,14 @@ int AttentionOp::ulyssesGenerationPostprocess(
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
reinterpret_cast<__nv_fp8_e4m3 const*>(mhaOutput), reinterpret_cast<__nv_fp8_e4m3 const*>(mhaInput),
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
reinterpret_cast<__nv_fp8_e4m3 const*>(input), reinterpret_cast<__nv_fp8_e4m3 const*>(buffer),
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
}
else
{
invokeCpTransposeToSeqMajor<T>(
(T*) output, mhaOutput, mhaInput, partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
(T*) output, input, buffer, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
}
sync_check_cuda_error(stream);
return 0;
}
@ -532,8 +661,8 @@ int AttentionOp::getHeadSize(bool checkInit) const
size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t max_num_seq, int32_t input_seq_length,
int32_t cross_kv_length, int32_t max_num_tokens) const noexcept
{
int const local_hidden_units_qo = mNumHeads / mCpSize * getHeadSize();
int const local_hidden_units_kv = mNumKVHeads / mCpSize * getHeadSize();
int const local_hidden_units_qo = mNumAttnHeads * getHeadSize();
int const local_hidden_units_kv = mNumAttnKVHeads * getHeadSize();
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
@ -815,9 +944,8 @@ int AttentionOp::mlaGeneration(
tllmRunnerParams.mHeadDimQk = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
tllmRunnerParams.mHeadDimV = mMLAParams.kv_lora_rank;
auto const num_q_heads = mNumHeads / mCpSize;
auto const num_q_heads = mNumAttnHeads;
tllmRunnerParams.mNumHeadsQ = num_q_heads;
// const auto num_kv_heads_tllm_runner_params = mNumKVHeads / mCpSize;
tllmRunnerParams.mNumHeadsKv = num_kv_heads;
tllmRunnerParams.mNumHeadsQPerKv = num_q_heads / num_kv_heads;
@ -1007,13 +1135,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
int const headSize = getHeadSize();
int const local_hidden_units_qo = mNumHeads * headSize;
int const local_hidden_units_kv = mNumKVHeads / mCpSize * headSize;
int const local_hidden_units_kv = mNumAttnKVHeads * headSize;
PositionEmbeddingType const position_embedding_type = mPositionEmbeddingType;
float const q_scaling = mQScaling;
KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize;
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
if (useKVCache())
{
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@ -1252,72 +1380,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
// do all-to-all for params.attention_input, need to split on kv head
// [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size]
T* attention_input = const_cast<T*>(params.attention_input);
if (mCpSize > 1)
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
int32_t partialTokenNum = 0;
int32_t maxPartialLength = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
auto const partialHeads = totalHeads / mCpSize;
auto const partialQHeads = mNumHeads / mCpSize;
auto const partialKVHeads = mNumKVHeads / mCpSize;
// full request: [bs, seqlen, head, headSize]
//
// input of cp: [bs, partialLength, head, headSize]
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
// view_2 as [bs, totalLength, partialHead, headSize]
// and this is same to the input under TP.
//
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
// input of cp: [partialNumTokens, head, headSize]
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
// transpose_2 as [NumTokens, partialHead, headSize]
// and this is same to the input under TP.
// view_1 + transpose_1
invokeCpTranspose(gatherInBuffer, gatherOutBuffer, params.attention_input, partialTokenNum, mCpSize,
partialQHeads, partialKVHeads, getHeadSize(), mCpRank, stream);
sync_check_cuda_error(stream);
// Do all to all
#if ENABLE_MULTI_DEVICE
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
}
}
ncclGroupEnd();
sync_check_cuda_error(stream);
#endif // ENABLE_MULTI_DEVICE
// transpose_2 + view_2
invokeCpTranspose2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
stream);
this->template ulyssesContextPreprocess<T>(
attention_input, gatherInBuffer, gatherOutBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
attention_input = gatherInBuffer;
sync_check_cuda_error(stream);
}
sync_check_cuda_error(stream);
bool const enablePagedKVContextFMHA = mPagedKVCache && mPagedContextFMHA;
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasInt8KvCache() && enablePagedKVContextFMHA),
@ -1363,9 +1432,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.token_num = params.num_tokens;
preprocessingParams.remove_padding = mRemovePadding;
preprocessingParams.cross_attention = isCrossAttention();
preprocessingParams.head_num = mNumHeads / mCpSize;
preprocessingParams.kv_head_num = mNumKVHeads / mCpSize;
preprocessingParams.qheads_per_kv_head = mNumHeads / mNumKVHeads;
preprocessingParams.head_num = mNumAttnHeads;
preprocessingParams.kv_head_num = mNumAttnKVHeads;
preprocessingParams.qheads_per_kv_head = mNumAttnHeads / mNumAttnKVHeads;
preprocessingParams.size_per_head = getHeadSize();
preprocessingParams.rotary_embedding_dim = mRotaryEmbeddingDim;
preprocessingParams.rotary_embedding_base = mRotaryEmbeddingBase;
@ -1479,79 +1548,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
invokeKvCachePostprocessing(preprocessingParams, stream);
sync_check_cuda_error(stream);
if (mCpSize > 1)
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
// view: [partialTokens, head, headSize]
int32_t maxPartialLength = 0;
int32_t partialTokenNum = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
auto partialHeads = mNumHeads / mCpSize;
// transpose_2_reverse
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer),
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), params.context_lengths, cu_q_seqlens,
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
stream);
}
else
{
invokeCpTransposeToSeqMajor2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
stream);
}
// all-to-all
#if ENABLE_MULTI_DEVICE
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
if (mFP8ContextFMHA)
{
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer) + cpIdx * elementNum,
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(gatherOutBuffer) + cpIdx * elementNum,
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
}
else
{
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
cpIdx, *mCpNcclComm, stream));
}
}
}
ncclGroupEnd();
#endif // ENABLE_MULTI_DEVICE
// transpose_1_reverse + view
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(params.context_buf),
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherInBuffer),
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), partialTokenNum, mCpSize, partialHeads,
getHeadSize(), mCpRank, stream);
}
else
{
invokeCpTransposeToSeqMajor<T>((T*) params.context_buf, gatherInBuffer, gatherOutBuffer,
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
}
this->template ulyssesContextPostprocess<T>(gatherOutBuffer, reinterpret_cast<T*>(params.context_buf),
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
sync_check_cuda_error(stream);
}
}
@ -1871,7 +1871,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto const sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize;
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
if (useKVCache())
{
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@ -1936,7 +1936,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
T* mhaInput = mhaOutput + cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize;
T* attention_input = const_cast<T*>(params.attention_input);
this->template ulyssesGenerationPreprocess<T>(batch_beam, mhaInput, mhaOutput, attention_input, stream);
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPreprocess<T>(attention_input, mhaInput, mhaOutput, batch_beam, stream);
attention_input = mhaInput;
sync_check_cuda_error(stream);
}
// Try XQA optimization first.
{
@ -1954,7 +1959,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
xqaParams.qkv = attention_input;
}
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream);
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPostprocess<T>(
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
sync_check_cuda_error(stream);
}
return 0;
}
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
@ -2040,8 +2050,8 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
dispatch_params.max_batch_size = batch_beam;
dispatch_params.inference_batch_size = batch_beam;
dispatch_params.beam_width = params.beam_width;
dispatch_params.head_num = mNumHeads / mCpSize;
dispatch_params.kv_head_num = mNumKVHeads / mCpSize;
dispatch_params.head_num = mNumAttnHeads;
dispatch_params.kv_head_num = mNumAttnKVHeads;
dispatch_params.size_per_head = getHeadSize();
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
dispatch_params.position_embedding_type = mPositionEmbeddingType;
@ -2104,7 +2114,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
}
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream);
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPostprocess<T>(
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
sync_check_cuda_error(stream);
}
return 0;
}
@ -2164,6 +2179,21 @@ template void AttentionOp::prepareEnqueueGeneration<__nv_bfloat16, KVBlockArray>
int AttentionOp::initialize() noexcept
{
// use Ulysses for GPTAttentionPlugin
if (mAttnTpSize < 0 || mAttnCpSize < 0)
{
mAttnTpSize = mTpSize * mCpSize;
mAttnCpSize = 1;
}
mNumAttnHeads = mNumHeads * mTpSize / mAttnTpSize;
mNumAttnKVHeads = (mNumKVHeads * mTpSize + mAttnTpSize - 1) / mAttnTpSize;
if (mCpSize != mAttnCpSize)
{
// mqa broadcast
mUlyssesMQABroadcast = (mAttnTpSize + mNumKVHeadsOrigin - 1) / mNumKVHeadsOrigin;
}
// Pre-check whether FMHA is supported in order to save memory allocation.
if (mEnableContextFMHA)
{
@ -2306,8 +2336,8 @@ int AttentionOp::initialize() noexcept
fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
}
fmhaParams.isSPadded = !mRemovePadding;
fmhaParams.numQHeads = mNumHeads / mCpSize;
fmhaParams.numKvHeads = mNumKVHeads / mCpSize;
fmhaParams.numQHeads = mNumAttnHeads;
fmhaParams.numKvHeads = mNumAttnKVHeads;
fmhaParams.numTokensPerBlock = mTokensPerBlock;
fmhaParams.headSize = mHeadSize;
fmhaParams.headSizeV = mHeadSize;
@ -2326,16 +2356,6 @@ int AttentionOp::initialize() noexcept
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
fmhaParams.hasAlibi = isALiBi();
fmhaParams.scaleAlibi = isAliBiWithScale();
if (mTpSize > 1)
{
fmhaParams.tpSize = mTpSize;
fmhaParams.tpRank = mTpRank;
}
else if (mCpSize > 1)
{
fmhaParams.tpSize = mCpSize;
fmhaParams.tpRank = mCpRank;
}
// Load kernels from the pre-compiled cubins.
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
@ -2479,8 +2499,8 @@ int AttentionOp::initialize() noexcept
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
}
fixedParams.numQHeads = mNumHeads / mCpSize;
fixedParams.numKvHeads = mNumKVHeads / mCpSize;
fixedParams.numQHeads = mNumAttnHeads;
fixedParams.numKvHeads = mNumAttnKVHeads;
fixedParams.numTokensPerBlock = mTokensPerBlock;
fixedParams.headSize = mHeadSize;
fixedParams.qScaling = mQScaling;
@ -2555,6 +2575,7 @@ std::string AttentionOp::toString() const
ss << "gptAttentionCommon members ====================" << std::endl;
ss << "mNumHeads: " << mNumHeads << std::endl;
ss << "mNumKVHeads: " << mNumKVHeads << std::endl;
ss << "mNumKVHeadsOrigin: " << mNumKVHeadsOrigin << std::endl;
ss << "mHeadSize: " << mHeadSize << std::endl;
ss << "mUnidirectional: " << mUnidirectional << std::endl;
ss << "mQScaling: " << mQScaling << std::endl;

View File

@ -229,10 +229,18 @@ public:
EnqueueGenerationParams<T> const& generationsParams, bool forConfigurePlugin);
template <typename T>
int ulyssesGenerationPreprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream);
int ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
template <typename T>
int ulyssesGenerationPostprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream);
int ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
template <typename T>
int ulyssesGenerationPreprocess(T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
template <typename T>
int ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
[[nodiscard]] bool isRelativePosition() const
{
@ -374,6 +382,16 @@ public:
int mCpSize = 1;
int mCpRank = 0;
std::set<int32_t> mCpGroup = {};
// These parameters are used to specifically configure the attention attributes when cp/tp_size are different
// between Attention and FFN(such as Ulysses)
int mNumAttnHeads = -1;
int mNumAttnKVHeads = -1;
int mNumKVHeadsOrigin = -1;
int mAttnTpSize = -1;
int mAttnTpRank = 0;
int mAttnCpSize = -1;
int mAttnCpRank = 0;
int mUlyssesMQABroadcast = 1;
// fmha runner (enabled by default)
// flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2
@ -403,8 +421,9 @@ public:
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
mCpGroup, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn,
mFuseFp4Quant, mNbMultiBlockSemaphores);
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
};
private:

View File

@ -2139,7 +2139,7 @@ __global__ void convertData(Dst* dst, Src const* src, int64_t size, float const*
template <typename T>
__global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize,
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank)
int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank)
{
// Do transpose from
// [partialTokenNum, mNumHeads + 2*mNumKVHeads, headSize]
@ -2164,12 +2164,12 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
}
else if (headIdx < partialQHeads + partialKVHeads)
{
srcHeadIdx = cpSize * partialQHeads + cpIdx * partialKVHeads + (headIdx - partialQHeads);
srcHeadIdx = cpSize * partialQHeads + cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads);
}
else
{
srcHeadIdx = cpSize * partialQHeads + cpSize * partialKVHeads + cpIdx * partialKVHeads
+ (headIdx - partialQHeads - partialKVHeads);
srcHeadIdx = cpSize * partialQHeads + cpSize / mqaBroadcast * partialKVHeads
+ cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads - partialKVHeads);
}
if (cpIdx == rank)
@ -2181,7 +2181,7 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
+ seqIdx * (partialQHeads + 2 * partialKVHeads) + headIdx)
* headSize);
VecType const* in = reinterpret_cast<VecType const*>(
src + (seqIdx * (partialQHeads + 2 * partialKVHeads) * cpSize + srcHeadIdx) * headSize);
src + (seqIdx * (partialQHeads * cpSize + 2 * partialKVHeads * cpSize / mqaBroadcast) + srcHeadIdx) * headSize);
for (int hiddenIdx = threadIdx.x; hiddenIdx < hiddenSize + hiddenRestSize; hiddenIdx += blockDim.x)
{
@ -2347,17 +2347,18 @@ INSTANTIATE_invokeConversion(__nv_fp8_e4m3, __nv_bfloat16);
template <typename T>
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, int64_t partialQHeads,
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream)
int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream)
{
dim3 grid(partialTokenNum, cpSize, partialQHeads + 2 * partialKVHeads);
dim3 block(128);
runCpTranspose<T><<<grid, block, 0, stream>>>(
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, headSize, rank);
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, mqaBroadcast, headSize, rank);
}
#define INSTANTIATE_invokeCpTranspose(T) \
template void invokeCpTranspose<T>(T * dst, T * dst2, T const* src, int64_t partialLength, int64_t cpSize, \
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream)
int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, \
cudaStream_t stream)
INSTANTIATE_invokeCpTranspose(float);
INSTANTIATE_invokeCpTranspose(half);
INSTANTIATE_invokeCpTranspose(__nv_bfloat16);

View File

@ -407,7 +407,7 @@ void invokeConversion(Dst* dst, Src const* src, int64_t size, float const* __res
template <typename T>
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialLength, int64_t cpSize, int64_t partialQHeads,
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream);
int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream);
template <typename T>
void invokeCpTransposeToSeqMajor(T* dst, T const* srcMyRank, T const* srcOtherRank, int64_t partialLength,

View File

@ -28,8 +28,8 @@ using tensorrt_llm::plugins::GPTAttentionPluginCreatorCommon;
using tensorrt_llm::plugins::GPTAttentionPluginCommon;
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
@ -52,6 +52,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
mVisionStart = vision_start;
mVisionLength = vision_length;
mNumKVHeads = num_kv_heads;
mNumKVHeadsOrigin = num_kv_heads_origin;
mHeadSize = head_size;
mUnidirectional = unidirectional;
mQScaling = q_scaling;
@ -114,6 +115,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
read(d, mVisionStart);
read(d, mVisionLength);
read(d, mNumKVHeads);
read(d, mNumKVHeadsOrigin);
read(d, mHeadSize);
read(d, mUnidirectional);
read(d, mQScaling);
@ -201,13 +203,13 @@ void GPTAttentionPluginCommon::destroy() noexcept
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
{
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
+ sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mAttnLogitSoftcappingScale)
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingShortMscale)
+ sizeof(mRotaryEmbeddingLongMscale) + sizeof(mRotaryEmbeddingMaxPositions)
+ sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
+ sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mNumKVHeadsOrigin) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
+ sizeof(mAttnLogitSoftcappingScale) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
+ sizeof(mRotaryEmbeddingShortMscale) + sizeof(mRotaryEmbeddingLongMscale)
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize)
+ sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode)
+ sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache)
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
@ -228,6 +230,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
write(d, mVisionStart);
write(d, mVisionLength);
write(d, mNumKVHeads);
write(d, mNumKVHeadsOrigin);
write(d, mHeadSize);
write(d, mUnidirectional);
write(d, mQScaling);
@ -307,6 +310,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("num_kv_heads_origin", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("layer_idx_in_cache_pool", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32));

View File

@ -35,7 +35,7 @@ public:
GPTAttentionPluginCommon() = delete;
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,

View File

@ -46,8 +46,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
@ -65,17 +65,17 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank,
int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size,
int cp_rank, std::set<int32_t> cp_group)
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type, rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size,
tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type, kv_cache_quant_mode, remove_input_padding,
mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled,
cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, use_paged_context_fmha,
use_fp8_context_fmha, has_full_attention_mask, use_cache, is_spec_decoding_enabled,
spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank,
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant, skip_attn, cp_size, cp_rank,
cp_group)
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, num_kv_heads_origin,
head_size, unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions,
rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type,
kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block,
type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, use_cache,
is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length,
is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant,
skip_attn, cp_size, cp_rank, cp_group)
{
TLLM_CHECK_WITH_INFO(
!is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead.");
@ -890,7 +890,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
auto const blockSize = mTokensPerBlock * mNumKVHeads / mCpSize * mHeadSize;
auto const kv_cache_head_num = (mNumKVHeads + mCpSize - 1) / mCpSize;
auto const blockSize = mTokensPerBlock * kv_cache_head_num * mHeadSize;
auto const bytesPerBlock = blockSize * cacheElemSize;
auto const layerOffset = layerIdxInCachePool * 2 * bytesPerBlock;
@ -1311,8 +1312,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
p.getScalar<float>("q_scaling").value(), p.getScalar<float>("attn_logit_softcapping_scale").value(),
p.getScalar<int32_t>("num_kv_heads_origin").value(), p.getScalar<int32_t>("head_size").value(),
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
p.getScalar<float>("attn_logit_softcapping_scale").value(),
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),

View File

@ -101,7 +101,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
{
public:
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,

View File

@ -236,6 +236,8 @@ def convert_and_save_hf(args):
load_model_on_cpu=args.load_model_on_cpu,
**override_fields)
glm.config.mapping.cp_size = args.cp_size
glm.config.mapping.attn_tp_size = -1
glm.config.mapping.attn_cp_size = -1
glm.config.mapping.world_size *= args.cp_size
glm.save_checkpoint(args.output_dir, save_config=(rank == 0))
del glm

View File

@ -393,6 +393,8 @@ def convert_and_save_meta(args, rank):
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim)
llama.config.mapping.cp_size = args.cp_size
llama.config.mapping.attn_tp_size = -1
llama.config.mapping.attn_cp_size = -1
llama.config.mapping.world_size *= args.cp_size
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
@ -511,6 +513,8 @@ def convert_and_save_hf(args):
f'Total time of reading and converting: {time.time()-tik:.3f} s'
)
llama.config.mapping.cp_size = args.cp_size
llama.config.mapping.attn_tp_size = -1
llama.config.mapping.attn_cp_size = -1
llama.config.mapping.world_size *= args.cp_size
tik = time.time()
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))

View File

@ -281,6 +281,8 @@ def convert_and_save_hf(args):
quant_config=quant_config,
**override_fields)
qwen.config.mapping.cp_size = args.cp_size
qwen.config.mapping.attn_tp_size = -1
qwen.config.mapping.attn_cp_size = -1
qwen.config.mapping.world_size *= args.cp_size
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
del qwen

View File

@ -5213,6 +5213,7 @@ def gpt_attention(
cp_group: List[int] = [0],
cp_size: int = 1,
cp_rank: int = 0,
num_kv_heads_origin: int = -1,
) -> Tuple[Tensor, Optional[Tensor]]:
'''
Add an operation that performs the multi-head attention in GPT-like models.
@ -5474,6 +5475,9 @@ def gpt_attention(
skip_attn: Tensor = None,
A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.
num_kv_heads_origin: int
The origin number of KV heads, without the process of TP
Returns:
The tensor produced by that layer.
'''
@ -5505,6 +5509,9 @@ def gpt_attention(
else:
use_logn_scaling = 0
if num_kv_heads_origin < 1:
num_kv_heads_origin = num_kv_heads
unfuse_qkv_gemm = trt.PluginField(
"unfuse_qkv_gemm", np.array(np.int8(is_unfuse_qkv_gemm), dtype=np.int8),
trt.PluginFieldType.INT8)
@ -5523,6 +5530,9 @@ def gpt_attention(
num_kv_heads = trt.PluginField("num_kv_heads",
np.array(num_kv_heads, dtype=np.int32),
trt.PluginFieldType.INT32)
num_kv_heads_origin = trt.PluginField(
"num_kv_heads_origin", np.array(num_kv_heads_origin, dtype=np.int32),
trt.PluginFieldType.INT32)
head_size = trt.PluginField("head_size",
np.array(hidden_size_per_head, dtype=np.int32),
trt.PluginFieldType.INT32)
@ -5714,9 +5724,10 @@ def gpt_attention(
trt.PluginFieldType.INT8)
pfc = trt.PluginFieldCollection([
layer_idx, nheads, vision_start, vision_length, num_kv_heads, head_size,
unidirectional, q_scaling, attn_logit_softcapping_scale,
position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
layer_idx, nheads, vision_start, vision_length, num_kv_heads,
num_kv_heads_origin, head_size, unidirectional, q_scaling,
attn_logit_softcapping_scale, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base,
rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale,
rotary_embedding_max_positions, rotary_embedding_original_max_positions,

View File

@ -395,13 +395,13 @@ class Attention(Module):
self.cross_attention = cross_attention
self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_kv_heads = num_kv_heads
assert num_attention_heads % tp_size == 0, \
"num_attention_heads must be divisible by tp_size"
self.num_attention_heads = num_attention_heads // tp_size
self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else self.num_attention_heads
self.hidden_size = hidden_size
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.max_position_embeddings = max_position_embeddings
@ -1058,6 +1058,7 @@ class Attention(Module):
layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim,
@ -1866,6 +1867,7 @@ class CogVLMAttention(Attention):
layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
position_embedding_type=self.position_embedding_type,
@ -2216,6 +2218,7 @@ class DeepseekV2Attention(Attention):
layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads,
num_kv_heads=1,
num_kv_heads_origin=1,
hidden_size_per_head=self.kv_lora_rank + self.qk_rope_head_dim,
q_scaling=self.q_scaling,
position_embedding_type=self.position_embedding_type,

View File

@ -123,6 +123,8 @@ class Mapping(object):
pp_size=1,
moe_tp_size=-1, # -1 means no moe
moe_ep_size=-1, # -1 means no moe
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False,
enable_attention_dp=False):
# set default values for non-moe cases
@ -137,6 +139,22 @@ class Mapping(object):
elif moe_ep_size == -1:
moe_ep_size = tp_size // moe_tp_size
if attn_tp_size == -1 and attn_cp_size == -1:
# fallback to ulysses
attn_tp_size = tp_size * cp_size
attn_cp_size = 1
elif attn_tp_size == -1:
attn_tp_size = cp_size * tp_size // attn_cp_size
elif attn_cp_size == -1:
attn_cp_size = cp_size * tp_size // attn_tp_size
if attn_cp_size != 1:
raise ValueError(
f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}."
)
if auto_parallel:
if tp_size != 1 or pp_size != 1 or tp_size != 1:
raise ValueError(
@ -154,6 +172,12 @@ class Mapping(object):
f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
)
attn_tp_cp_size = attn_tp_size * attn_cp_size
if attn_tp_cp_size != tp_size * cp_size:
raise ValueError(
f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
)
if moe_ep_size != 1 and cp_size > 1:
raise NotImplementedError("CP don't support MoE tp/ep yet")
@ -163,6 +187,8 @@ class Mapping(object):
self.pp_size = pp_size
self.moe_tp_size = moe_tp_size
self.moe_ep_size = moe_ep_size
self.attn_tp_size = attn_tp_size
self.attn_cp_size = attn_cp_size
self.auto_parallel = auto_parallel
self.world_size = world_size
self.rank = rank
@ -218,6 +244,8 @@ class Mapping(object):
and self.pp_size == other.pp_size
and self.moe_tp_size == other.moe_tp_size
and self.moe_ep_size == other.moe_ep_size
and self.attn_tp_size == other.attn_tp_size
and self.attn_cp_size == other.attn_cp_size
and self.auto_parallel == other.auto_parallel)
def __hash__(self):
@ -225,6 +253,7 @@ class Mapping(object):
^ hash(self.gpus_per_node) ^ hash(self.cp_size)
^ hash(self.tp_size) ^ hash(self.pp_size)
^ hash(self.moe_tp_size) ^ hash(self.moe_ep_size)
^ hash(self.attn_tp_size) ^ hash(self.attn_cp_size)
^ hash(self.auto_parallel))
@property
@ -375,5 +404,7 @@ class Mapping(object):
'pp_size': self.pp_size,
'moe_tp_size': self.moe_tp_size,
'moe_ep_size': self.moe_ep_size,
'attn_tp_size': self.attn_tp_size,
'attn_cp_size': self.attn_cp_size,
'auto_parallel': self.auto_parallel,
}

View File

@ -1850,6 +1850,7 @@ class Fp8RowwiseAttention(Module):
self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_attention_heads = num_attention_heads // tp_size
self.num_kv_heads = num_kv_heads
self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads
@ -2010,6 +2011,7 @@ class Fp8RowwiseAttention(Module):
layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim,
@ -2467,6 +2469,7 @@ class SmoothQuantAttention(Module):
self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_attention_heads = num_attention_heads // tp_size
self.num_kv_heads = num_kv_heads
self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads
@ -2634,6 +2637,7 @@ class SmoothQuantAttention(Module):
layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim,
@ -2987,6 +2991,7 @@ class QServeAttention(Module):
self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_attention_heads = num_attention_heads // tp_size
self.num_kv_heads = num_kv_heads
self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads
@ -3154,6 +3159,7 @@ class QServeAttention(Module):
layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim,

View File

@ -905,6 +905,8 @@ def quantize_and_export(*,
with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f)
tensorrt_llm_config["mapping"]["cp_size"] = cp_size
tensorrt_llm_config["mapping"]["attn_tp_size"] = -1
tensorrt_llm_config["mapping"]["attn_cp_size"] = -1
tensorrt_llm_config["mapping"]["world_size"] *= cp_size
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)