mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix: fix for cp > kvHeadNum (#3002)
* fix for cp > kvHeadNum Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> * fix for None kv_head_num Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> --------- Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
This commit is contained in:
parent
25f2434495
commit
1ac0566a93
@ -166,9 +166,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
xqaParams = {};
|
||||
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
|
||||
|
||||
// TODO(ziqingc): A better description for these parameters affected by CP size
|
||||
xqaParams.num_q_heads = mNumHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
|
||||
xqaParams.num_kv_heads = mNumKVHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
|
||||
xqaParams.num_q_heads = mNumAttnHeads;
|
||||
xqaParams.num_kv_heads = mNumAttnKVHeads;
|
||||
xqaParams.head_size = mHeadSize;
|
||||
xqaParams.unidirectional = mUnidirectional;
|
||||
xqaParams.q_scaling = mQScaling;
|
||||
@ -265,15 +264,151 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int AttentionOp::ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
|
||||
{
|
||||
int32_t partialTokenNum = 0;
|
||||
int32_t maxPartialLength = 0;
|
||||
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
||||
{
|
||||
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
||||
maxPartialLength = std::max(maxPartialLength, partialLength);
|
||||
partialTokenNum += partialLength;
|
||||
}
|
||||
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
|
||||
|
||||
// full request: [bs, seqlen, head, headSize]
|
||||
//
|
||||
// input of cp: [bs, partialLength, head, headSize]
|
||||
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
|
||||
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
|
||||
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
|
||||
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
|
||||
// view_2 as [bs, totalLength, partialHead, headSize]
|
||||
// and this is same to the input under TP.
|
||||
//
|
||||
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
|
||||
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
|
||||
// input of cp: [partialNumTokens, head, headSize]
|
||||
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
|
||||
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
|
||||
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
|
||||
// transpose_2 as [NumTokens, partialHead, headSize]
|
||||
// and this is same to the input under TP.
|
||||
|
||||
// view_1 + transpose_1
|
||||
invokeCpTranspose(output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
|
||||
mUlyssesMQABroadcast, getHeadSize(), mCpRank, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
// Do all to all
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
ncclGroupStart();
|
||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||
{
|
||||
if (cpIdx != mCpRank)
|
||||
{
|
||||
NCCLCHECK(ncclSend(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||
stream));
|
||||
NCCLCHECK(ncclRecv(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||
stream));
|
||||
}
|
||||
}
|
||||
ncclGroupEnd();
|
||||
sync_check_cuda_error(stream);
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
// transpose_2 + view_2
|
||||
invokeCpTranspose2(output, buffer, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens, mCpSize,
|
||||
maxPartialLength, params.batch_size, partialHeads, getHeadSize(), stream);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int AttentionOp::ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
|
||||
{
|
||||
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
|
||||
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
|
||||
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
|
||||
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
||||
// view: [partialTokens, head, headSize]
|
||||
|
||||
int32_t maxPartialLength = 0;
|
||||
int32_t partialTokenNum = 0;
|
||||
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
||||
{
|
||||
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
||||
maxPartialLength = std::max(maxPartialLength, partialLength);
|
||||
partialTokenNum += partialLength;
|
||||
}
|
||||
|
||||
// transpose_2_reverse
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(buffer),
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(input), params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
|
||||
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeCpTransposeToSeqMajor2(buffer, input, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
|
||||
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
|
||||
}
|
||||
|
||||
// all-to-all
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
|
||||
ncclGroupStart();
|
||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||
{
|
||||
if (cpIdx != mCpRank)
|
||||
{
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||
cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||
cpIdx, *mCpNcclComm, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
NCCLCHECK(ncclSend(
|
||||
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(
|
||||
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
ncclGroupEnd();
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
// transpose_1_reverse + view
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(buffer), reinterpret_cast<__nv_fp8_e4m3 const*>(input),
|
||||
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeCpTransposeToSeqMajor<T>(
|
||||
(T*) output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int AttentionOp::ulyssesGenerationPreprocess(
|
||||
int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream)
|
||||
T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
|
||||
{
|
||||
if (mCpSize <= 1)
|
||||
return 0;
|
||||
|
||||
auto const partialQHeads = mNumHeads / mCpSize;
|
||||
auto const partialKVHeads = mNumKVHeads / mCpSize;
|
||||
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
|
||||
|
||||
// attention_input shape: [partialTokenNum, numHeads, headSize]
|
||||
@ -284,27 +419,26 @@ int AttentionOp::ulyssesGenerationPreprocess(
|
||||
|
||||
// do transpose_1
|
||||
// [1, mNumHeads + 2*mNumKVHeads, headSize]
|
||||
// -> (view) [1, cpSize * partialQHeads + cpSize * partialKVHeads + cpSize * partilKVHeads,
|
||||
// -> (view) [1, cpSize * mNumAttnHeads + cpSize * mNumAttnKVHeads + cpSize * partilKVHeads,
|
||||
// headSize]
|
||||
// -> (transpose) [cpSize, 1, partialQHeads + partialKvHeads + partialKVHeads, headSize]
|
||||
invokeCpTranspose(mhaOutput, mhaInput, input, partialTokenNum, mCpSize, partialQHeads, partialKVHeads, mHeadSize,
|
||||
mCpRank, stream);
|
||||
// -> (transpose) [cpSize, 1, mNumAttnHeads + mNumAttnKVHeads + mNumAttnKVHeads, headSize]
|
||||
invokeCpTranspose(buffer, output, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
|
||||
mUlyssesMQABroadcast, mHeadSize, mCpRank, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
// Do all to all
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
|
||||
auto const partialHeads = totalHeads / mCpSize;
|
||||
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
|
||||
|
||||
ncclGroupStart();
|
||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||
{
|
||||
if (cpIdx != mCpRank)
|
||||
{
|
||||
NCCLCHECK(ncclSend(mhaOutput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
NCCLCHECK(ncclSend(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||
stream));
|
||||
NCCLCHECK(ncclRecv(mhaInput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
NCCLCHECK(ncclRecv(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||
stream));
|
||||
}
|
||||
@ -312,14 +446,11 @@ int AttentionOp::ulyssesGenerationPreprocess(
|
||||
ncclGroupEnd();
|
||||
sync_check_cuda_error(stream);
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
input = mhaInput;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int AttentionOp::ulyssesGenerationPostprocess(
|
||||
int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream)
|
||||
int AttentionOp::ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
|
||||
{
|
||||
if (mCpSize <= 1)
|
||||
return 0;
|
||||
@ -330,12 +461,11 @@ int AttentionOp::ulyssesGenerationPostprocess(
|
||||
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
||||
// view: [partialTokens, head, headSize]
|
||||
|
||||
auto partialHeads = mNumHeads / mCpSize;
|
||||
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
|
||||
|
||||
// do all-to-all
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
|
||||
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
|
||||
ncclGroupStart();
|
||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||
{
|
||||
@ -343,17 +473,17 @@ int AttentionOp::ulyssesGenerationPostprocess(
|
||||
{
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(mhaOutput) + cpIdx * elementNum, elementNum,
|
||||
ncclInt8, cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(mhaInput) + cpIdx * elementNum, elementNum,
|
||||
ncclInt8, cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||
cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||
cpIdx, *mCpNcclComm, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
NCCLCHECK(ncclSend(
|
||||
mhaOutput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(
|
||||
mhaInput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -364,15 +494,14 @@ int AttentionOp::ulyssesGenerationPostprocess(
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(mhaOutput), reinterpret_cast<__nv_fp8_e4m3 const*>(mhaInput),
|
||||
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(input), reinterpret_cast<__nv_fp8_e4m3 const*>(buffer),
|
||||
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeCpTransposeToSeqMajor<T>(
|
||||
(T*) output, mhaOutput, mhaInput, partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
|
||||
(T*) output, input, buffer, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||
}
|
||||
sync_check_cuda_error(stream);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -532,8 +661,8 @@ int AttentionOp::getHeadSize(bool checkInit) const
|
||||
size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t max_num_seq, int32_t input_seq_length,
|
||||
int32_t cross_kv_length, int32_t max_num_tokens) const noexcept
|
||||
{
|
||||
int const local_hidden_units_qo = mNumHeads / mCpSize * getHeadSize();
|
||||
int const local_hidden_units_kv = mNumKVHeads / mCpSize * getHeadSize();
|
||||
int const local_hidden_units_qo = mNumAttnHeads * getHeadSize();
|
||||
int const local_hidden_units_kv = mNumAttnKVHeads * getHeadSize();
|
||||
|
||||
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
||||
|
||||
@ -815,9 +944,8 @@ int AttentionOp::mlaGeneration(
|
||||
tllmRunnerParams.mHeadDimQk = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
|
||||
tllmRunnerParams.mHeadDimV = mMLAParams.kv_lora_rank;
|
||||
|
||||
auto const num_q_heads = mNumHeads / mCpSize;
|
||||
auto const num_q_heads = mNumAttnHeads;
|
||||
tllmRunnerParams.mNumHeadsQ = num_q_heads;
|
||||
// const auto num_kv_heads_tllm_runner_params = mNumKVHeads / mCpSize;
|
||||
tllmRunnerParams.mNumHeadsKv = num_kv_heads;
|
||||
tllmRunnerParams.mNumHeadsQPerKv = num_q_heads / num_kv_heads;
|
||||
|
||||
@ -1007,13 +1135,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
int const headSize = getHeadSize();
|
||||
|
||||
int const local_hidden_units_qo = mNumHeads * headSize;
|
||||
int const local_hidden_units_kv = mNumKVHeads / mCpSize * headSize;
|
||||
int const local_hidden_units_kv = mNumAttnKVHeads * headSize;
|
||||
PositionEmbeddingType const position_embedding_type = mPositionEmbeddingType;
|
||||
float const q_scaling = mQScaling;
|
||||
|
||||
KVCacheBuffer kv_cache_buffer;
|
||||
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
||||
auto sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize;
|
||||
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
|
||||
if (useKVCache())
|
||||
{
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
@ -1252,72 +1380,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
// do all-to-all for params.attention_input, need to split on kv head
|
||||
// [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size]
|
||||
T* attention_input = const_cast<T*>(params.attention_input);
|
||||
if (mCpSize > 1)
|
||||
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||
{
|
||||
int32_t partialTokenNum = 0;
|
||||
int32_t maxPartialLength = 0;
|
||||
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
||||
{
|
||||
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
||||
maxPartialLength = std::max(maxPartialLength, partialLength);
|
||||
partialTokenNum += partialLength;
|
||||
}
|
||||
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
|
||||
auto const partialHeads = totalHeads / mCpSize;
|
||||
auto const partialQHeads = mNumHeads / mCpSize;
|
||||
auto const partialKVHeads = mNumKVHeads / mCpSize;
|
||||
|
||||
// full request: [bs, seqlen, head, headSize]
|
||||
//
|
||||
// input of cp: [bs, partialLength, head, headSize]
|
||||
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
|
||||
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
|
||||
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
|
||||
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
|
||||
// view_2 as [bs, totalLength, partialHead, headSize]
|
||||
// and this is same to the input under TP.
|
||||
//
|
||||
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
|
||||
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
|
||||
// input of cp: [partialNumTokens, head, headSize]
|
||||
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
|
||||
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
|
||||
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
|
||||
// transpose_2 as [NumTokens, partialHead, headSize]
|
||||
// and this is same to the input under TP.
|
||||
|
||||
// view_1 + transpose_1
|
||||
invokeCpTranspose(gatherInBuffer, gatherOutBuffer, params.attention_input, partialTokenNum, mCpSize,
|
||||
partialQHeads, partialKVHeads, getHeadSize(), mCpRank, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
// Do all to all
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
ncclGroupStart();
|
||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||
{
|
||||
if (cpIdx != mCpRank)
|
||||
{
|
||||
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||
stream));
|
||||
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||
stream));
|
||||
}
|
||||
}
|
||||
ncclGroupEnd();
|
||||
sync_check_cuda_error(stream);
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
// transpose_2 + view_2
|
||||
invokeCpTranspose2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
|
||||
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
|
||||
stream);
|
||||
|
||||
this->template ulyssesContextPreprocess<T>(
|
||||
attention_input, gatherInBuffer, gatherOutBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
|
||||
attention_input = gatherInBuffer;
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
bool const enablePagedKVContextFMHA = mPagedKVCache && mPagedContextFMHA;
|
||||
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasInt8KvCache() && enablePagedKVContextFMHA),
|
||||
@ -1363,9 +1432,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
preprocessingParams.token_num = params.num_tokens;
|
||||
preprocessingParams.remove_padding = mRemovePadding;
|
||||
preprocessingParams.cross_attention = isCrossAttention();
|
||||
preprocessingParams.head_num = mNumHeads / mCpSize;
|
||||
preprocessingParams.kv_head_num = mNumKVHeads / mCpSize;
|
||||
preprocessingParams.qheads_per_kv_head = mNumHeads / mNumKVHeads;
|
||||
preprocessingParams.head_num = mNumAttnHeads;
|
||||
preprocessingParams.kv_head_num = mNumAttnKVHeads;
|
||||
preprocessingParams.qheads_per_kv_head = mNumAttnHeads / mNumAttnKVHeads;
|
||||
preprocessingParams.size_per_head = getHeadSize();
|
||||
preprocessingParams.rotary_embedding_dim = mRotaryEmbeddingDim;
|
||||
preprocessingParams.rotary_embedding_base = mRotaryEmbeddingBase;
|
||||
@ -1479,79 +1548,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
invokeKvCachePostprocessing(preprocessingParams, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
if (mCpSize > 1)
|
||||
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||
{
|
||||
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
|
||||
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
|
||||
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
|
||||
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
||||
// view: [partialTokens, head, headSize]
|
||||
|
||||
int32_t maxPartialLength = 0;
|
||||
int32_t partialTokenNum = 0;
|
||||
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
||||
{
|
||||
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
||||
maxPartialLength = std::max(maxPartialLength, partialLength);
|
||||
partialTokenNum += partialLength;
|
||||
}
|
||||
auto partialHeads = mNumHeads / mCpSize;
|
||||
|
||||
// transpose_2_reverse
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer),
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), params.context_lengths, cu_q_seqlens,
|
||||
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
|
||||
stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeCpTransposeToSeqMajor2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
|
||||
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
|
||||
stream);
|
||||
}
|
||||
|
||||
// all-to-all
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
|
||||
ncclGroupStart();
|
||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||
{
|
||||
if (cpIdx != mCpRank)
|
||||
{
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer) + cpIdx * elementNum,
|
||||
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(gatherOutBuffer) + cpIdx * elementNum,
|
||||
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
|
||||
cpIdx, *mCpNcclComm, stream));
|
||||
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
|
||||
cpIdx, *mCpNcclComm, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
ncclGroupEnd();
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
// transpose_1_reverse + view
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(params.context_buf),
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherInBuffer),
|
||||
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), partialTokenNum, mCpSize, partialHeads,
|
||||
getHeadSize(), mCpRank, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeCpTransposeToSeqMajor<T>((T*) params.context_buf, gatherInBuffer, gatherOutBuffer,
|
||||
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
|
||||
}
|
||||
this->template ulyssesContextPostprocess<T>(gatherOutBuffer, reinterpret_cast<T*>(params.context_buf),
|
||||
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
}
|
||||
@ -1871,7 +1871,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
|
||||
KVCacheBuffer kv_cache_buffer;
|
||||
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
||||
auto const sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize;
|
||||
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
|
||||
if (useKVCache())
|
||||
{
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
@ -1936,7 +1936,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
T* mhaInput = mhaOutput + cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize;
|
||||
|
||||
T* attention_input = const_cast<T*>(params.attention_input);
|
||||
this->template ulyssesGenerationPreprocess<T>(batch_beam, mhaInput, mhaOutput, attention_input, stream);
|
||||
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||
{
|
||||
this->template ulyssesGenerationPreprocess<T>(attention_input, mhaInput, mhaOutput, batch_beam, stream);
|
||||
attention_input = mhaInput;
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
// Try XQA optimization first.
|
||||
{
|
||||
@ -1954,7 +1959,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
xqaParams.qkv = attention_input;
|
||||
}
|
||||
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
|
||||
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream);
|
||||
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||
{
|
||||
this->template ulyssesGenerationPostprocess<T>(
|
||||
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
|
||||
@ -2040,8 +2050,8 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
dispatch_params.max_batch_size = batch_beam;
|
||||
dispatch_params.inference_batch_size = batch_beam;
|
||||
dispatch_params.beam_width = params.beam_width;
|
||||
dispatch_params.head_num = mNumHeads / mCpSize;
|
||||
dispatch_params.kv_head_num = mNumKVHeads / mCpSize;
|
||||
dispatch_params.head_num = mNumAttnHeads;
|
||||
dispatch_params.kv_head_num = mNumAttnKVHeads;
|
||||
dispatch_params.size_per_head = getHeadSize();
|
||||
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
|
||||
dispatch_params.position_embedding_type = mPositionEmbeddingType;
|
||||
@ -2104,7 +2114,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
|
||||
}
|
||||
|
||||
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream);
|
||||
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||
{
|
||||
this->template ulyssesGenerationPostprocess<T>(
|
||||
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -2164,6 +2179,21 @@ template void AttentionOp::prepareEnqueueGeneration<__nv_bfloat16, KVBlockArray>
|
||||
|
||||
int AttentionOp::initialize() noexcept
|
||||
{
|
||||
// use Ulysses for GPTAttentionPlugin
|
||||
if (mAttnTpSize < 0 || mAttnCpSize < 0)
|
||||
{
|
||||
mAttnTpSize = mTpSize * mCpSize;
|
||||
mAttnCpSize = 1;
|
||||
}
|
||||
mNumAttnHeads = mNumHeads * mTpSize / mAttnTpSize;
|
||||
mNumAttnKVHeads = (mNumKVHeads * mTpSize + mAttnTpSize - 1) / mAttnTpSize;
|
||||
|
||||
if (mCpSize != mAttnCpSize)
|
||||
{
|
||||
// mqa broadcast
|
||||
mUlyssesMQABroadcast = (mAttnTpSize + mNumKVHeadsOrigin - 1) / mNumKVHeadsOrigin;
|
||||
}
|
||||
|
||||
// Pre-check whether FMHA is supported in order to save memory allocation.
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
@ -2306,8 +2336,8 @@ int AttentionOp::initialize() noexcept
|
||||
fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
|
||||
}
|
||||
fmhaParams.isSPadded = !mRemovePadding;
|
||||
fmhaParams.numQHeads = mNumHeads / mCpSize;
|
||||
fmhaParams.numKvHeads = mNumKVHeads / mCpSize;
|
||||
fmhaParams.numQHeads = mNumAttnHeads;
|
||||
fmhaParams.numKvHeads = mNumAttnKVHeads;
|
||||
fmhaParams.numTokensPerBlock = mTokensPerBlock;
|
||||
fmhaParams.headSize = mHeadSize;
|
||||
fmhaParams.headSizeV = mHeadSize;
|
||||
@ -2326,16 +2356,6 @@ int AttentionOp::initialize() noexcept
|
||||
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
|
||||
fmhaParams.hasAlibi = isALiBi();
|
||||
fmhaParams.scaleAlibi = isAliBiWithScale();
|
||||
if (mTpSize > 1)
|
||||
{
|
||||
fmhaParams.tpSize = mTpSize;
|
||||
fmhaParams.tpRank = mTpRank;
|
||||
}
|
||||
else if (mCpSize > 1)
|
||||
{
|
||||
fmhaParams.tpSize = mCpSize;
|
||||
fmhaParams.tpRank = mCpRank;
|
||||
}
|
||||
|
||||
// Load kernels from the pre-compiled cubins.
|
||||
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
|
||||
@ -2479,8 +2499,8 @@ int AttentionOp::initialize() noexcept
|
||||
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
|
||||
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
|
||||
}
|
||||
fixedParams.numQHeads = mNumHeads / mCpSize;
|
||||
fixedParams.numKvHeads = mNumKVHeads / mCpSize;
|
||||
fixedParams.numQHeads = mNumAttnHeads;
|
||||
fixedParams.numKvHeads = mNumAttnKVHeads;
|
||||
fixedParams.numTokensPerBlock = mTokensPerBlock;
|
||||
fixedParams.headSize = mHeadSize;
|
||||
fixedParams.qScaling = mQScaling;
|
||||
@ -2555,6 +2575,7 @@ std::string AttentionOp::toString() const
|
||||
ss << "gptAttentionCommon members ====================" << std::endl;
|
||||
ss << "mNumHeads: " << mNumHeads << std::endl;
|
||||
ss << "mNumKVHeads: " << mNumKVHeads << std::endl;
|
||||
ss << "mNumKVHeadsOrigin: " << mNumKVHeadsOrigin << std::endl;
|
||||
ss << "mHeadSize: " << mHeadSize << std::endl;
|
||||
ss << "mUnidirectional: " << mUnidirectional << std::endl;
|
||||
ss << "mQScaling: " << mQScaling << std::endl;
|
||||
|
||||
@ -229,10 +229,18 @@ public:
|
||||
EnqueueGenerationParams<T> const& generationsParams, bool forConfigurePlugin);
|
||||
|
||||
template <typename T>
|
||||
int ulyssesGenerationPreprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream);
|
||||
int ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
int ulyssesGenerationPostprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream);
|
||||
int ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
int ulyssesGenerationPreprocess(T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
int ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
|
||||
|
||||
[[nodiscard]] bool isRelativePosition() const
|
||||
{
|
||||
@ -374,6 +382,16 @@ public:
|
||||
int mCpSize = 1;
|
||||
int mCpRank = 0;
|
||||
std::set<int32_t> mCpGroup = {};
|
||||
// These parameters are used to specifically configure the attention attributes when cp/tp_size are different
|
||||
// between Attention and FFN(such as Ulysses)
|
||||
int mNumAttnHeads = -1;
|
||||
int mNumAttnKVHeads = -1;
|
||||
int mNumKVHeadsOrigin = -1;
|
||||
int mAttnTpSize = -1;
|
||||
int mAttnTpRank = 0;
|
||||
int mAttnCpSize = -1;
|
||||
int mAttnCpRank = 0;
|
||||
int mUlyssesMQABroadcast = 1;
|
||||
|
||||
// fmha runner (enabled by default)
|
||||
// flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2
|
||||
@ -403,8 +421,9 @@ public:
|
||||
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
|
||||
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
|
||||
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
|
||||
mCpGroup, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn,
|
||||
mFuseFp4Quant, mNbMultiBlockSemaphores);
|
||||
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
|
||||
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
|
||||
mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
@ -2139,7 +2139,7 @@ __global__ void convertData(Dst* dst, Src const* src, int64_t size, float const*
|
||||
|
||||
template <typename T>
|
||||
__global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize,
|
||||
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank)
|
||||
int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank)
|
||||
{
|
||||
// Do transpose from
|
||||
// [partialTokenNum, mNumHeads + 2*mNumKVHeads, headSize]
|
||||
@ -2164,12 +2164,12 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
|
||||
}
|
||||
else if (headIdx < partialQHeads + partialKVHeads)
|
||||
{
|
||||
srcHeadIdx = cpSize * partialQHeads + cpIdx * partialKVHeads + (headIdx - partialQHeads);
|
||||
srcHeadIdx = cpSize * partialQHeads + cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads);
|
||||
}
|
||||
else
|
||||
{
|
||||
srcHeadIdx = cpSize * partialQHeads + cpSize * partialKVHeads + cpIdx * partialKVHeads
|
||||
+ (headIdx - partialQHeads - partialKVHeads);
|
||||
srcHeadIdx = cpSize * partialQHeads + cpSize / mqaBroadcast * partialKVHeads
|
||||
+ cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads - partialKVHeads);
|
||||
}
|
||||
|
||||
if (cpIdx == rank)
|
||||
@ -2181,7 +2181,7 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
|
||||
+ seqIdx * (partialQHeads + 2 * partialKVHeads) + headIdx)
|
||||
* headSize);
|
||||
VecType const* in = reinterpret_cast<VecType const*>(
|
||||
src + (seqIdx * (partialQHeads + 2 * partialKVHeads) * cpSize + srcHeadIdx) * headSize);
|
||||
src + (seqIdx * (partialQHeads * cpSize + 2 * partialKVHeads * cpSize / mqaBroadcast) + srcHeadIdx) * headSize);
|
||||
|
||||
for (int hiddenIdx = threadIdx.x; hiddenIdx < hiddenSize + hiddenRestSize; hiddenIdx += blockDim.x)
|
||||
{
|
||||
@ -2347,17 +2347,18 @@ INSTANTIATE_invokeConversion(__nv_fp8_e4m3, __nv_bfloat16);
|
||||
|
||||
template <typename T>
|
||||
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, int64_t partialQHeads,
|
||||
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream)
|
||||
int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(partialTokenNum, cpSize, partialQHeads + 2 * partialKVHeads);
|
||||
dim3 block(128);
|
||||
runCpTranspose<T><<<grid, block, 0, stream>>>(
|
||||
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, headSize, rank);
|
||||
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, mqaBroadcast, headSize, rank);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_invokeCpTranspose(T) \
|
||||
template void invokeCpTranspose<T>(T * dst, T * dst2, T const* src, int64_t partialLength, int64_t cpSize, \
|
||||
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream)
|
||||
int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, \
|
||||
cudaStream_t stream)
|
||||
INSTANTIATE_invokeCpTranspose(float);
|
||||
INSTANTIATE_invokeCpTranspose(half);
|
||||
INSTANTIATE_invokeCpTranspose(__nv_bfloat16);
|
||||
|
||||
@ -407,7 +407,7 @@ void invokeConversion(Dst* dst, Src const* src, int64_t size, float const* __res
|
||||
|
||||
template <typename T>
|
||||
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialLength, int64_t cpSize, int64_t partialQHeads,
|
||||
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream);
|
||||
int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void invokeCpTransposeToSeqMajor(T* dst, T const* srcMyRank, T const* srcOtherRank, int64_t partialLength,
|
||||
|
||||
@ -28,8 +28,8 @@ using tensorrt_llm::plugins::GPTAttentionPluginCreatorCommon;
|
||||
using tensorrt_llm::plugins::GPTAttentionPluginCommon;
|
||||
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
|
||||
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
|
||||
@ -52,6 +52,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
mVisionStart = vision_start;
|
||||
mVisionLength = vision_length;
|
||||
mNumKVHeads = num_kv_heads;
|
||||
mNumKVHeadsOrigin = num_kv_heads_origin;
|
||||
mHeadSize = head_size;
|
||||
mUnidirectional = unidirectional;
|
||||
mQScaling = q_scaling;
|
||||
@ -114,6 +115,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
||||
read(d, mVisionStart);
|
||||
read(d, mVisionLength);
|
||||
read(d, mNumKVHeads);
|
||||
read(d, mNumKVHeadsOrigin);
|
||||
read(d, mHeadSize);
|
||||
read(d, mUnidirectional);
|
||||
read(d, mQScaling);
|
||||
@ -201,13 +203,13 @@ void GPTAttentionPluginCommon::destroy() noexcept
|
||||
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
|
||||
+ sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mAttnLogitSoftcappingScale)
|
||||
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
|
||||
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingShortMscale)
|
||||
+ sizeof(mRotaryEmbeddingLongMscale) + sizeof(mRotaryEmbeddingMaxPositions)
|
||||
+ sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
|
||||
+ sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mNumKVHeadsOrigin) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
|
||||
+ sizeof(mAttnLogitSoftcappingScale) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
|
||||
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
|
||||
+ sizeof(mRotaryEmbeddingShortMscale) + sizeof(mRotaryEmbeddingLongMscale)
|
||||
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize)
|
||||
+ sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode)
|
||||
+ sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache)
|
||||
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
|
||||
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
|
||||
@ -228,6 +230,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mVisionStart);
|
||||
write(d, mVisionLength);
|
||||
write(d, mNumKVHeads);
|
||||
write(d, mNumKVHeadsOrigin);
|
||||
write(d, mHeadSize);
|
||||
write(d, mUnidirectional);
|
||||
write(d, mQScaling);
|
||||
@ -307,6 +310,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("num_kv_heads_origin", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("layer_idx_in_cache_pool", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32));
|
||||
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32));
|
||||
|
||||
@ -35,7 +35,7 @@ public:
|
||||
GPTAttentionPluginCommon() = delete;
|
||||
|
||||
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||
int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
|
||||
@ -46,8 +46,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
|
||||
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
|
||||
|
||||
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
|
||||
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
|
||||
@ -65,17 +65,17 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
|
||||
int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank,
|
||||
int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size,
|
||||
int cp_rank, std::set<int32_t> cp_group)
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
|
||||
unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
|
||||
rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size,
|
||||
tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type, kv_cache_quant_mode, remove_input_padding,
|
||||
mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled,
|
||||
cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, use_paged_context_fmha,
|
||||
use_fp8_context_fmha, has_full_attention_mask, use_cache, is_spec_decoding_enabled,
|
||||
spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank,
|
||||
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant, skip_attn, cp_size, cp_rank,
|
||||
cp_group)
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, num_kv_heads_origin,
|
||||
head_size, unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type,
|
||||
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions,
|
||||
rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type,
|
||||
kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block,
|
||||
type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
|
||||
dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, use_cache,
|
||||
is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length,
|
||||
is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant,
|
||||
skip_attn, cp_size, cp_rank, cp_group)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
!is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead.");
|
||||
@ -890,7 +890,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
|
||||
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
|
||||
|
||||
auto const blockSize = mTokensPerBlock * mNumKVHeads / mCpSize * mHeadSize;
|
||||
auto const kv_cache_head_num = (mNumKVHeads + mCpSize - 1) / mCpSize;
|
||||
auto const blockSize = mTokensPerBlock * kv_cache_head_num * mHeadSize;
|
||||
auto const bytesPerBlock = blockSize * cacheElemSize;
|
||||
auto const layerOffset = layerIdxInCachePool * 2 * bytesPerBlock;
|
||||
|
||||
@ -1311,8 +1312,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
|
||||
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
|
||||
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
|
||||
p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
|
||||
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
|
||||
p.getScalar<float>("q_scaling").value(), p.getScalar<float>("attn_logit_softcapping_scale").value(),
|
||||
p.getScalar<int32_t>("num_kv_heads_origin").value(), p.getScalar<int32_t>("head_size").value(),
|
||||
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
|
||||
p.getScalar<float>("attn_logit_softcapping_scale").value(),
|
||||
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
|
||||
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
|
||||
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
|
||||
|
||||
@ -101,7 +101,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
|
||||
{
|
||||
public:
|
||||
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||
int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
|
||||
@ -236,6 +236,8 @@ def convert_and_save_hf(args):
|
||||
load_model_on_cpu=args.load_model_on_cpu,
|
||||
**override_fields)
|
||||
glm.config.mapping.cp_size = args.cp_size
|
||||
glm.config.mapping.attn_tp_size = -1
|
||||
glm.config.mapping.attn_cp_size = -1
|
||||
glm.config.mapping.world_size *= args.cp_size
|
||||
glm.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||
del glm
|
||||
|
||||
@ -393,6 +393,8 @@ def convert_and_save_meta(args, rank):
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
embedding_sharding_dim=args.embedding_sharding_dim)
|
||||
llama.config.mapping.cp_size = args.cp_size
|
||||
llama.config.mapping.attn_tp_size = -1
|
||||
llama.config.mapping.attn_cp_size = -1
|
||||
llama.config.mapping.world_size *= args.cp_size
|
||||
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||
|
||||
@ -511,6 +513,8 @@ def convert_and_save_hf(args):
|
||||
f'Total time of reading and converting: {time.time()-tik:.3f} s'
|
||||
)
|
||||
llama.config.mapping.cp_size = args.cp_size
|
||||
llama.config.mapping.attn_tp_size = -1
|
||||
llama.config.mapping.attn_cp_size = -1
|
||||
llama.config.mapping.world_size *= args.cp_size
|
||||
tik = time.time()
|
||||
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||
|
||||
@ -281,6 +281,8 @@ def convert_and_save_hf(args):
|
||||
quant_config=quant_config,
|
||||
**override_fields)
|
||||
qwen.config.mapping.cp_size = args.cp_size
|
||||
qwen.config.mapping.attn_tp_size = -1
|
||||
qwen.config.mapping.attn_cp_size = -1
|
||||
qwen.config.mapping.world_size *= args.cp_size
|
||||
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||
del qwen
|
||||
|
||||
@ -5213,6 +5213,7 @@ def gpt_attention(
|
||||
cp_group: List[int] = [0],
|
||||
cp_size: int = 1,
|
||||
cp_rank: int = 0,
|
||||
num_kv_heads_origin: int = -1,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
'''
|
||||
Add an operation that performs the multi-head attention in GPT-like models.
|
||||
@ -5474,6 +5475,9 @@ def gpt_attention(
|
||||
skip_attn: Tensor = None,
|
||||
A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.
|
||||
|
||||
num_kv_heads_origin: int
|
||||
The origin number of KV heads, without the process of TP
|
||||
|
||||
Returns:
|
||||
The tensor produced by that layer.
|
||||
'''
|
||||
@ -5505,6 +5509,9 @@ def gpt_attention(
|
||||
else:
|
||||
use_logn_scaling = 0
|
||||
|
||||
if num_kv_heads_origin < 1:
|
||||
num_kv_heads_origin = num_kv_heads
|
||||
|
||||
unfuse_qkv_gemm = trt.PluginField(
|
||||
"unfuse_qkv_gemm", np.array(np.int8(is_unfuse_qkv_gemm), dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
@ -5523,6 +5530,9 @@ def gpt_attention(
|
||||
num_kv_heads = trt.PluginField("num_kv_heads",
|
||||
np.array(num_kv_heads, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
num_kv_heads_origin = trt.PluginField(
|
||||
"num_kv_heads_origin", np.array(num_kv_heads_origin, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
head_size = trt.PluginField("head_size",
|
||||
np.array(hidden_size_per_head, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
@ -5714,9 +5724,10 @@ def gpt_attention(
|
||||
trt.PluginFieldType.INT8)
|
||||
|
||||
pfc = trt.PluginFieldCollection([
|
||||
layer_idx, nheads, vision_start, vision_length, num_kv_heads, head_size,
|
||||
unidirectional, q_scaling, attn_logit_softcapping_scale,
|
||||
position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
|
||||
layer_idx, nheads, vision_start, vision_length, num_kv_heads,
|
||||
num_kv_heads_origin, head_size, unidirectional, q_scaling,
|
||||
attn_logit_softcapping_scale, position_embedding_type,
|
||||
rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale,
|
||||
rotary_embedding_max_positions, rotary_embedding_original_max_positions,
|
||||
|
||||
@ -395,13 +395,13 @@ class Attention(Module):
|
||||
self.cross_attention = cross_attention
|
||||
self.attention_mask_type = attention_mask_type
|
||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
assert num_attention_heads % tp_size == 0, \
|
||||
"num_attention_heads must be divisible by tp_size"
|
||||
self.num_attention_heads = num_attention_heads // tp_size
|
||||
self.num_attention_kv_heads = (
|
||||
num_kv_heads + tp_size - 1
|
||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else self.num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
@ -1058,6 +1058,7 @@ class Attention(Module):
|
||||
layer_idx=self.local_layer_idx,
|
||||
num_heads=self.num_attention_heads,
|
||||
num_kv_heads=self.num_attention_kv_heads,
|
||||
num_kv_heads_origin=self.num_kv_heads,
|
||||
hidden_size_per_head=self.attention_head_size,
|
||||
q_scaling=self.q_scaling,
|
||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||
@ -1866,6 +1867,7 @@ class CogVLMAttention(Attention):
|
||||
layer_idx=self.local_layer_idx,
|
||||
num_heads=self.num_attention_heads,
|
||||
num_kv_heads=self.num_attention_kv_heads,
|
||||
num_kv_heads_origin=self.num_kv_heads,
|
||||
hidden_size_per_head=self.attention_head_size,
|
||||
q_scaling=self.q_scaling,
|
||||
position_embedding_type=self.position_embedding_type,
|
||||
@ -2216,6 +2218,7 @@ class DeepseekV2Attention(Attention):
|
||||
layer_idx=self.local_layer_idx,
|
||||
num_heads=self.num_attention_heads,
|
||||
num_kv_heads=1,
|
||||
num_kv_heads_origin=1,
|
||||
hidden_size_per_head=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
q_scaling=self.q_scaling,
|
||||
position_embedding_type=self.position_embedding_type,
|
||||
|
||||
@ -123,6 +123,8 @@ class Mapping(object):
|
||||
pp_size=1,
|
||||
moe_tp_size=-1, # -1 means no moe
|
||||
moe_ep_size=-1, # -1 means no moe
|
||||
attn_tp_size=-1,
|
||||
attn_cp_size=-1,
|
||||
auto_parallel=False,
|
||||
enable_attention_dp=False):
|
||||
# set default values for non-moe cases
|
||||
@ -137,6 +139,22 @@ class Mapping(object):
|
||||
elif moe_ep_size == -1:
|
||||
moe_ep_size = tp_size // moe_tp_size
|
||||
|
||||
if attn_tp_size == -1 and attn_cp_size == -1:
|
||||
# fallback to ulysses
|
||||
attn_tp_size = tp_size * cp_size
|
||||
attn_cp_size = 1
|
||||
|
||||
elif attn_tp_size == -1:
|
||||
attn_tp_size = cp_size * tp_size // attn_cp_size
|
||||
|
||||
elif attn_cp_size == -1:
|
||||
attn_cp_size = cp_size * tp_size // attn_tp_size
|
||||
|
||||
if attn_cp_size != 1:
|
||||
raise ValueError(
|
||||
f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}."
|
||||
)
|
||||
|
||||
if auto_parallel:
|
||||
if tp_size != 1 or pp_size != 1 or tp_size != 1:
|
||||
raise ValueError(
|
||||
@ -154,6 +172,12 @@ class Mapping(object):
|
||||
f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
|
||||
)
|
||||
|
||||
attn_tp_cp_size = attn_tp_size * attn_cp_size
|
||||
if attn_tp_cp_size != tp_size * cp_size:
|
||||
raise ValueError(
|
||||
f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
|
||||
)
|
||||
|
||||
if moe_ep_size != 1 and cp_size > 1:
|
||||
raise NotImplementedError("CP don't support MoE tp/ep yet")
|
||||
|
||||
@ -163,6 +187,8 @@ class Mapping(object):
|
||||
self.pp_size = pp_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
self.moe_ep_size = moe_ep_size
|
||||
self.attn_tp_size = attn_tp_size
|
||||
self.attn_cp_size = attn_cp_size
|
||||
self.auto_parallel = auto_parallel
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
@ -218,6 +244,8 @@ class Mapping(object):
|
||||
and self.pp_size == other.pp_size
|
||||
and self.moe_tp_size == other.moe_tp_size
|
||||
and self.moe_ep_size == other.moe_ep_size
|
||||
and self.attn_tp_size == other.attn_tp_size
|
||||
and self.attn_cp_size == other.attn_cp_size
|
||||
and self.auto_parallel == other.auto_parallel)
|
||||
|
||||
def __hash__(self):
|
||||
@ -225,6 +253,7 @@ class Mapping(object):
|
||||
^ hash(self.gpus_per_node) ^ hash(self.cp_size)
|
||||
^ hash(self.tp_size) ^ hash(self.pp_size)
|
||||
^ hash(self.moe_tp_size) ^ hash(self.moe_ep_size)
|
||||
^ hash(self.attn_tp_size) ^ hash(self.attn_cp_size)
|
||||
^ hash(self.auto_parallel))
|
||||
|
||||
@property
|
||||
@ -375,5 +404,7 @@ class Mapping(object):
|
||||
'pp_size': self.pp_size,
|
||||
'moe_tp_size': self.moe_tp_size,
|
||||
'moe_ep_size': self.moe_ep_size,
|
||||
'attn_tp_size': self.attn_tp_size,
|
||||
'attn_cp_size': self.attn_cp_size,
|
||||
'auto_parallel': self.auto_parallel,
|
||||
}
|
||||
|
||||
@ -1850,6 +1850,7 @@ class Fp8RowwiseAttention(Module):
|
||||
self.attention_mask_type = attention_mask_type
|
||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||
self.num_attention_heads = num_attention_heads // tp_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_attention_kv_heads = (
|
||||
num_kv_heads + tp_size - 1
|
||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||
@ -2010,6 +2011,7 @@ class Fp8RowwiseAttention(Module):
|
||||
layer_idx=self.local_layer_idx,
|
||||
num_heads=self.num_attention_heads,
|
||||
num_kv_heads=self.num_attention_kv_heads,
|
||||
num_kv_heads_origin=self.num_kv_heads,
|
||||
hidden_size_per_head=self.attention_head_size,
|
||||
q_scaling=self.q_scaling,
|
||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||
@ -2467,6 +2469,7 @@ class SmoothQuantAttention(Module):
|
||||
self.attention_mask_type = attention_mask_type
|
||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||
self.num_attention_heads = num_attention_heads // tp_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_attention_kv_heads = (
|
||||
num_kv_heads + tp_size - 1
|
||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||
@ -2634,6 +2637,7 @@ class SmoothQuantAttention(Module):
|
||||
layer_idx=self.local_layer_idx,
|
||||
num_heads=self.num_attention_heads,
|
||||
num_kv_heads=self.num_attention_kv_heads,
|
||||
num_kv_heads_origin=self.num_kv_heads,
|
||||
hidden_size_per_head=self.attention_head_size,
|
||||
q_scaling=self.q_scaling,
|
||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||
@ -2987,6 +2991,7 @@ class QServeAttention(Module):
|
||||
self.attention_mask_type = attention_mask_type
|
||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||
self.num_attention_heads = num_attention_heads // tp_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_attention_kv_heads = (
|
||||
num_kv_heads + tp_size - 1
|
||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||
@ -3154,6 +3159,7 @@ class QServeAttention(Module):
|
||||
layer_idx=self.local_layer_idx,
|
||||
num_heads=self.num_attention_heads,
|
||||
num_kv_heads=self.num_attention_kv_heads,
|
||||
num_kv_heads_origin=self.num_kv_heads,
|
||||
hidden_size_per_head=self.attention_head_size,
|
||||
q_scaling=self.q_scaling,
|
||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||
|
||||
@ -905,6 +905,8 @@ def quantize_and_export(*,
|
||||
with open(f"{export_path}/config.json", "r") as f:
|
||||
tensorrt_llm_config = json.load(f)
|
||||
tensorrt_llm_config["mapping"]["cp_size"] = cp_size
|
||||
tensorrt_llm_config["mapping"]["attn_tp_size"] = -1
|
||||
tensorrt_llm_config["mapping"]["attn_cp_size"] = -1
|
||||
tensorrt_llm_config["mapping"]["world_size"] *= cp_size
|
||||
with open(f"{export_path}/config.json", "w") as f:
|
||||
json.dump(tensorrt_llm_config, f, indent=4)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user