fix: fix for cp > kvHeadNum (#3002)

* fix for cp > kvHeadNum

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

* fix for None kv_head_num

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

---------

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
This commit is contained in:
DylanChen-NV 2025-03-26 12:39:02 +08:00 committed by GitHub
parent 25f2434495
commit 1ac0566a93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 346 additions and 238 deletions

View File

@ -166,9 +166,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams = {}; xqaParams = {};
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type; xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
// TODO(ziqingc): A better description for these parameters affected by CP size xqaParams.num_q_heads = mNumAttnHeads;
xqaParams.num_q_heads = mNumHeads / mCpSize; // when we use CP, the MHA part is spilt like TP xqaParams.num_kv_heads = mNumAttnKVHeads;
xqaParams.num_kv_heads = mNumKVHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
xqaParams.head_size = mHeadSize; xqaParams.head_size = mHeadSize;
xqaParams.unidirectional = mUnidirectional; xqaParams.unidirectional = mUnidirectional;
xqaParams.q_scaling = mQScaling; xqaParams.q_scaling = mQScaling;
@ -265,15 +264,151 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
return true; return true;
} }
template <typename T>
int AttentionOp::ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
{
int32_t partialTokenNum = 0;
int32_t maxPartialLength = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
// full request: [bs, seqlen, head, headSize]
//
// input of cp: [bs, partialLength, head, headSize]
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
// view_2 as [bs, totalLength, partialHead, headSize]
// and this is same to the input under TP.
//
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
// input of cp: [partialNumTokens, head, headSize]
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
// transpose_2 as [NumTokens, partialHead, headSize]
// and this is same to the input under TP.
// view_1 + transpose_1
invokeCpTranspose(output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
mUlyssesMQABroadcast, getHeadSize(), mCpRank, stream);
sync_check_cuda_error(stream);
// Do all to all
#if ENABLE_MULTI_DEVICE
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
NCCLCHECK(ncclSend(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
NCCLCHECK(ncclRecv(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
}
}
ncclGroupEnd();
sync_check_cuda_error(stream);
#endif // ENABLE_MULTI_DEVICE
// transpose_2 + view_2
invokeCpTranspose2(output, buffer, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens, mCpSize,
maxPartialLength, params.batch_size, partialHeads, getHeadSize(), stream);
return 0;
}
template <typename T>
int AttentionOp::ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
{
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
// view: [partialTokens, head, headSize]
int32_t maxPartialLength = 0;
int32_t partialTokenNum = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
// transpose_2_reverse
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(buffer),
reinterpret_cast<__nv_fp8_e4m3 const*>(input), params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
}
else
{
invokeCpTransposeToSeqMajor2(buffer, input, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
}
// all-to-all
#if ENABLE_MULTI_DEVICE
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
if (mFP8ContextFMHA)
{
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
cpIdx, *mCpNcclComm, stream));
}
else
{
NCCLCHECK(ncclSend(
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
}
}
}
ncclGroupEnd();
#endif // ENABLE_MULTI_DEVICE
// transpose_1_reverse + view
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
reinterpret_cast<__nv_fp8_e4m3 const*>(buffer), reinterpret_cast<__nv_fp8_e4m3 const*>(input),
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
}
else
{
invokeCpTransposeToSeqMajor<T>(
(T*) output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
}
return 0;
}
template <typename T> template <typename T>
int AttentionOp::ulyssesGenerationPreprocess( int AttentionOp::ulyssesGenerationPreprocess(
int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream) T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
{ {
if (mCpSize <= 1) if (mCpSize <= 1)
return 0; return 0;
auto const partialQHeads = mNumHeads / mCpSize;
auto const partialKVHeads = mNumKVHeads / mCpSize;
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize; auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
// attention_input shape: [partialTokenNum, numHeads, headSize] // attention_input shape: [partialTokenNum, numHeads, headSize]
@ -284,27 +419,26 @@ int AttentionOp::ulyssesGenerationPreprocess(
// do transpose_1 // do transpose_1
// [1, mNumHeads + 2*mNumKVHeads, headSize] // [1, mNumHeads + 2*mNumKVHeads, headSize]
// -> (view) [1, cpSize * partialQHeads + cpSize * partialKVHeads + cpSize * partilKVHeads, // -> (view) [1, cpSize * mNumAttnHeads + cpSize * mNumAttnKVHeads + cpSize * partilKVHeads,
// headSize] // headSize]
// -> (transpose) [cpSize, 1, partialQHeads + partialKvHeads + partialKVHeads, headSize] // -> (transpose) [cpSize, 1, mNumAttnHeads + mNumAttnKVHeads + mNumAttnKVHeads, headSize]
invokeCpTranspose(mhaOutput, mhaInput, input, partialTokenNum, mCpSize, partialQHeads, partialKVHeads, mHeadSize, invokeCpTranspose(buffer, output, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
mCpRank, stream); mUlyssesMQABroadcast, mHeadSize, mCpRank, stream);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
// Do all to all // Do all to all
#if ENABLE_MULTI_DEVICE #if ENABLE_MULTI_DEVICE
auto const totalHeads = mNumHeads + 2 * mNumKVHeads; auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
auto const partialHeads = totalHeads / mCpSize;
ncclGroupStart(); ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++) for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{ {
if (cpIdx != mCpRank) if (cpIdx != mCpRank)
{ {
NCCLCHECK(ncclSend(mhaOutput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads), NCCLCHECK(ncclSend(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, (partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream)); stream));
NCCLCHECK(ncclRecv(mhaInput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads), NCCLCHECK(ncclRecv(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, (partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream)); stream));
} }
@ -312,14 +446,11 @@ int AttentionOp::ulyssesGenerationPreprocess(
ncclGroupEnd(); ncclGroupEnd();
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
#endif // ENABLE_MULTI_DEVICE #endif // ENABLE_MULTI_DEVICE
input = mhaInput;
return 0; return 0;
} }
template <typename T> template <typename T>
int AttentionOp::ulyssesGenerationPostprocess( int AttentionOp::ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream)
{ {
if (mCpSize <= 1) if (mCpSize <= 1)
return 0; return 0;
@ -330,12 +461,11 @@ int AttentionOp::ulyssesGenerationPostprocess(
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize] // transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
// view: [partialTokens, head, headSize] // view: [partialTokens, head, headSize]
auto partialHeads = mNumHeads / mCpSize;
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize; auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
// do all-to-all // do all-to-all
#if ENABLE_MULTI_DEVICE #if ENABLE_MULTI_DEVICE
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads; size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
ncclGroupStart(); ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++) for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{ {
@ -343,17 +473,17 @@ int AttentionOp::ulyssesGenerationPostprocess(
{ {
if (mFP8ContextFMHA) if (mFP8ContextFMHA)
{ {
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(mhaOutput) + cpIdx * elementNum, elementNum, NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
ncclInt8, cpIdx, *mCpNcclComm, stream)); cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(mhaInput) + cpIdx * elementNum, elementNum, NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
ncclInt8, cpIdx, *mCpNcclComm, stream)); cpIdx, *mCpNcclComm, stream));
} }
else else
{ {
NCCLCHECK(ncclSend( NCCLCHECK(ncclSend(
mhaOutput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream)); input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv( NCCLCHECK(ncclRecv(
mhaInput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream)); buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
} }
} }
} }
@ -364,15 +494,14 @@ int AttentionOp::ulyssesGenerationPostprocess(
if (mFP8ContextFMHA) if (mFP8ContextFMHA)
{ {
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output), invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
reinterpret_cast<__nv_fp8_e4m3 const*>(mhaOutput), reinterpret_cast<__nv_fp8_e4m3 const*>(mhaInput), reinterpret_cast<__nv_fp8_e4m3 const*>(input), reinterpret_cast<__nv_fp8_e4m3 const*>(buffer),
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream); partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
} }
else else
{ {
invokeCpTransposeToSeqMajor<T>( invokeCpTransposeToSeqMajor<T>(
(T*) output, mhaOutput, mhaInput, partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream); (T*) output, input, buffer, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
} }
sync_check_cuda_error(stream);
return 0; return 0;
} }
@ -532,8 +661,8 @@ int AttentionOp::getHeadSize(bool checkInit) const
size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t max_num_seq, int32_t input_seq_length, size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t max_num_seq, int32_t input_seq_length,
int32_t cross_kv_length, int32_t max_num_tokens) const noexcept int32_t cross_kv_length, int32_t max_num_tokens) const noexcept
{ {
int const local_hidden_units_qo = mNumHeads / mCpSize * getHeadSize(); int const local_hidden_units_qo = mNumAttnHeads * getHeadSize();
int const local_hidden_units_kv = mNumKVHeads / mCpSize * getHeadSize(); int const local_hidden_units_kv = mNumAttnKVHeads * getHeadSize();
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize(); auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
@ -815,9 +944,8 @@ int AttentionOp::mlaGeneration(
tllmRunnerParams.mHeadDimQk = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim; tllmRunnerParams.mHeadDimQk = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
tllmRunnerParams.mHeadDimV = mMLAParams.kv_lora_rank; tllmRunnerParams.mHeadDimV = mMLAParams.kv_lora_rank;
auto const num_q_heads = mNumHeads / mCpSize; auto const num_q_heads = mNumAttnHeads;
tllmRunnerParams.mNumHeadsQ = num_q_heads; tllmRunnerParams.mNumHeadsQ = num_q_heads;
// const auto num_kv_heads_tllm_runner_params = mNumKVHeads / mCpSize;
tllmRunnerParams.mNumHeadsKv = num_kv_heads; tllmRunnerParams.mNumHeadsKv = num_kv_heads;
tllmRunnerParams.mNumHeadsQPerKv = num_q_heads / num_kv_heads; tllmRunnerParams.mNumHeadsQPerKv = num_q_heads / num_kv_heads;
@ -1007,13 +1135,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
int const headSize = getHeadSize(); int const headSize = getHeadSize();
int const local_hidden_units_qo = mNumHeads * headSize; int const local_hidden_units_qo = mNumHeads * headSize;
int const local_hidden_units_kv = mNumKVHeads / mCpSize * headSize; int const local_hidden_units_kv = mNumAttnKVHeads * headSize;
PositionEmbeddingType const position_embedding_type = mPositionEmbeddingType; PositionEmbeddingType const position_embedding_type = mPositionEmbeddingType;
float const q_scaling = mQScaling; float const q_scaling = mQScaling;
KVCacheBuffer kv_cache_buffer; KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize; auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
if (useKVCache()) if (useKVCache())
{ {
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>) if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@ -1252,72 +1380,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
// do all-to-all for params.attention_input, need to split on kv head // do all-to-all for params.attention_input, need to split on kv head
// [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size] // [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size]
T* attention_input = const_cast<T*>(params.attention_input); T* attention_input = const_cast<T*>(params.attention_input);
if (mCpSize > 1) if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{ {
int32_t partialTokenNum = 0; this->template ulyssesContextPreprocess<T>(
int32_t maxPartialLength = 0; attention_input, gatherInBuffer, gatherOutBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
auto const partialHeads = totalHeads / mCpSize;
auto const partialQHeads = mNumHeads / mCpSize;
auto const partialKVHeads = mNumKVHeads / mCpSize;
// full request: [bs, seqlen, head, headSize]
//
// input of cp: [bs, partialLength, head, headSize]
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
// view_2 as [bs, totalLength, partialHead, headSize]
// and this is same to the input under TP.
//
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
// input of cp: [partialNumTokens, head, headSize]
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
// transpose_2 as [NumTokens, partialHead, headSize]
// and this is same to the input under TP.
// view_1 + transpose_1
invokeCpTranspose(gatherInBuffer, gatherOutBuffer, params.attention_input, partialTokenNum, mCpSize,
partialQHeads, partialKVHeads, getHeadSize(), mCpRank, stream);
sync_check_cuda_error(stream);
// Do all to all
#if ENABLE_MULTI_DEVICE
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
stream));
}
}
ncclGroupEnd();
sync_check_cuda_error(stream);
#endif // ENABLE_MULTI_DEVICE
// transpose_2 + view_2
invokeCpTranspose2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
stream);
attention_input = gatherInBuffer; attention_input = gatherInBuffer;
sync_check_cuda_error(stream);
} }
sync_check_cuda_error(stream);
bool const enablePagedKVContextFMHA = mPagedKVCache && mPagedContextFMHA; bool const enablePagedKVContextFMHA = mPagedKVCache && mPagedContextFMHA;
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasInt8KvCache() && enablePagedKVContextFMHA), TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasInt8KvCache() && enablePagedKVContextFMHA),
@ -1363,9 +1432,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.token_num = params.num_tokens; preprocessingParams.token_num = params.num_tokens;
preprocessingParams.remove_padding = mRemovePadding; preprocessingParams.remove_padding = mRemovePadding;
preprocessingParams.cross_attention = isCrossAttention(); preprocessingParams.cross_attention = isCrossAttention();
preprocessingParams.head_num = mNumHeads / mCpSize; preprocessingParams.head_num = mNumAttnHeads;
preprocessingParams.kv_head_num = mNumKVHeads / mCpSize; preprocessingParams.kv_head_num = mNumAttnKVHeads;
preprocessingParams.qheads_per_kv_head = mNumHeads / mNumKVHeads; preprocessingParams.qheads_per_kv_head = mNumAttnHeads / mNumAttnKVHeads;
preprocessingParams.size_per_head = getHeadSize(); preprocessingParams.size_per_head = getHeadSize();
preprocessingParams.rotary_embedding_dim = mRotaryEmbeddingDim; preprocessingParams.rotary_embedding_dim = mRotaryEmbeddingDim;
preprocessingParams.rotary_embedding_base = mRotaryEmbeddingBase; preprocessingParams.rotary_embedding_base = mRotaryEmbeddingBase;
@ -1479,79 +1548,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
invokeKvCachePostprocessing(preprocessingParams, stream); invokeKvCachePostprocessing(preprocessingParams, stream);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
if (mCpSize > 1) if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{ {
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize] this->template ulyssesContextPostprocess<T>(gatherOutBuffer, reinterpret_cast<T*>(params.context_buf),
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize] gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
// view: [partialTokens, head, headSize]
int32_t maxPartialLength = 0;
int32_t partialTokenNum = 0;
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
{
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
maxPartialLength = std::max(maxPartialLength, partialLength);
partialTokenNum += partialLength;
}
auto partialHeads = mNumHeads / mCpSize;
// transpose_2_reverse
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer),
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), params.context_lengths, cu_q_seqlens,
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
stream);
}
else
{
invokeCpTransposeToSeqMajor2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
stream);
}
// all-to-all
#if ENABLE_MULTI_DEVICE
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
ncclGroupStart();
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
{
if (cpIdx != mCpRank)
{
if (mFP8ContextFMHA)
{
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer) + cpIdx * elementNum,
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(gatherOutBuffer) + cpIdx * elementNum,
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
}
else
{
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
cpIdx, *mCpNcclComm, stream));
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
cpIdx, *mCpNcclComm, stream));
}
}
}
ncclGroupEnd();
#endif // ENABLE_MULTI_DEVICE
// transpose_1_reverse + view
if (mFP8ContextFMHA)
{
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(params.context_buf),
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherInBuffer),
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), partialTokenNum, mCpSize, partialHeads,
getHeadSize(), mCpRank, stream);
}
else
{
invokeCpTransposeToSeqMajor<T>((T*) params.context_buf, gatherInBuffer, gatherOutBuffer,
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
}
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
} }
} }
@ -1871,7 +1871,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
KVCacheBuffer kv_cache_buffer; KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto const sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize; auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
if (useKVCache()) if (useKVCache())
{ {
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>) if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@ -1936,7 +1936,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
T* mhaInput = mhaOutput + cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize; T* mhaInput = mhaOutput + cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize;
T* attention_input = const_cast<T*>(params.attention_input); T* attention_input = const_cast<T*>(params.attention_input);
this->template ulyssesGenerationPreprocess<T>(batch_beam, mhaInput, mhaOutput, attention_input, stream); if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPreprocess<T>(attention_input, mhaInput, mhaOutput, batch_beam, stream);
attention_input = mhaInput;
sync_check_cuda_error(stream);
}
// Try XQA optimization first. // Try XQA optimization first.
{ {
@ -1954,7 +1959,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
xqaParams.qkv = attention_input; xqaParams.qkv = attention_input;
} }
mXqaDispatcher->run(xqaParams, kv_cache_buffer); mXqaDispatcher->run(xqaParams, kv_cache_buffer);
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream); if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPostprocess<T>(
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
sync_check_cuda_error(stream);
}
return 0; return 0;
} }
else if (mIsSpecDecodingEnabled && mUseSpecDecoding) else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
@ -2040,8 +2050,8 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
dispatch_params.max_batch_size = batch_beam; dispatch_params.max_batch_size = batch_beam;
dispatch_params.inference_batch_size = batch_beam; dispatch_params.inference_batch_size = batch_beam;
dispatch_params.beam_width = params.beam_width; dispatch_params.beam_width = params.beam_width;
dispatch_params.head_num = mNumHeads / mCpSize; dispatch_params.head_num = mNumAttnHeads;
dispatch_params.kv_head_num = mNumKVHeads / mCpSize; dispatch_params.kv_head_num = mNumAttnKVHeads;
dispatch_params.size_per_head = getHeadSize(); dispatch_params.size_per_head = getHeadSize();
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim; dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
dispatch_params.position_embedding_type = mPositionEmbeddingType; dispatch_params.position_embedding_type = mPositionEmbeddingType;
@ -2104,7 +2114,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream); fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
} }
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream); if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPostprocess<T>(
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
sync_check_cuda_error(stream);
}
return 0; return 0;
} }
@ -2164,6 +2179,21 @@ template void AttentionOp::prepareEnqueueGeneration<__nv_bfloat16, KVBlockArray>
int AttentionOp::initialize() noexcept int AttentionOp::initialize() noexcept
{ {
// use Ulysses for GPTAttentionPlugin
if (mAttnTpSize < 0 || mAttnCpSize < 0)
{
mAttnTpSize = mTpSize * mCpSize;
mAttnCpSize = 1;
}
mNumAttnHeads = mNumHeads * mTpSize / mAttnTpSize;
mNumAttnKVHeads = (mNumKVHeads * mTpSize + mAttnTpSize - 1) / mAttnTpSize;
if (mCpSize != mAttnCpSize)
{
// mqa broadcast
mUlyssesMQABroadcast = (mAttnTpSize + mNumKVHeadsOrigin - 1) / mNumKVHeadsOrigin;
}
// Pre-check whether FMHA is supported in order to save memory allocation. // Pre-check whether FMHA is supported in order to save memory allocation.
if (mEnableContextFMHA) if (mEnableContextFMHA)
{ {
@ -2306,8 +2336,8 @@ int AttentionOp::initialize() noexcept
fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK; fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
} }
fmhaParams.isSPadded = !mRemovePadding; fmhaParams.isSPadded = !mRemovePadding;
fmhaParams.numQHeads = mNumHeads / mCpSize; fmhaParams.numQHeads = mNumAttnHeads;
fmhaParams.numKvHeads = mNumKVHeads / mCpSize; fmhaParams.numKvHeads = mNumAttnKVHeads;
fmhaParams.numTokensPerBlock = mTokensPerBlock; fmhaParams.numTokensPerBlock = mTokensPerBlock;
fmhaParams.headSize = mHeadSize; fmhaParams.headSize = mHeadSize;
fmhaParams.headSizeV = mHeadSize; fmhaParams.headSizeV = mHeadSize;
@ -2326,16 +2356,6 @@ int AttentionOp::initialize() noexcept
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale; fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
fmhaParams.hasAlibi = isALiBi(); fmhaParams.hasAlibi = isALiBi();
fmhaParams.scaleAlibi = isAliBiWithScale(); fmhaParams.scaleAlibi = isAliBiWithScale();
if (mTpSize > 1)
{
fmhaParams.tpSize = mTpSize;
fmhaParams.tpRank = mTpRank;
}
else if (mCpSize > 1)
{
fmhaParams.tpSize = mCpSize;
fmhaParams.tpRank = mCpRank;
}
// Load kernels from the pre-compiled cubins. // Load kernels from the pre-compiled cubins.
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams)); mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
@ -2479,8 +2499,8 @@ int AttentionOp::initialize() noexcept
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads."); TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode."); TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
} }
fixedParams.numQHeads = mNumHeads / mCpSize; fixedParams.numQHeads = mNumAttnHeads;
fixedParams.numKvHeads = mNumKVHeads / mCpSize; fixedParams.numKvHeads = mNumAttnKVHeads;
fixedParams.numTokensPerBlock = mTokensPerBlock; fixedParams.numTokensPerBlock = mTokensPerBlock;
fixedParams.headSize = mHeadSize; fixedParams.headSize = mHeadSize;
fixedParams.qScaling = mQScaling; fixedParams.qScaling = mQScaling;
@ -2555,6 +2575,7 @@ std::string AttentionOp::toString() const
ss << "gptAttentionCommon members ====================" << std::endl; ss << "gptAttentionCommon members ====================" << std::endl;
ss << "mNumHeads: " << mNumHeads << std::endl; ss << "mNumHeads: " << mNumHeads << std::endl;
ss << "mNumKVHeads: " << mNumKVHeads << std::endl; ss << "mNumKVHeads: " << mNumKVHeads << std::endl;
ss << "mNumKVHeadsOrigin: " << mNumKVHeadsOrigin << std::endl;
ss << "mHeadSize: " << mHeadSize << std::endl; ss << "mHeadSize: " << mHeadSize << std::endl;
ss << "mUnidirectional: " << mUnidirectional << std::endl; ss << "mUnidirectional: " << mUnidirectional << std::endl;
ss << "mQScaling: " << mQScaling << std::endl; ss << "mQScaling: " << mQScaling << std::endl;

View File

@ -229,10 +229,18 @@ public:
EnqueueGenerationParams<T> const& generationsParams, bool forConfigurePlugin); EnqueueGenerationParams<T> const& generationsParams, bool forConfigurePlugin);
template <typename T> template <typename T>
int ulyssesGenerationPreprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream); int ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
template <typename T> template <typename T>
int ulyssesGenerationPostprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream); int ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
template <typename T>
int ulyssesGenerationPreprocess(T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
template <typename T>
int ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
[[nodiscard]] bool isRelativePosition() const [[nodiscard]] bool isRelativePosition() const
{ {
@ -374,6 +382,16 @@ public:
int mCpSize = 1; int mCpSize = 1;
int mCpRank = 0; int mCpRank = 0;
std::set<int32_t> mCpGroup = {}; std::set<int32_t> mCpGroup = {};
// These parameters are used to specifically configure the attention attributes when cp/tp_size are different
// between Attention and FFN(such as Ulysses)
int mNumAttnHeads = -1;
int mNumAttnKVHeads = -1;
int mNumKVHeadsOrigin = -1;
int mAttnTpSize = -1;
int mAttnTpRank = 0;
int mAttnCpSize = -1;
int mAttnCpRank = 0;
int mUlyssesMQABroadcast = 1;
// fmha runner (enabled by default) // fmha runner (enabled by default)
// flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2 // flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2
@ -403,8 +421,9 @@ public:
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask, mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable, mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
mCpGroup, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
mFuseFp4Quant, mNbMultiBlockSemaphores); mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
}; };
private: private:

View File

@ -2139,7 +2139,7 @@ __global__ void convertData(Dst* dst, Src const* src, int64_t size, float const*
template <typename T> template <typename T>
__global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize,
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank) int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank)
{ {
// Do transpose from // Do transpose from
// [partialTokenNum, mNumHeads + 2*mNumKVHeads, headSize] // [partialTokenNum, mNumHeads + 2*mNumKVHeads, headSize]
@ -2164,12 +2164,12 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
} }
else if (headIdx < partialQHeads + partialKVHeads) else if (headIdx < partialQHeads + partialKVHeads)
{ {
srcHeadIdx = cpSize * partialQHeads + cpIdx * partialKVHeads + (headIdx - partialQHeads); srcHeadIdx = cpSize * partialQHeads + cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads);
} }
else else
{ {
srcHeadIdx = cpSize * partialQHeads + cpSize * partialKVHeads + cpIdx * partialKVHeads srcHeadIdx = cpSize * partialQHeads + cpSize / mqaBroadcast * partialKVHeads
+ (headIdx - partialQHeads - partialKVHeads); + cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads - partialKVHeads);
} }
if (cpIdx == rank) if (cpIdx == rank)
@ -2181,7 +2181,7 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
+ seqIdx * (partialQHeads + 2 * partialKVHeads) + headIdx) + seqIdx * (partialQHeads + 2 * partialKVHeads) + headIdx)
* headSize); * headSize);
VecType const* in = reinterpret_cast<VecType const*>( VecType const* in = reinterpret_cast<VecType const*>(
src + (seqIdx * (partialQHeads + 2 * partialKVHeads) * cpSize + srcHeadIdx) * headSize); src + (seqIdx * (partialQHeads * cpSize + 2 * partialKVHeads * cpSize / mqaBroadcast) + srcHeadIdx) * headSize);
for (int hiddenIdx = threadIdx.x; hiddenIdx < hiddenSize + hiddenRestSize; hiddenIdx += blockDim.x) for (int hiddenIdx = threadIdx.x; hiddenIdx < hiddenSize + hiddenRestSize; hiddenIdx += blockDim.x)
{ {
@ -2347,17 +2347,18 @@ INSTANTIATE_invokeConversion(__nv_fp8_e4m3, __nv_bfloat16);
template <typename T> template <typename T>
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, int64_t partialQHeads, void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, int64_t partialQHeads,
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream) int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream)
{ {
dim3 grid(partialTokenNum, cpSize, partialQHeads + 2 * partialKVHeads); dim3 grid(partialTokenNum, cpSize, partialQHeads + 2 * partialKVHeads);
dim3 block(128); dim3 block(128);
runCpTranspose<T><<<grid, block, 0, stream>>>( runCpTranspose<T><<<grid, block, 0, stream>>>(
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, headSize, rank); dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, mqaBroadcast, headSize, rank);
} }
#define INSTANTIATE_invokeCpTranspose(T) \ #define INSTANTIATE_invokeCpTranspose(T) \
template void invokeCpTranspose<T>(T * dst, T * dst2, T const* src, int64_t partialLength, int64_t cpSize, \ template void invokeCpTranspose<T>(T * dst, T * dst2, T const* src, int64_t partialLength, int64_t cpSize, \
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream) int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, \
cudaStream_t stream)
INSTANTIATE_invokeCpTranspose(float); INSTANTIATE_invokeCpTranspose(float);
INSTANTIATE_invokeCpTranspose(half); INSTANTIATE_invokeCpTranspose(half);
INSTANTIATE_invokeCpTranspose(__nv_bfloat16); INSTANTIATE_invokeCpTranspose(__nv_bfloat16);

View File

@ -407,7 +407,7 @@ void invokeConversion(Dst* dst, Src const* src, int64_t size, float const* __res
template <typename T> template <typename T>
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialLength, int64_t cpSize, int64_t partialQHeads, void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialLength, int64_t cpSize, int64_t partialQHeads,
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream); int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream);
template <typename T> template <typename T>
void invokeCpTransposeToSeqMajor(T* dst, T const* srcMyRank, T const* srcOtherRank, int64_t partialLength, void invokeCpTransposeToSeqMajor(T* dst, T const* srcMyRank, T const* srcOtherRank, int64_t partialLength,

View File

@ -28,8 +28,8 @@ using tensorrt_llm::plugins::GPTAttentionPluginCreatorCommon;
using tensorrt_llm::plugins::GPTAttentionPluginCommon; using tensorrt_llm::plugins::GPTAttentionPluginCommon;
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale, int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale, float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
@ -52,6 +52,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
mVisionStart = vision_start; mVisionStart = vision_start;
mVisionLength = vision_length; mVisionLength = vision_length;
mNumKVHeads = num_kv_heads; mNumKVHeads = num_kv_heads;
mNumKVHeadsOrigin = num_kv_heads_origin;
mHeadSize = head_size; mHeadSize = head_size;
mUnidirectional = unidirectional; mUnidirectional = unidirectional;
mQScaling = q_scaling; mQScaling = q_scaling;
@ -114,6 +115,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
read(d, mVisionStart); read(d, mVisionStart);
read(d, mVisionLength); read(d, mVisionLength);
read(d, mNumKVHeads); read(d, mNumKVHeads);
read(d, mNumKVHeadsOrigin);
read(d, mHeadSize); read(d, mHeadSize);
read(d, mUnidirectional); read(d, mUnidirectional);
read(d, mQScaling); read(d, mQScaling);
@ -201,13 +203,13 @@ void GPTAttentionPluginCommon::destroy() noexcept
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
{ {
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads) return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
+ sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mAttnLogitSoftcappingScale) + sizeof(mNumKVHeadsOrigin) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase) + sizeof(mAttnLogitSoftcappingScale) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingShortMscale) + sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
+ sizeof(mRotaryEmbeddingLongMscale) + sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingShortMscale) + sizeof(mRotaryEmbeddingLongMscale)
+ sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize)
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode)
+ sizeof(unsigned int) // mKVCacheQuantMode + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache) + sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache)
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
@ -228,6 +230,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
write(d, mVisionStart); write(d, mVisionStart);
write(d, mVisionLength); write(d, mVisionLength);
write(d, mNumKVHeads); write(d, mNumKVHeads);
write(d, mNumKVHeadsOrigin);
write(d, mHeadSize); write(d, mHeadSize);
write(d, mUnidirectional); write(d, mUnidirectional);
write(d, mQScaling); write(d, mQScaling);
@ -307,6 +310,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("num_kv_heads_origin", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("layer_idx_in_cache_pool", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("layer_idx_in_cache_pool", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32));

View File

@ -35,7 +35,7 @@ public:
GPTAttentionPluginCommon() = delete; GPTAttentionPluginCommon() = delete;
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads, GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,

View File

@ -46,8 +46,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"}; static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale, int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_scale, float rotary_embedding_short_m_scale,
@ -65,17 +65,17 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank, int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank,
int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size, int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size,
int cp_rank, std::set<int32_t> cp_group) int cp_rank, std::set<int32_t> cp_group)
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size, : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, num_kv_heads_origin,
unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type, rotary_embedding_dim, head_size, unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type,
rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size, rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions,
tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type, kv_cache_quant_mode, remove_input_padding, rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type,
mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled, kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block,
cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, use_paged_context_fmha, type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
use_fp8_context_fmha, has_full_attention_mask, use_cache, is_spec_decoding_enabled, dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, use_cache,
spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank, is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length,
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant, skip_attn, cp_size, cp_rank, is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant,
cp_group) skip_attn, cp_size, cp_rank, cp_group)
{ {
TLLM_CHECK_WITH_INFO( TLLM_CHECK_WITH_INFO(
!is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead."); !is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead.");
@ -890,7 +890,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T)); auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
auto const blockSize = mTokensPerBlock * mNumKVHeads / mCpSize * mHeadSize; auto const kv_cache_head_num = (mNumKVHeads + mCpSize - 1) / mCpSize;
auto const blockSize = mTokensPerBlock * kv_cache_head_num * mHeadSize;
auto const bytesPerBlock = blockSize * cacheElemSize; auto const bytesPerBlock = blockSize * cacheElemSize;
auto const layerOffset = layerIdxInCachePool * 2 * bytesPerBlock; auto const layerOffset = layerIdxInCachePool * 2 * bytesPerBlock;
@ -1311,8 +1312,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(), auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(), p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(), p.getScalar<int32_t>("num_kv_heads_origin").value(), p.getScalar<int32_t>("head_size").value(),
p.getScalar<float>("q_scaling").value(), p.getScalar<float>("attn_logit_softcapping_scale").value(), p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
p.getScalar<float>("attn_logit_softcapping_scale").value(),
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()), static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(), p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()), static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),

View File

@ -101,7 +101,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
{ {
public: public:
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads, GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,

View File

@ -236,6 +236,8 @@ def convert_and_save_hf(args):
load_model_on_cpu=args.load_model_on_cpu, load_model_on_cpu=args.load_model_on_cpu,
**override_fields) **override_fields)
glm.config.mapping.cp_size = args.cp_size glm.config.mapping.cp_size = args.cp_size
glm.config.mapping.attn_tp_size = -1
glm.config.mapping.attn_cp_size = -1
glm.config.mapping.world_size *= args.cp_size glm.config.mapping.world_size *= args.cp_size
glm.save_checkpoint(args.output_dir, save_config=(rank == 0)) glm.save_checkpoint(args.output_dir, save_config=(rank == 0))
del glm del glm

View File

@ -393,6 +393,8 @@ def convert_and_save_meta(args, rank):
use_parallel_embedding=args.use_parallel_embedding, use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim) embedding_sharding_dim=args.embedding_sharding_dim)
llama.config.mapping.cp_size = args.cp_size llama.config.mapping.cp_size = args.cp_size
llama.config.mapping.attn_tp_size = -1
llama.config.mapping.attn_cp_size = -1
llama.config.mapping.world_size *= args.cp_size llama.config.mapping.world_size *= args.cp_size
llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
@ -511,6 +513,8 @@ def convert_and_save_hf(args):
f'Total time of reading and converting: {time.time()-tik:.3f} s' f'Total time of reading and converting: {time.time()-tik:.3f} s'
) )
llama.config.mapping.cp_size = args.cp_size llama.config.mapping.cp_size = args.cp_size
llama.config.mapping.attn_tp_size = -1
llama.config.mapping.attn_cp_size = -1
llama.config.mapping.world_size *= args.cp_size llama.config.mapping.world_size *= args.cp_size
tik = time.time() tik = time.time()
llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) llama.save_checkpoint(args.output_dir, save_config=(rank == 0))

View File

@ -281,6 +281,8 @@ def convert_and_save_hf(args):
quant_config=quant_config, quant_config=quant_config,
**override_fields) **override_fields)
qwen.config.mapping.cp_size = args.cp_size qwen.config.mapping.cp_size = args.cp_size
qwen.config.mapping.attn_tp_size = -1
qwen.config.mapping.attn_cp_size = -1
qwen.config.mapping.world_size *= args.cp_size qwen.config.mapping.world_size *= args.cp_size
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0)) qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
del qwen del qwen

View File

@ -5213,6 +5213,7 @@ def gpt_attention(
cp_group: List[int] = [0], cp_group: List[int] = [0],
cp_size: int = 1, cp_size: int = 1,
cp_rank: int = 0, cp_rank: int = 0,
num_kv_heads_origin: int = -1,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
''' '''
Add an operation that performs the multi-head attention in GPT-like models. Add an operation that performs the multi-head attention in GPT-like models.
@ -5474,6 +5475,9 @@ def gpt_attention(
skip_attn: Tensor = None, skip_attn: Tensor = None,
A bool tensor on CPU. If it is true, don't run attention plugin, returning directly. A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.
num_kv_heads_origin: int
The origin number of KV heads, without the process of TP
Returns: Returns:
The tensor produced by that layer. The tensor produced by that layer.
''' '''
@ -5505,6 +5509,9 @@ def gpt_attention(
else: else:
use_logn_scaling = 0 use_logn_scaling = 0
if num_kv_heads_origin < 1:
num_kv_heads_origin = num_kv_heads
unfuse_qkv_gemm = trt.PluginField( unfuse_qkv_gemm = trt.PluginField(
"unfuse_qkv_gemm", np.array(np.int8(is_unfuse_qkv_gemm), dtype=np.int8), "unfuse_qkv_gemm", np.array(np.int8(is_unfuse_qkv_gemm), dtype=np.int8),
trt.PluginFieldType.INT8) trt.PluginFieldType.INT8)
@ -5523,6 +5530,9 @@ def gpt_attention(
num_kv_heads = trt.PluginField("num_kv_heads", num_kv_heads = trt.PluginField("num_kv_heads",
np.array(num_kv_heads, dtype=np.int32), np.array(num_kv_heads, dtype=np.int32),
trt.PluginFieldType.INT32) trt.PluginFieldType.INT32)
num_kv_heads_origin = trt.PluginField(
"num_kv_heads_origin", np.array(num_kv_heads_origin, dtype=np.int32),
trt.PluginFieldType.INT32)
head_size = trt.PluginField("head_size", head_size = trt.PluginField("head_size",
np.array(hidden_size_per_head, dtype=np.int32), np.array(hidden_size_per_head, dtype=np.int32),
trt.PluginFieldType.INT32) trt.PluginFieldType.INT32)
@ -5714,9 +5724,10 @@ def gpt_attention(
trt.PluginFieldType.INT8) trt.PluginFieldType.INT8)
pfc = trt.PluginFieldCollection([ pfc = trt.PluginFieldCollection([
layer_idx, nheads, vision_start, vision_length, num_kv_heads, head_size, layer_idx, nheads, vision_start, vision_length, num_kv_heads,
unidirectional, q_scaling, attn_logit_softcapping_scale, num_kv_heads_origin, head_size, unidirectional, q_scaling,
position_embedding_type, rotary_embedding_dim, rotary_embedding_base, attn_logit_softcapping_scale, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base,
rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_short_m_scale, rotary_embedding_long_m_scale,
rotary_embedding_max_positions, rotary_embedding_original_max_positions, rotary_embedding_max_positions, rotary_embedding_original_max_positions,

View File

@ -395,13 +395,13 @@ class Attention(Module):
self.cross_attention = cross_attention self.cross_attention = cross_attention
self.attention_mask_type = attention_mask_type self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_kv_heads = num_kv_heads
assert num_attention_heads % tp_size == 0, \ assert num_attention_heads % tp_size == 0, \
"num_attention_heads must be divisible by tp_size" "num_attention_heads must be divisible by tp_size"
self.num_attention_heads = num_attention_heads // tp_size self.num_attention_heads = num_attention_heads // tp_size
self.num_attention_kv_heads = ( self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1 num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads ) // tp_size if num_kv_heads is not None else self.num_attention_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else self.num_attention_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
@ -1058,6 +1058,7 @@ class Attention(Module):
layer_idx=self.local_layer_idx, layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads, num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size, hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling, q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_dim=self.rotary_embedding_dim,
@ -1866,6 +1867,7 @@ class CogVLMAttention(Attention):
layer_idx=self.local_layer_idx, layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads, num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size, hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling, q_scaling=self.q_scaling,
position_embedding_type=self.position_embedding_type, position_embedding_type=self.position_embedding_type,
@ -2216,6 +2218,7 @@ class DeepseekV2Attention(Attention):
layer_idx=self.local_layer_idx, layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
num_kv_heads=1, num_kv_heads=1,
num_kv_heads_origin=1,
hidden_size_per_head=self.kv_lora_rank + self.qk_rope_head_dim, hidden_size_per_head=self.kv_lora_rank + self.qk_rope_head_dim,
q_scaling=self.q_scaling, q_scaling=self.q_scaling,
position_embedding_type=self.position_embedding_type, position_embedding_type=self.position_embedding_type,

View File

@ -123,6 +123,8 @@ class Mapping(object):
pp_size=1, pp_size=1,
moe_tp_size=-1, # -1 means no moe moe_tp_size=-1, # -1 means no moe
moe_ep_size=-1, # -1 means no moe moe_ep_size=-1, # -1 means no moe
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False, auto_parallel=False,
enable_attention_dp=False): enable_attention_dp=False):
# set default values for non-moe cases # set default values for non-moe cases
@ -137,6 +139,22 @@ class Mapping(object):
elif moe_ep_size == -1: elif moe_ep_size == -1:
moe_ep_size = tp_size // moe_tp_size moe_ep_size = tp_size // moe_tp_size
if attn_tp_size == -1 and attn_cp_size == -1:
# fallback to ulysses
attn_tp_size = tp_size * cp_size
attn_cp_size = 1
elif attn_tp_size == -1:
attn_tp_size = cp_size * tp_size // attn_cp_size
elif attn_cp_size == -1:
attn_cp_size = cp_size * tp_size // attn_tp_size
if attn_cp_size != 1:
raise ValueError(
f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}."
)
if auto_parallel: if auto_parallel:
if tp_size != 1 or pp_size != 1 or tp_size != 1: if tp_size != 1 or pp_size != 1 or tp_size != 1:
raise ValueError( raise ValueError(
@ -154,6 +172,12 @@ class Mapping(object):
f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}" f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
) )
attn_tp_cp_size = attn_tp_size * attn_cp_size
if attn_tp_cp_size != tp_size * cp_size:
raise ValueError(
f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
)
if moe_ep_size != 1 and cp_size > 1: if moe_ep_size != 1 and cp_size > 1:
raise NotImplementedError("CP don't support MoE tp/ep yet") raise NotImplementedError("CP don't support MoE tp/ep yet")
@ -163,6 +187,8 @@ class Mapping(object):
self.pp_size = pp_size self.pp_size = pp_size
self.moe_tp_size = moe_tp_size self.moe_tp_size = moe_tp_size
self.moe_ep_size = moe_ep_size self.moe_ep_size = moe_ep_size
self.attn_tp_size = attn_tp_size
self.attn_cp_size = attn_cp_size
self.auto_parallel = auto_parallel self.auto_parallel = auto_parallel
self.world_size = world_size self.world_size = world_size
self.rank = rank self.rank = rank
@ -218,6 +244,8 @@ class Mapping(object):
and self.pp_size == other.pp_size and self.pp_size == other.pp_size
and self.moe_tp_size == other.moe_tp_size and self.moe_tp_size == other.moe_tp_size
and self.moe_ep_size == other.moe_ep_size and self.moe_ep_size == other.moe_ep_size
and self.attn_tp_size == other.attn_tp_size
and self.attn_cp_size == other.attn_cp_size
and self.auto_parallel == other.auto_parallel) and self.auto_parallel == other.auto_parallel)
def __hash__(self): def __hash__(self):
@ -225,6 +253,7 @@ class Mapping(object):
^ hash(self.gpus_per_node) ^ hash(self.cp_size) ^ hash(self.gpus_per_node) ^ hash(self.cp_size)
^ hash(self.tp_size) ^ hash(self.pp_size) ^ hash(self.tp_size) ^ hash(self.pp_size)
^ hash(self.moe_tp_size) ^ hash(self.moe_ep_size) ^ hash(self.moe_tp_size) ^ hash(self.moe_ep_size)
^ hash(self.attn_tp_size) ^ hash(self.attn_cp_size)
^ hash(self.auto_parallel)) ^ hash(self.auto_parallel))
@property @property
@ -375,5 +404,7 @@ class Mapping(object):
'pp_size': self.pp_size, 'pp_size': self.pp_size,
'moe_tp_size': self.moe_tp_size, 'moe_tp_size': self.moe_tp_size,
'moe_ep_size': self.moe_ep_size, 'moe_ep_size': self.moe_ep_size,
'attn_tp_size': self.attn_tp_size,
'attn_cp_size': self.attn_cp_size,
'auto_parallel': self.auto_parallel, 'auto_parallel': self.auto_parallel,
} }

View File

@ -1850,6 +1850,7 @@ class Fp8RowwiseAttention(Module):
self.attention_mask_type = attention_mask_type self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_attention_heads = num_attention_heads // tp_size self.num_attention_heads = num_attention_heads // tp_size
self.num_kv_heads = num_kv_heads
self.num_attention_kv_heads = ( self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1 num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads ) // tp_size if num_kv_heads is not None else self.num_attention_heads
@ -2010,6 +2011,7 @@ class Fp8RowwiseAttention(Module):
layer_idx=self.local_layer_idx, layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads, num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size, hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling, q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_dim=self.rotary_embedding_dim,
@ -2467,6 +2469,7 @@ class SmoothQuantAttention(Module):
self.attention_mask_type = attention_mask_type self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_attention_heads = num_attention_heads // tp_size self.num_attention_heads = num_attention_heads // tp_size
self.num_kv_heads = num_kv_heads
self.num_attention_kv_heads = ( self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1 num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads ) // tp_size if num_kv_heads is not None else self.num_attention_heads
@ -2634,6 +2637,7 @@ class SmoothQuantAttention(Module):
layer_idx=self.local_layer_idx, layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads, num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size, hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling, q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_dim=self.rotary_embedding_dim,
@ -2987,6 +2991,7 @@ class QServeAttention(Module):
self.attention_mask_type = attention_mask_type self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
self.num_attention_heads = num_attention_heads // tp_size self.num_attention_heads = num_attention_heads // tp_size
self.num_kv_heads = num_kv_heads
self.num_attention_kv_heads = ( self.num_attention_kv_heads = (
num_kv_heads + tp_size - 1 num_kv_heads + tp_size - 1
) // tp_size if num_kv_heads is not None else self.num_attention_heads ) // tp_size if num_kv_heads is not None else self.num_attention_heads
@ -3154,6 +3159,7 @@ class QServeAttention(Module):
layer_idx=self.local_layer_idx, layer_idx=self.local_layer_idx,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
num_kv_heads=self.num_attention_kv_heads, num_kv_heads=self.num_attention_kv_heads,
num_kv_heads_origin=self.num_kv_heads,
hidden_size_per_head=self.attention_head_size, hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling, q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_dim=self.rotary_embedding_dim,

View File

@ -905,6 +905,8 @@ def quantize_and_export(*,
with open(f"{export_path}/config.json", "r") as f: with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f) tensorrt_llm_config = json.load(f)
tensorrt_llm_config["mapping"]["cp_size"] = cp_size tensorrt_llm_config["mapping"]["cp_size"] = cp_size
tensorrt_llm_config["mapping"]["attn_tp_size"] = -1
tensorrt_llm_config["mapping"]["attn_cp_size"] = -1
tensorrt_llm_config["mapping"]["world_size"] *= cp_size tensorrt_llm_config["mapping"]["world_size"] *= cp_size
with open(f"{export_path}/config.json", "w") as f: with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4) json.dump(tensorrt_llm_config, f, indent=4)