mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: fix for cp > kvHeadNum (#3002)
* fix for cp > kvHeadNum Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> * fix for None kv_head_num Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> --------- Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
This commit is contained in:
parent
25f2434495
commit
1ac0566a93
@ -166,9 +166,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
|||||||
xqaParams = {};
|
xqaParams = {};
|
||||||
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
|
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
|
||||||
|
|
||||||
// TODO(ziqingc): A better description for these parameters affected by CP size
|
xqaParams.num_q_heads = mNumAttnHeads;
|
||||||
xqaParams.num_q_heads = mNumHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
|
xqaParams.num_kv_heads = mNumAttnKVHeads;
|
||||||
xqaParams.num_kv_heads = mNumKVHeads / mCpSize; // when we use CP, the MHA part is spilt like TP
|
|
||||||
xqaParams.head_size = mHeadSize;
|
xqaParams.head_size = mHeadSize;
|
||||||
xqaParams.unidirectional = mUnidirectional;
|
xqaParams.unidirectional = mUnidirectional;
|
||||||
xqaParams.q_scaling = mQScaling;
|
xqaParams.q_scaling = mQScaling;
|
||||||
@ -265,15 +264,151 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int AttentionOp::ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||||
|
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
int32_t partialTokenNum = 0;
|
||||||
|
int32_t maxPartialLength = 0;
|
||||||
|
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
||||||
|
{
|
||||||
|
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
||||||
|
maxPartialLength = std::max(maxPartialLength, partialLength);
|
||||||
|
partialTokenNum += partialLength;
|
||||||
|
}
|
||||||
|
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
|
||||||
|
|
||||||
|
// full request: [bs, seqlen, head, headSize]
|
||||||
|
//
|
||||||
|
// input of cp: [bs, partialLength, head, headSize]
|
||||||
|
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
|
||||||
|
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
|
||||||
|
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
|
||||||
|
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
|
||||||
|
// view_2 as [bs, totalLength, partialHead, headSize]
|
||||||
|
// and this is same to the input under TP.
|
||||||
|
//
|
||||||
|
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
|
||||||
|
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
|
||||||
|
// input of cp: [partialNumTokens, head, headSize]
|
||||||
|
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
|
||||||
|
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
|
||||||
|
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
|
||||||
|
// transpose_2 as [NumTokens, partialHead, headSize]
|
||||||
|
// and this is same to the input under TP.
|
||||||
|
|
||||||
|
// view_1 + transpose_1
|
||||||
|
invokeCpTranspose(output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
|
||||||
|
mUlyssesMQABroadcast, getHeadSize(), mCpRank, stream);
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
|
||||||
|
// Do all to all
|
||||||
|
#if ENABLE_MULTI_DEVICE
|
||||||
|
ncclGroupStart();
|
||||||
|
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||||
|
{
|
||||||
|
if (cpIdx != mCpRank)
|
||||||
|
{
|
||||||
|
NCCLCHECK(ncclSend(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||||
|
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||||
|
stream));
|
||||||
|
NCCLCHECK(ncclRecv(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||||
|
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||||
|
stream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ncclGroupEnd();
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
#endif // ENABLE_MULTI_DEVICE
|
||||||
|
|
||||||
|
// transpose_2 + view_2
|
||||||
|
invokeCpTranspose2(output, buffer, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens, mCpSize,
|
||||||
|
maxPartialLength, params.batch_size, partialHeads, getHeadSize(), stream);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int AttentionOp::ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||||
|
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
|
||||||
|
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
|
||||||
|
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
|
||||||
|
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
||||||
|
// view: [partialTokens, head, headSize]
|
||||||
|
|
||||||
|
int32_t maxPartialLength = 0;
|
||||||
|
int32_t partialTokenNum = 0;
|
||||||
|
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
||||||
|
{
|
||||||
|
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
||||||
|
maxPartialLength = std::max(maxPartialLength, partialLength);
|
||||||
|
partialTokenNum += partialLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
// transpose_2_reverse
|
||||||
|
if (mFP8ContextFMHA)
|
||||||
|
{
|
||||||
|
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(buffer),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3 const*>(input), params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
|
||||||
|
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
invokeCpTransposeToSeqMajor2(buffer, input, params.context_lengths, cu_q_seqlens, cu_cp_partial_seqlens,
|
||||||
|
mCpSize, maxPartialLength, params.batch_size, mNumAttnHeads, getHeadSize(), stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// all-to-all
|
||||||
|
#if ENABLE_MULTI_DEVICE
|
||||||
|
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
|
||||||
|
ncclGroupStart();
|
||||||
|
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||||
|
{
|
||||||
|
if (cpIdx != mCpRank)
|
||||||
|
{
|
||||||
|
if (mFP8ContextFMHA)
|
||||||
|
{
|
||||||
|
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||||
|
cpIdx, *mCpNcclComm, stream));
|
||||||
|
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||||
|
cpIdx, *mCpNcclComm, stream));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
NCCLCHECK(ncclSend(
|
||||||
|
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||||
|
NCCLCHECK(ncclRecv(
|
||||||
|
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ncclGroupEnd();
|
||||||
|
#endif // ENABLE_MULTI_DEVICE
|
||||||
|
|
||||||
|
// transpose_1_reverse + view
|
||||||
|
if (mFP8ContextFMHA)
|
||||||
|
{
|
||||||
|
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3 const*>(buffer), reinterpret_cast<__nv_fp8_e4m3 const*>(input),
|
||||||
|
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
invokeCpTransposeToSeqMajor<T>(
|
||||||
|
(T*) output, buffer, input, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int AttentionOp::ulyssesGenerationPreprocess(
|
int AttentionOp::ulyssesGenerationPreprocess(
|
||||||
int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream)
|
T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
if (mCpSize <= 1)
|
if (mCpSize <= 1)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
auto const partialQHeads = mNumHeads / mCpSize;
|
|
||||||
auto const partialKVHeads = mNumKVHeads / mCpSize;
|
|
||||||
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
|
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
|
||||||
|
|
||||||
// attention_input shape: [partialTokenNum, numHeads, headSize]
|
// attention_input shape: [partialTokenNum, numHeads, headSize]
|
||||||
@ -284,27 +419,26 @@ int AttentionOp::ulyssesGenerationPreprocess(
|
|||||||
|
|
||||||
// do transpose_1
|
// do transpose_1
|
||||||
// [1, mNumHeads + 2*mNumKVHeads, headSize]
|
// [1, mNumHeads + 2*mNumKVHeads, headSize]
|
||||||
// -> (view) [1, cpSize * partialQHeads + cpSize * partialKVHeads + cpSize * partilKVHeads,
|
// -> (view) [1, cpSize * mNumAttnHeads + cpSize * mNumAttnKVHeads + cpSize * partilKVHeads,
|
||||||
// headSize]
|
// headSize]
|
||||||
// -> (transpose) [cpSize, 1, partialQHeads + partialKvHeads + partialKVHeads, headSize]
|
// -> (transpose) [cpSize, 1, mNumAttnHeads + mNumAttnKVHeads + mNumAttnKVHeads, headSize]
|
||||||
invokeCpTranspose(mhaOutput, mhaInput, input, partialTokenNum, mCpSize, partialQHeads, partialKVHeads, mHeadSize,
|
invokeCpTranspose(buffer, output, input, partialTokenNum, mCpSize, mNumAttnHeads, mNumAttnKVHeads,
|
||||||
mCpRank, stream);
|
mUlyssesMQABroadcast, mHeadSize, mCpRank, stream);
|
||||||
sync_check_cuda_error(stream);
|
sync_check_cuda_error(stream);
|
||||||
|
|
||||||
// Do all to all
|
// Do all to all
|
||||||
#if ENABLE_MULTI_DEVICE
|
#if ENABLE_MULTI_DEVICE
|
||||||
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
|
auto const partialHeads = mNumAttnHeads + 2 * mNumAttnKVHeads;
|
||||||
auto const partialHeads = totalHeads / mCpSize;
|
|
||||||
|
|
||||||
ncclGroupStart();
|
ncclGroupStart();
|
||||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||||
{
|
{
|
||||||
if (cpIdx != mCpRank)
|
if (cpIdx != mCpRank)
|
||||||
{
|
{
|
||||||
NCCLCHECK(ncclSend(mhaOutput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
NCCLCHECK(ncclSend(buffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||||
stream));
|
stream));
|
||||||
NCCLCHECK(ncclRecv(mhaInput + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
NCCLCHECK(ncclRecv(output + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
||||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
||||||
stream));
|
stream));
|
||||||
}
|
}
|
||||||
@ -312,14 +446,11 @@ int AttentionOp::ulyssesGenerationPreprocess(
|
|||||||
ncclGroupEnd();
|
ncclGroupEnd();
|
||||||
sync_check_cuda_error(stream);
|
sync_check_cuda_error(stream);
|
||||||
#endif // ENABLE_MULTI_DEVICE
|
#endif // ENABLE_MULTI_DEVICE
|
||||||
|
|
||||||
input = mhaInput;
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int AttentionOp::ulyssesGenerationPostprocess(
|
int AttentionOp::ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream)
|
||||||
int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream)
|
|
||||||
{
|
{
|
||||||
if (mCpSize <= 1)
|
if (mCpSize <= 1)
|
||||||
return 0;
|
return 0;
|
||||||
@ -330,12 +461,11 @@ int AttentionOp::ulyssesGenerationPostprocess(
|
|||||||
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
||||||
// view: [partialTokens, head, headSize]
|
// view: [partialTokens, head, headSize]
|
||||||
|
|
||||||
auto partialHeads = mNumHeads / mCpSize;
|
|
||||||
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
|
auto const partialTokenNum = (batch_beam + mCpSize - 1) / mCpSize;
|
||||||
|
|
||||||
// do all-to-all
|
// do all-to-all
|
||||||
#if ENABLE_MULTI_DEVICE
|
#if ENABLE_MULTI_DEVICE
|
||||||
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
|
size_t const elementNum = partialTokenNum * getHeadSize() * mNumAttnHeads;
|
||||||
ncclGroupStart();
|
ncclGroupStart();
|
||||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
||||||
{
|
{
|
||||||
@ -343,17 +473,17 @@ int AttentionOp::ulyssesGenerationPostprocess(
|
|||||||
{
|
{
|
||||||
if (mFP8ContextFMHA)
|
if (mFP8ContextFMHA)
|
||||||
{
|
{
|
||||||
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(mhaOutput) + cpIdx * elementNum, elementNum,
|
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(input) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||||
ncclInt8, cpIdx, *mCpNcclComm, stream));
|
cpIdx, *mCpNcclComm, stream));
|
||||||
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(mhaInput) + cpIdx * elementNum, elementNum,
|
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(buffer) + cpIdx * elementNum, elementNum, ncclInt8,
|
||||||
ncclInt8, cpIdx, *mCpNcclComm, stream));
|
cpIdx, *mCpNcclComm, stream));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
NCCLCHECK(ncclSend(
|
NCCLCHECK(ncclSend(
|
||||||
mhaOutput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
input + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||||
NCCLCHECK(ncclRecv(
|
NCCLCHECK(ncclRecv(
|
||||||
mhaInput + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
buffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType], cpIdx, *mCpNcclComm, stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -364,15 +494,14 @@ int AttentionOp::ulyssesGenerationPostprocess(
|
|||||||
if (mFP8ContextFMHA)
|
if (mFP8ContextFMHA)
|
||||||
{
|
{
|
||||||
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
|
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(output),
|
||||||
reinterpret_cast<__nv_fp8_e4m3 const*>(mhaOutput), reinterpret_cast<__nv_fp8_e4m3 const*>(mhaInput),
|
reinterpret_cast<__nv_fp8_e4m3 const*>(input), reinterpret_cast<__nv_fp8_e4m3 const*>(buffer),
|
||||||
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
|
partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
invokeCpTransposeToSeqMajor<T>(
|
invokeCpTransposeToSeqMajor<T>(
|
||||||
(T*) output, mhaOutput, mhaInput, partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
|
(T*) output, input, buffer, partialTokenNum, mCpSize, mNumAttnHeads, getHeadSize(), mCpRank, stream);
|
||||||
}
|
}
|
||||||
sync_check_cuda_error(stream);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -532,8 +661,8 @@ int AttentionOp::getHeadSize(bool checkInit) const
|
|||||||
size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t max_num_seq, int32_t input_seq_length,
|
size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t max_num_seq, int32_t input_seq_length,
|
||||||
int32_t cross_kv_length, int32_t max_num_tokens) const noexcept
|
int32_t cross_kv_length, int32_t max_num_tokens) const noexcept
|
||||||
{
|
{
|
||||||
int const local_hidden_units_qo = mNumHeads / mCpSize * getHeadSize();
|
int const local_hidden_units_qo = mNumAttnHeads * getHeadSize();
|
||||||
int const local_hidden_units_kv = mNumKVHeads / mCpSize * getHeadSize();
|
int const local_hidden_units_kv = mNumAttnKVHeads * getHeadSize();
|
||||||
|
|
||||||
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
||||||
|
|
||||||
@ -815,9 +944,8 @@ int AttentionOp::mlaGeneration(
|
|||||||
tllmRunnerParams.mHeadDimQk = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
|
tllmRunnerParams.mHeadDimQk = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
|
||||||
tllmRunnerParams.mHeadDimV = mMLAParams.kv_lora_rank;
|
tllmRunnerParams.mHeadDimV = mMLAParams.kv_lora_rank;
|
||||||
|
|
||||||
auto const num_q_heads = mNumHeads / mCpSize;
|
auto const num_q_heads = mNumAttnHeads;
|
||||||
tllmRunnerParams.mNumHeadsQ = num_q_heads;
|
tllmRunnerParams.mNumHeadsQ = num_q_heads;
|
||||||
// const auto num_kv_heads_tllm_runner_params = mNumKVHeads / mCpSize;
|
|
||||||
tllmRunnerParams.mNumHeadsKv = num_kv_heads;
|
tllmRunnerParams.mNumHeadsKv = num_kv_heads;
|
||||||
tllmRunnerParams.mNumHeadsQPerKv = num_q_heads / num_kv_heads;
|
tllmRunnerParams.mNumHeadsQPerKv = num_q_heads / num_kv_heads;
|
||||||
|
|
||||||
@ -1007,13 +1135,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
|||||||
int const headSize = getHeadSize();
|
int const headSize = getHeadSize();
|
||||||
|
|
||||||
int const local_hidden_units_qo = mNumHeads * headSize;
|
int const local_hidden_units_qo = mNumHeads * headSize;
|
||||||
int const local_hidden_units_kv = mNumKVHeads / mCpSize * headSize;
|
int const local_hidden_units_kv = mNumAttnKVHeads * headSize;
|
||||||
PositionEmbeddingType const position_embedding_type = mPositionEmbeddingType;
|
PositionEmbeddingType const position_embedding_type = mPositionEmbeddingType;
|
||||||
float const q_scaling = mQScaling;
|
float const q_scaling = mQScaling;
|
||||||
|
|
||||||
KVCacheBuffer kv_cache_buffer;
|
KVCacheBuffer kv_cache_buffer;
|
||||||
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
||||||
auto sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize;
|
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
|
||||||
if (useKVCache())
|
if (useKVCache())
|
||||||
{
|
{
|
||||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||||
@ -1252,72 +1380,13 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
|||||||
// do all-to-all for params.attention_input, need to split on kv head
|
// do all-to-all for params.attention_input, need to split on kv head
|
||||||
// [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size]
|
// [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size]
|
||||||
T* attention_input = const_cast<T*>(params.attention_input);
|
T* attention_input = const_cast<T*>(params.attention_input);
|
||||||
if (mCpSize > 1)
|
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||||
{
|
{
|
||||||
int32_t partialTokenNum = 0;
|
this->template ulyssesContextPreprocess<T>(
|
||||||
int32_t maxPartialLength = 0;
|
attention_input, gatherInBuffer, gatherOutBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
|
||||||
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
|
||||||
{
|
|
||||||
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
|
||||||
maxPartialLength = std::max(maxPartialLength, partialLength);
|
|
||||||
partialTokenNum += partialLength;
|
|
||||||
}
|
|
||||||
auto const totalHeads = mNumHeads + 2 * mNumKVHeads;
|
|
||||||
auto const partialHeads = totalHeads / mCpSize;
|
|
||||||
auto const partialQHeads = mNumHeads / mCpSize;
|
|
||||||
auto const partialKVHeads = mNumKVHeads / mCpSize;
|
|
||||||
|
|
||||||
// full request: [bs, seqlen, head, headSize]
|
|
||||||
//
|
|
||||||
// input of cp: [bs, partialLength, head, headSize]
|
|
||||||
// view_1 as [bs, partialLength, cpSize_Head, partialHead, headSize]
|
|
||||||
// transpose_1 as [cpSize_Head, bs, partialLenth, partialHead, headSize]
|
|
||||||
// all-to-all to get [cpSize_Length, bs, partialLength, partialHead, headSize]
|
|
||||||
// transpose_2 to [bs, cpSize_Length, partialLength, partialHead, headSize]
|
|
||||||
// view_2 as [bs, totalLength, partialHead, headSize]
|
|
||||||
// and this is same to the input under TP.
|
|
||||||
//
|
|
||||||
// when we use remove_input_padding, bs and length are fused into numTokens. So, we need to
|
|
||||||
// insert the cpSize_Length dimension of transpose_2 into numTokens directly like
|
|
||||||
// input of cp: [partialNumTokens, head, headSize]
|
|
||||||
// view_1 as [partialNumTokens, cpSize_Head, partialHead, headSize]
|
|
||||||
// transpose_1 as [cpSize_Head, partialNumTokens, partialHead, headSize]
|
|
||||||
// all-to-all to get [cpSize_Length, partialNumTokens, partialHead, headSize]
|
|
||||||
// transpose_2 as [NumTokens, partialHead, headSize]
|
|
||||||
// and this is same to the input under TP.
|
|
||||||
|
|
||||||
// view_1 + transpose_1
|
|
||||||
invokeCpTranspose(gatherInBuffer, gatherOutBuffer, params.attention_input, partialTokenNum, mCpSize,
|
|
||||||
partialQHeads, partialKVHeads, getHeadSize(), mCpRank, stream);
|
|
||||||
sync_check_cuda_error(stream);
|
|
||||||
|
|
||||||
// Do all to all
|
|
||||||
#if ENABLE_MULTI_DEVICE
|
|
||||||
ncclGroupStart();
|
|
||||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
|
||||||
{
|
|
||||||
if (cpIdx != mCpRank)
|
|
||||||
{
|
|
||||||
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
|
||||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
|
||||||
stream));
|
|
||||||
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * (partialTokenNum * getHeadSize() * partialHeads),
|
|
||||||
(partialTokenNum * getHeadSize() * partialHeads), (*getDtypeMap())[mType], cpIdx, *mCpNcclComm,
|
|
||||||
stream));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ncclGroupEnd();
|
|
||||||
sync_check_cuda_error(stream);
|
|
||||||
#endif // ENABLE_MULTI_DEVICE
|
|
||||||
|
|
||||||
// transpose_2 + view_2
|
|
||||||
invokeCpTranspose2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
|
|
||||||
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
|
|
||||||
stream);
|
|
||||||
|
|
||||||
attention_input = gatherInBuffer;
|
attention_input = gatherInBuffer;
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
}
|
}
|
||||||
sync_check_cuda_error(stream);
|
|
||||||
|
|
||||||
bool const enablePagedKVContextFMHA = mPagedKVCache && mPagedContextFMHA;
|
bool const enablePagedKVContextFMHA = mPagedKVCache && mPagedContextFMHA;
|
||||||
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasInt8KvCache() && enablePagedKVContextFMHA),
|
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasInt8KvCache() && enablePagedKVContextFMHA),
|
||||||
@ -1363,9 +1432,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
|||||||
preprocessingParams.token_num = params.num_tokens;
|
preprocessingParams.token_num = params.num_tokens;
|
||||||
preprocessingParams.remove_padding = mRemovePadding;
|
preprocessingParams.remove_padding = mRemovePadding;
|
||||||
preprocessingParams.cross_attention = isCrossAttention();
|
preprocessingParams.cross_attention = isCrossAttention();
|
||||||
preprocessingParams.head_num = mNumHeads / mCpSize;
|
preprocessingParams.head_num = mNumAttnHeads;
|
||||||
preprocessingParams.kv_head_num = mNumKVHeads / mCpSize;
|
preprocessingParams.kv_head_num = mNumAttnKVHeads;
|
||||||
preprocessingParams.qheads_per_kv_head = mNumHeads / mNumKVHeads;
|
preprocessingParams.qheads_per_kv_head = mNumAttnHeads / mNumAttnKVHeads;
|
||||||
preprocessingParams.size_per_head = getHeadSize();
|
preprocessingParams.size_per_head = getHeadSize();
|
||||||
preprocessingParams.rotary_embedding_dim = mRotaryEmbeddingDim;
|
preprocessingParams.rotary_embedding_dim = mRotaryEmbeddingDim;
|
||||||
preprocessingParams.rotary_embedding_base = mRotaryEmbeddingBase;
|
preprocessingParams.rotary_embedding_base = mRotaryEmbeddingBase;
|
||||||
@ -1479,79 +1548,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
|||||||
invokeKvCachePostprocessing(preprocessingParams, stream);
|
invokeKvCachePostprocessing(preprocessingParams, stream);
|
||||||
sync_check_cuda_error(stream);
|
sync_check_cuda_error(stream);
|
||||||
|
|
||||||
if (mCpSize > 1)
|
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||||
{
|
{
|
||||||
// After FMHA, we get result [numTokens(bs, cp, paritalLength), partialHead, headSize]
|
this->template ulyssesContextPostprocess<T>(gatherOutBuffer, reinterpret_cast<T*>(params.context_buf),
|
||||||
// transpose_2_reverse: [cpSize_Length, partialTokens(bs, partialLength), partialHead, headSize]
|
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
|
||||||
// all-to-all: [cpSize_Head, partialTokens, partialHead, headSize]
|
|
||||||
// transpose_1_reverse: [partialTokens, cpSize_Head, partialHead, headSize]
|
|
||||||
// view: [partialTokens, head, headSize]
|
|
||||||
|
|
||||||
int32_t maxPartialLength = 0;
|
|
||||||
int32_t partialTokenNum = 0;
|
|
||||||
for (int32_t batchIdx = 0; batchIdx < params.batch_size; ++batchIdx)
|
|
||||||
{
|
|
||||||
int32_t partialLength = (params.host_context_lengths[batchIdx] + mCpSize - 1) / mCpSize;
|
|
||||||
maxPartialLength = std::max(maxPartialLength, partialLength);
|
|
||||||
partialTokenNum += partialLength;
|
|
||||||
}
|
|
||||||
auto partialHeads = mNumHeads / mCpSize;
|
|
||||||
|
|
||||||
// transpose_2_reverse
|
|
||||||
if (mFP8ContextFMHA)
|
|
||||||
{
|
|
||||||
invokeCpTransposeToSeqMajor2(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer),
|
|
||||||
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), params.context_lengths, cu_q_seqlens,
|
|
||||||
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
invokeCpTransposeToSeqMajor2(gatherInBuffer, gatherOutBuffer, params.context_lengths, cu_q_seqlens,
|
|
||||||
cu_cp_partial_seqlens, mCpSize, maxPartialLength, params.batch_size, partialHeads, getHeadSize(),
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
// all-to-all
|
|
||||||
#if ENABLE_MULTI_DEVICE
|
|
||||||
size_t const elementNum = partialTokenNum * getHeadSize() * partialHeads;
|
|
||||||
ncclGroupStart();
|
|
||||||
for (int cpIdx = 0; cpIdx < mCpSize; cpIdx++)
|
|
||||||
{
|
|
||||||
if (cpIdx != mCpRank)
|
|
||||||
{
|
|
||||||
if (mFP8ContextFMHA)
|
|
||||||
{
|
|
||||||
NCCLCHECK(ncclSend(reinterpret_cast<__nv_fp8_e4m3*>(gatherInBuffer) + cpIdx * elementNum,
|
|
||||||
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
|
|
||||||
NCCLCHECK(ncclRecv(reinterpret_cast<__nv_fp8_e4m3*>(gatherOutBuffer) + cpIdx * elementNum,
|
|
||||||
elementNum, ncclInt8, cpIdx, *mCpNcclComm, stream));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
NCCLCHECK(ncclSend(gatherInBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
|
|
||||||
cpIdx, *mCpNcclComm, stream));
|
|
||||||
NCCLCHECK(ncclRecv(gatherOutBuffer + cpIdx * elementNum, elementNum, (*getDtypeMap())[mType],
|
|
||||||
cpIdx, *mCpNcclComm, stream));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ncclGroupEnd();
|
|
||||||
#endif // ENABLE_MULTI_DEVICE
|
|
||||||
|
|
||||||
// transpose_1_reverse + view
|
|
||||||
if (mFP8ContextFMHA)
|
|
||||||
{
|
|
||||||
invokeCpTransposeToSeqMajor<__nv_fp8_e4m3>(reinterpret_cast<__nv_fp8_e4m3*>(params.context_buf),
|
|
||||||
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherInBuffer),
|
|
||||||
reinterpret_cast<__nv_fp8_e4m3 const*>(gatherOutBuffer), partialTokenNum, mCpSize, partialHeads,
|
|
||||||
getHeadSize(), mCpRank, stream);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
invokeCpTransposeToSeqMajor<T>((T*) params.context_buf, gatherInBuffer, gatherOutBuffer,
|
|
||||||
partialTokenNum, mCpSize, partialHeads, getHeadSize(), mCpRank, stream);
|
|
||||||
}
|
|
||||||
sync_check_cuda_error(stream);
|
sync_check_cuda_error(stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1871,7 +1871,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
|||||||
|
|
||||||
KVCacheBuffer kv_cache_buffer;
|
KVCacheBuffer kv_cache_buffer;
|
||||||
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
||||||
auto const sizePerToken = mNumKVHeads / mCpSize * headSize * elemSize;
|
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
|
||||||
if (useKVCache())
|
if (useKVCache())
|
||||||
{
|
{
|
||||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||||
@ -1936,7 +1936,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
|||||||
T* mhaInput = mhaOutput + cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize;
|
T* mhaInput = mhaOutput + cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize;
|
||||||
|
|
||||||
T* attention_input = const_cast<T*>(params.attention_input);
|
T* attention_input = const_cast<T*>(params.attention_input);
|
||||||
this->template ulyssesGenerationPreprocess<T>(batch_beam, mhaInput, mhaOutput, attention_input, stream);
|
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||||
|
{
|
||||||
|
this->template ulyssesGenerationPreprocess<T>(attention_input, mhaInput, mhaOutput, batch_beam, stream);
|
||||||
|
attention_input = mhaInput;
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
}
|
||||||
|
|
||||||
// Try XQA optimization first.
|
// Try XQA optimization first.
|
||||||
{
|
{
|
||||||
@ -1954,7 +1959,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
|||||||
xqaParams.qkv = attention_input;
|
xqaParams.qkv = attention_input;
|
||||||
}
|
}
|
||||||
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
|
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
|
||||||
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream);
|
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||||
|
{
|
||||||
|
this->template ulyssesGenerationPostprocess<T>(
|
||||||
|
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
|
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
|
||||||
@ -2040,8 +2050,8 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
|||||||
dispatch_params.max_batch_size = batch_beam;
|
dispatch_params.max_batch_size = batch_beam;
|
||||||
dispatch_params.inference_batch_size = batch_beam;
|
dispatch_params.inference_batch_size = batch_beam;
|
||||||
dispatch_params.beam_width = params.beam_width;
|
dispatch_params.beam_width = params.beam_width;
|
||||||
dispatch_params.head_num = mNumHeads / mCpSize;
|
dispatch_params.head_num = mNumAttnHeads;
|
||||||
dispatch_params.kv_head_num = mNumKVHeads / mCpSize;
|
dispatch_params.kv_head_num = mNumAttnKVHeads;
|
||||||
dispatch_params.size_per_head = getHeadSize();
|
dispatch_params.size_per_head = getHeadSize();
|
||||||
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
|
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
|
||||||
dispatch_params.position_embedding_type = mPositionEmbeddingType;
|
dispatch_params.position_embedding_type = mPositionEmbeddingType;
|
||||||
@ -2104,7 +2114,12 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
|||||||
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
|
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
this->template ulyssesGenerationPostprocess<T>(batch_beam, mhaInput, mhaOutput, params.context_buf, stream);
|
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
|
||||||
|
{
|
||||||
|
this->template ulyssesGenerationPostprocess<T>(
|
||||||
|
mhaOutput, reinterpret_cast<T*>(params.context_buf), mhaInput, batch_beam, stream);
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2164,6 +2179,21 @@ template void AttentionOp::prepareEnqueueGeneration<__nv_bfloat16, KVBlockArray>
|
|||||||
|
|
||||||
int AttentionOp::initialize() noexcept
|
int AttentionOp::initialize() noexcept
|
||||||
{
|
{
|
||||||
|
// use Ulysses for GPTAttentionPlugin
|
||||||
|
if (mAttnTpSize < 0 || mAttnCpSize < 0)
|
||||||
|
{
|
||||||
|
mAttnTpSize = mTpSize * mCpSize;
|
||||||
|
mAttnCpSize = 1;
|
||||||
|
}
|
||||||
|
mNumAttnHeads = mNumHeads * mTpSize / mAttnTpSize;
|
||||||
|
mNumAttnKVHeads = (mNumKVHeads * mTpSize + mAttnTpSize - 1) / mAttnTpSize;
|
||||||
|
|
||||||
|
if (mCpSize != mAttnCpSize)
|
||||||
|
{
|
||||||
|
// mqa broadcast
|
||||||
|
mUlyssesMQABroadcast = (mAttnTpSize + mNumKVHeadsOrigin - 1) / mNumKVHeadsOrigin;
|
||||||
|
}
|
||||||
|
|
||||||
// Pre-check whether FMHA is supported in order to save memory allocation.
|
// Pre-check whether FMHA is supported in order to save memory allocation.
|
||||||
if (mEnableContextFMHA)
|
if (mEnableContextFMHA)
|
||||||
{
|
{
|
||||||
@ -2306,8 +2336,8 @@ int AttentionOp::initialize() noexcept
|
|||||||
fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
|
fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
|
||||||
}
|
}
|
||||||
fmhaParams.isSPadded = !mRemovePadding;
|
fmhaParams.isSPadded = !mRemovePadding;
|
||||||
fmhaParams.numQHeads = mNumHeads / mCpSize;
|
fmhaParams.numQHeads = mNumAttnHeads;
|
||||||
fmhaParams.numKvHeads = mNumKVHeads / mCpSize;
|
fmhaParams.numKvHeads = mNumAttnKVHeads;
|
||||||
fmhaParams.numTokensPerBlock = mTokensPerBlock;
|
fmhaParams.numTokensPerBlock = mTokensPerBlock;
|
||||||
fmhaParams.headSize = mHeadSize;
|
fmhaParams.headSize = mHeadSize;
|
||||||
fmhaParams.headSizeV = mHeadSize;
|
fmhaParams.headSizeV = mHeadSize;
|
||||||
@ -2326,16 +2356,6 @@ int AttentionOp::initialize() noexcept
|
|||||||
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
|
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
|
||||||
fmhaParams.hasAlibi = isALiBi();
|
fmhaParams.hasAlibi = isALiBi();
|
||||||
fmhaParams.scaleAlibi = isAliBiWithScale();
|
fmhaParams.scaleAlibi = isAliBiWithScale();
|
||||||
if (mTpSize > 1)
|
|
||||||
{
|
|
||||||
fmhaParams.tpSize = mTpSize;
|
|
||||||
fmhaParams.tpRank = mTpRank;
|
|
||||||
}
|
|
||||||
else if (mCpSize > 1)
|
|
||||||
{
|
|
||||||
fmhaParams.tpSize = mCpSize;
|
|
||||||
fmhaParams.tpRank = mCpRank;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load kernels from the pre-compiled cubins.
|
// Load kernels from the pre-compiled cubins.
|
||||||
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
|
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
|
||||||
@ -2479,8 +2499,8 @@ int AttentionOp::initialize() noexcept
|
|||||||
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
|
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
|
||||||
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
|
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
|
||||||
}
|
}
|
||||||
fixedParams.numQHeads = mNumHeads / mCpSize;
|
fixedParams.numQHeads = mNumAttnHeads;
|
||||||
fixedParams.numKvHeads = mNumKVHeads / mCpSize;
|
fixedParams.numKvHeads = mNumAttnKVHeads;
|
||||||
fixedParams.numTokensPerBlock = mTokensPerBlock;
|
fixedParams.numTokensPerBlock = mTokensPerBlock;
|
||||||
fixedParams.headSize = mHeadSize;
|
fixedParams.headSize = mHeadSize;
|
||||||
fixedParams.qScaling = mQScaling;
|
fixedParams.qScaling = mQScaling;
|
||||||
@ -2555,6 +2575,7 @@ std::string AttentionOp::toString() const
|
|||||||
ss << "gptAttentionCommon members ====================" << std::endl;
|
ss << "gptAttentionCommon members ====================" << std::endl;
|
||||||
ss << "mNumHeads: " << mNumHeads << std::endl;
|
ss << "mNumHeads: " << mNumHeads << std::endl;
|
||||||
ss << "mNumKVHeads: " << mNumKVHeads << std::endl;
|
ss << "mNumKVHeads: " << mNumKVHeads << std::endl;
|
||||||
|
ss << "mNumKVHeadsOrigin: " << mNumKVHeadsOrigin << std::endl;
|
||||||
ss << "mHeadSize: " << mHeadSize << std::endl;
|
ss << "mHeadSize: " << mHeadSize << std::endl;
|
||||||
ss << "mUnidirectional: " << mUnidirectional << std::endl;
|
ss << "mUnidirectional: " << mUnidirectional << std::endl;
|
||||||
ss << "mQScaling: " << mQScaling << std::endl;
|
ss << "mQScaling: " << mQScaling << std::endl;
|
||||||
|
|||||||
@ -229,10 +229,18 @@ public:
|
|||||||
EnqueueGenerationParams<T> const& generationsParams, bool forConfigurePlugin);
|
EnqueueGenerationParams<T> const& generationsParams, bool forConfigurePlugin);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int ulyssesGenerationPreprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, T*& input, cudaStream_t stream);
|
int ulyssesContextPreprocess(T const* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||||
|
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int ulyssesGenerationPostprocess(int32_t batch_beam, T* mhaInput, T* mhaOutput, void* output, cudaStream_t stream);
|
int ulyssesContextPostprocess(T* input, T* output, T* buffer, EnqueueContextParams<T> const& params,
|
||||||
|
int const* cu_q_seqlens, int const* cu_cp_partial_seqlens, cudaStream_t stream);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int ulyssesGenerationPreprocess(T const* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int ulyssesGenerationPostprocess(T* input, T* output, T* buffer, int32_t batch_beam, cudaStream_t stream);
|
||||||
|
|
||||||
[[nodiscard]] bool isRelativePosition() const
|
[[nodiscard]] bool isRelativePosition() const
|
||||||
{
|
{
|
||||||
@ -374,6 +382,16 @@ public:
|
|||||||
int mCpSize = 1;
|
int mCpSize = 1;
|
||||||
int mCpRank = 0;
|
int mCpRank = 0;
|
||||||
std::set<int32_t> mCpGroup = {};
|
std::set<int32_t> mCpGroup = {};
|
||||||
|
// These parameters are used to specifically configure the attention attributes when cp/tp_size are different
|
||||||
|
// between Attention and FFN(such as Ulysses)
|
||||||
|
int mNumAttnHeads = -1;
|
||||||
|
int mNumAttnKVHeads = -1;
|
||||||
|
int mNumKVHeadsOrigin = -1;
|
||||||
|
int mAttnTpSize = -1;
|
||||||
|
int mAttnTpRank = 0;
|
||||||
|
int mAttnCpSize = -1;
|
||||||
|
int mAttnCpRank = 0;
|
||||||
|
int mUlyssesMQABroadcast = 1;
|
||||||
|
|
||||||
// fmha runner (enabled by default)
|
// fmha runner (enabled by default)
|
||||||
// flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2
|
// flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2
|
||||||
@ -403,8 +421,9 @@ public:
|
|||||||
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
|
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
|
||||||
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
|
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
|
||||||
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
|
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
|
||||||
mCpGroup, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn,
|
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
|
||||||
mFuseFp4Quant, mNbMultiBlockSemaphores);
|
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
|
||||||
|
mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -2139,7 +2139,7 @@ __global__ void convertData(Dst* dst, Src const* src, int64_t size, float const*
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize,
|
__global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize,
|
||||||
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank)
|
int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank)
|
||||||
{
|
{
|
||||||
// Do transpose from
|
// Do transpose from
|
||||||
// [partialTokenNum, mNumHeads + 2*mNumKVHeads, headSize]
|
// [partialTokenNum, mNumHeads + 2*mNumKVHeads, headSize]
|
||||||
@ -2164,12 +2164,12 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
|
|||||||
}
|
}
|
||||||
else if (headIdx < partialQHeads + partialKVHeads)
|
else if (headIdx < partialQHeads + partialKVHeads)
|
||||||
{
|
{
|
||||||
srcHeadIdx = cpSize * partialQHeads + cpIdx * partialKVHeads + (headIdx - partialQHeads);
|
srcHeadIdx = cpSize * partialQHeads + cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
srcHeadIdx = cpSize * partialQHeads + cpSize * partialKVHeads + cpIdx * partialKVHeads
|
srcHeadIdx = cpSize * partialQHeads + cpSize / mqaBroadcast * partialKVHeads
|
||||||
+ (headIdx - partialQHeads - partialKVHeads);
|
+ cpIdx / mqaBroadcast * partialKVHeads + (headIdx - partialQHeads - partialKVHeads);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cpIdx == rank)
|
if (cpIdx == rank)
|
||||||
@ -2181,7 +2181,7 @@ __global__ void runCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTok
|
|||||||
+ seqIdx * (partialQHeads + 2 * partialKVHeads) + headIdx)
|
+ seqIdx * (partialQHeads + 2 * partialKVHeads) + headIdx)
|
||||||
* headSize);
|
* headSize);
|
||||||
VecType const* in = reinterpret_cast<VecType const*>(
|
VecType const* in = reinterpret_cast<VecType const*>(
|
||||||
src + (seqIdx * (partialQHeads + 2 * partialKVHeads) * cpSize + srcHeadIdx) * headSize);
|
src + (seqIdx * (partialQHeads * cpSize + 2 * partialKVHeads * cpSize / mqaBroadcast) + srcHeadIdx) * headSize);
|
||||||
|
|
||||||
for (int hiddenIdx = threadIdx.x; hiddenIdx < hiddenSize + hiddenRestSize; hiddenIdx += blockDim.x)
|
for (int hiddenIdx = threadIdx.x; hiddenIdx < hiddenSize + hiddenRestSize; hiddenIdx += blockDim.x)
|
||||||
{
|
{
|
||||||
@ -2347,17 +2347,18 @@ INSTANTIATE_invokeConversion(__nv_fp8_e4m3, __nv_bfloat16);
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, int64_t partialQHeads,
|
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialTokenNum, int64_t cpSize, int64_t partialQHeads,
|
||||||
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream)
|
int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
dim3 grid(partialTokenNum, cpSize, partialQHeads + 2 * partialKVHeads);
|
dim3 grid(partialTokenNum, cpSize, partialQHeads + 2 * partialKVHeads);
|
||||||
dim3 block(128);
|
dim3 block(128);
|
||||||
runCpTranspose<T><<<grid, block, 0, stream>>>(
|
runCpTranspose<T><<<grid, block, 0, stream>>>(
|
||||||
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, headSize, rank);
|
dst, dst2, src, partialTokenNum, cpSize, partialQHeads, partialKVHeads, mqaBroadcast, headSize, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define INSTANTIATE_invokeCpTranspose(T) \
|
#define INSTANTIATE_invokeCpTranspose(T) \
|
||||||
template void invokeCpTranspose<T>(T * dst, T * dst2, T const* src, int64_t partialLength, int64_t cpSize, \
|
template void invokeCpTranspose<T>(T * dst, T * dst2, T const* src, int64_t partialLength, int64_t cpSize, \
|
||||||
int64_t partialQHeads, int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream)
|
int64_t partialQHeads, int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, \
|
||||||
|
cudaStream_t stream)
|
||||||
INSTANTIATE_invokeCpTranspose(float);
|
INSTANTIATE_invokeCpTranspose(float);
|
||||||
INSTANTIATE_invokeCpTranspose(half);
|
INSTANTIATE_invokeCpTranspose(half);
|
||||||
INSTANTIATE_invokeCpTranspose(__nv_bfloat16);
|
INSTANTIATE_invokeCpTranspose(__nv_bfloat16);
|
||||||
|
|||||||
@ -407,7 +407,7 @@ void invokeConversion(Dst* dst, Src const* src, int64_t size, float const* __res
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialLength, int64_t cpSize, int64_t partialQHeads,
|
void invokeCpTranspose(T* dst, T* dst2, T const* src, int64_t partialLength, int64_t cpSize, int64_t partialQHeads,
|
||||||
int64_t partialKVHeads, int64_t headSize, int64_t rank, cudaStream_t stream);
|
int64_t partialKVHeads, int64_t mqaBroadcast, int64_t headSize, int64_t rank, cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeCpTransposeToSeqMajor(T* dst, T const* srcMyRank, T const* srcOtherRank, int64_t partialLength,
|
void invokeCpTransposeToSeqMajor(T* dst, T const* srcMyRank, T const* srcOtherRank, int64_t partialLength,
|
||||||
|
|||||||
@ -28,8 +28,8 @@ using tensorrt_llm::plugins::GPTAttentionPluginCreatorCommon;
|
|||||||
using tensorrt_llm::plugins::GPTAttentionPluginCommon;
|
using tensorrt_llm::plugins::GPTAttentionPluginCommon;
|
||||||
|
|
||||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
|
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
|
||||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||||
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
|
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
|
||||||
@ -52,6 +52,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
|||||||
mVisionStart = vision_start;
|
mVisionStart = vision_start;
|
||||||
mVisionLength = vision_length;
|
mVisionLength = vision_length;
|
||||||
mNumKVHeads = num_kv_heads;
|
mNumKVHeads = num_kv_heads;
|
||||||
|
mNumKVHeadsOrigin = num_kv_heads_origin;
|
||||||
mHeadSize = head_size;
|
mHeadSize = head_size;
|
||||||
mUnidirectional = unidirectional;
|
mUnidirectional = unidirectional;
|
||||||
mQScaling = q_scaling;
|
mQScaling = q_scaling;
|
||||||
@ -114,6 +115,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
|||||||
read(d, mVisionStart);
|
read(d, mVisionStart);
|
||||||
read(d, mVisionLength);
|
read(d, mVisionLength);
|
||||||
read(d, mNumKVHeads);
|
read(d, mNumKVHeads);
|
||||||
|
read(d, mNumKVHeadsOrigin);
|
||||||
read(d, mHeadSize);
|
read(d, mHeadSize);
|
||||||
read(d, mUnidirectional);
|
read(d, mUnidirectional);
|
||||||
read(d, mQScaling);
|
read(d, mQScaling);
|
||||||
@ -201,13 +203,13 @@ void GPTAttentionPluginCommon::destroy() noexcept
|
|||||||
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
|
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
|
||||||
{
|
{
|
||||||
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
|
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
|
||||||
+ sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mAttnLogitSoftcappingScale)
|
+ sizeof(mNumKVHeadsOrigin) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
|
||||||
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
|
+ sizeof(mAttnLogitSoftcappingScale) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
|
||||||
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingShortMscale)
|
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
|
||||||
+ sizeof(mRotaryEmbeddingLongMscale) + sizeof(mRotaryEmbeddingMaxPositions)
|
+ sizeof(mRotaryEmbeddingShortMscale) + sizeof(mRotaryEmbeddingLongMscale)
|
||||||
+ sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
|
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize)
|
||||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
|
+ sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode)
|
||||||
+ sizeof(unsigned int) // mKVCacheQuantMode
|
+ sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
|
||||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache)
|
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache)
|
||||||
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
|
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
|
||||||
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
|
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
|
||||||
@ -228,6 +230,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
|||||||
write(d, mVisionStart);
|
write(d, mVisionStart);
|
||||||
write(d, mVisionLength);
|
write(d, mVisionLength);
|
||||||
write(d, mNumKVHeads);
|
write(d, mNumKVHeads);
|
||||||
|
write(d, mNumKVHeadsOrigin);
|
||||||
write(d, mHeadSize);
|
write(d, mHeadSize);
|
||||||
write(d, mUnidirectional);
|
write(d, mUnidirectional);
|
||||||
write(d, mQScaling);
|
write(d, mQScaling);
|
||||||
@ -307,6 +310,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
|||||||
mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32));
|
mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32));
|
||||||
mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32));
|
mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32));
|
||||||
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32));
|
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32));
|
||||||
|
mPluginAttributes.emplace_back(PluginField("num_kv_heads_origin", nullptr, PluginFieldType::kINT32));
|
||||||
mPluginAttributes.emplace_back(PluginField("layer_idx_in_cache_pool", nullptr, PluginFieldType::kINT32));
|
mPluginAttributes.emplace_back(PluginField("layer_idx_in_cache_pool", nullptr, PluginFieldType::kINT32));
|
||||||
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32));
|
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32));
|
||||||
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32));
|
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32));
|
||||||
|
|||||||
@ -35,7 +35,7 @@ public:
|
|||||||
GPTAttentionPluginCommon() = delete;
|
GPTAttentionPluginCommon() = delete;
|
||||||
|
|
||||||
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||||
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||||
|
|||||||
@ -46,8 +46,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
|
|||||||
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
|
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
|
||||||
|
|
||||||
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
|
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
int num_kv_heads, int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling,
|
||||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||||
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
|
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
|
||||||
@ -65,17 +65,17 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
|
|||||||
int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank,
|
int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank,
|
||||||
int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size,
|
int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size,
|
||||||
int cp_rank, std::set<int32_t> cp_group)
|
int cp_rank, std::set<int32_t> cp_group)
|
||||||
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
|
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, num_kv_heads_origin,
|
||||||
unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type, rotary_embedding_dim,
|
head_size, unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type,
|
||||||
rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
|
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
|
||||||
rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size,
|
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions,
|
||||||
tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type, kv_cache_quant_mode, remove_input_padding,
|
rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type,
|
||||||
mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled,
|
kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block,
|
||||||
cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, use_paged_context_fmha,
|
type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
|
||||||
use_fp8_context_fmha, has_full_attention_mask, use_cache, is_spec_decoding_enabled,
|
dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, use_cache,
|
||||||
spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank,
|
is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length,
|
||||||
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant, skip_attn, cp_size, cp_rank,
|
is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant,
|
||||||
cp_group)
|
skip_attn, cp_size, cp_rank, cp_group)
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(
|
TLLM_CHECK_WITH_INFO(
|
||||||
!is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead.");
|
!is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead.");
|
||||||
@ -890,7 +890,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
|||||||
|
|
||||||
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
|
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
|
||||||
|
|
||||||
auto const blockSize = mTokensPerBlock * mNumKVHeads / mCpSize * mHeadSize;
|
auto const kv_cache_head_num = (mNumKVHeads + mCpSize - 1) / mCpSize;
|
||||||
|
auto const blockSize = mTokensPerBlock * kv_cache_head_num * mHeadSize;
|
||||||
auto const bytesPerBlock = blockSize * cacheElemSize;
|
auto const bytesPerBlock = blockSize * cacheElemSize;
|
||||||
auto const layerOffset = layerIdxInCachePool * 2 * bytesPerBlock;
|
auto const layerOffset = layerIdxInCachePool * 2 * bytesPerBlock;
|
||||||
|
|
||||||
@ -1311,8 +1312,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
|
|||||||
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
|
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
|
||||||
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
|
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
|
||||||
p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
|
p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
|
||||||
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
|
p.getScalar<int32_t>("num_kv_heads_origin").value(), p.getScalar<int32_t>("head_size").value(),
|
||||||
p.getScalar<float>("q_scaling").value(), p.getScalar<float>("attn_logit_softcapping_scale").value(),
|
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
|
||||||
|
p.getScalar<float>("attn_logit_softcapping_scale").value(),
|
||||||
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
|
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
|
||||||
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
|
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
|
||||||
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
|
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
|
||||||
|
|||||||
@ -101,7 +101,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||||
int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
int num_kv_heads_origin, int head_size, int unidirectional, float q_scaling, float attn_logit_softcapping_scale,
|
||||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||||
|
|||||||
@ -236,6 +236,8 @@ def convert_and_save_hf(args):
|
|||||||
load_model_on_cpu=args.load_model_on_cpu,
|
load_model_on_cpu=args.load_model_on_cpu,
|
||||||
**override_fields)
|
**override_fields)
|
||||||
glm.config.mapping.cp_size = args.cp_size
|
glm.config.mapping.cp_size = args.cp_size
|
||||||
|
glm.config.mapping.attn_tp_size = -1
|
||||||
|
glm.config.mapping.attn_cp_size = -1
|
||||||
glm.config.mapping.world_size *= args.cp_size
|
glm.config.mapping.world_size *= args.cp_size
|
||||||
glm.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
glm.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||||
del glm
|
del glm
|
||||||
|
|||||||
@ -393,6 +393,8 @@ def convert_and_save_meta(args, rank):
|
|||||||
use_parallel_embedding=args.use_parallel_embedding,
|
use_parallel_embedding=args.use_parallel_embedding,
|
||||||
embedding_sharding_dim=args.embedding_sharding_dim)
|
embedding_sharding_dim=args.embedding_sharding_dim)
|
||||||
llama.config.mapping.cp_size = args.cp_size
|
llama.config.mapping.cp_size = args.cp_size
|
||||||
|
llama.config.mapping.attn_tp_size = -1
|
||||||
|
llama.config.mapping.attn_cp_size = -1
|
||||||
llama.config.mapping.world_size *= args.cp_size
|
llama.config.mapping.world_size *= args.cp_size
|
||||||
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||||
|
|
||||||
@ -511,6 +513,8 @@ def convert_and_save_hf(args):
|
|||||||
f'Total time of reading and converting: {time.time()-tik:.3f} s'
|
f'Total time of reading and converting: {time.time()-tik:.3f} s'
|
||||||
)
|
)
|
||||||
llama.config.mapping.cp_size = args.cp_size
|
llama.config.mapping.cp_size = args.cp_size
|
||||||
|
llama.config.mapping.attn_tp_size = -1
|
||||||
|
llama.config.mapping.attn_cp_size = -1
|
||||||
llama.config.mapping.world_size *= args.cp_size
|
llama.config.mapping.world_size *= args.cp_size
|
||||||
tik = time.time()
|
tik = time.time()
|
||||||
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||||
|
|||||||
@ -281,6 +281,8 @@ def convert_and_save_hf(args):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
**override_fields)
|
**override_fields)
|
||||||
qwen.config.mapping.cp_size = args.cp_size
|
qwen.config.mapping.cp_size = args.cp_size
|
||||||
|
qwen.config.mapping.attn_tp_size = -1
|
||||||
|
qwen.config.mapping.attn_cp_size = -1
|
||||||
qwen.config.mapping.world_size *= args.cp_size
|
qwen.config.mapping.world_size *= args.cp_size
|
||||||
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||||
del qwen
|
del qwen
|
||||||
|
|||||||
@ -5213,6 +5213,7 @@ def gpt_attention(
|
|||||||
cp_group: List[int] = [0],
|
cp_group: List[int] = [0],
|
||||||
cp_size: int = 1,
|
cp_size: int = 1,
|
||||||
cp_rank: int = 0,
|
cp_rank: int = 0,
|
||||||
|
num_kv_heads_origin: int = -1,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
'''
|
'''
|
||||||
Add an operation that performs the multi-head attention in GPT-like models.
|
Add an operation that performs the multi-head attention in GPT-like models.
|
||||||
@ -5474,6 +5475,9 @@ def gpt_attention(
|
|||||||
skip_attn: Tensor = None,
|
skip_attn: Tensor = None,
|
||||||
A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.
|
A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.
|
||||||
|
|
||||||
|
num_kv_heads_origin: int
|
||||||
|
The origin number of KV heads, without the process of TP
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The tensor produced by that layer.
|
The tensor produced by that layer.
|
||||||
'''
|
'''
|
||||||
@ -5505,6 +5509,9 @@ def gpt_attention(
|
|||||||
else:
|
else:
|
||||||
use_logn_scaling = 0
|
use_logn_scaling = 0
|
||||||
|
|
||||||
|
if num_kv_heads_origin < 1:
|
||||||
|
num_kv_heads_origin = num_kv_heads
|
||||||
|
|
||||||
unfuse_qkv_gemm = trt.PluginField(
|
unfuse_qkv_gemm = trt.PluginField(
|
||||||
"unfuse_qkv_gemm", np.array(np.int8(is_unfuse_qkv_gemm), dtype=np.int8),
|
"unfuse_qkv_gemm", np.array(np.int8(is_unfuse_qkv_gemm), dtype=np.int8),
|
||||||
trt.PluginFieldType.INT8)
|
trt.PluginFieldType.INT8)
|
||||||
@ -5523,6 +5530,9 @@ def gpt_attention(
|
|||||||
num_kv_heads = trt.PluginField("num_kv_heads",
|
num_kv_heads = trt.PluginField("num_kv_heads",
|
||||||
np.array(num_kv_heads, dtype=np.int32),
|
np.array(num_kv_heads, dtype=np.int32),
|
||||||
trt.PluginFieldType.INT32)
|
trt.PluginFieldType.INT32)
|
||||||
|
num_kv_heads_origin = trt.PluginField(
|
||||||
|
"num_kv_heads_origin", np.array(num_kv_heads_origin, dtype=np.int32),
|
||||||
|
trt.PluginFieldType.INT32)
|
||||||
head_size = trt.PluginField("head_size",
|
head_size = trt.PluginField("head_size",
|
||||||
np.array(hidden_size_per_head, dtype=np.int32),
|
np.array(hidden_size_per_head, dtype=np.int32),
|
||||||
trt.PluginFieldType.INT32)
|
trt.PluginFieldType.INT32)
|
||||||
@ -5714,9 +5724,10 @@ def gpt_attention(
|
|||||||
trt.PluginFieldType.INT8)
|
trt.PluginFieldType.INT8)
|
||||||
|
|
||||||
pfc = trt.PluginFieldCollection([
|
pfc = trt.PluginFieldCollection([
|
||||||
layer_idx, nheads, vision_start, vision_length, num_kv_heads, head_size,
|
layer_idx, nheads, vision_start, vision_length, num_kv_heads,
|
||||||
unidirectional, q_scaling, attn_logit_softcapping_scale,
|
num_kv_heads_origin, head_size, unidirectional, q_scaling,
|
||||||
position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
|
attn_logit_softcapping_scale, position_embedding_type,
|
||||||
|
rotary_embedding_dim, rotary_embedding_base,
|
||||||
rotary_embedding_scale_type, rotary_embedding_scale,
|
rotary_embedding_scale_type, rotary_embedding_scale,
|
||||||
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale,
|
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale,
|
||||||
rotary_embedding_max_positions, rotary_embedding_original_max_positions,
|
rotary_embedding_max_positions, rotary_embedding_original_max_positions,
|
||||||
|
|||||||
@ -395,13 +395,13 @@ class Attention(Module):
|
|||||||
self.cross_attention = cross_attention
|
self.cross_attention = cross_attention
|
||||||
self.attention_mask_type = attention_mask_type
|
self.attention_mask_type = attention_mask_type
|
||||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
assert num_attention_heads % tp_size == 0, \
|
assert num_attention_heads % tp_size == 0, \
|
||||||
"num_attention_heads must be divisible by tp_size"
|
"num_attention_heads must be divisible by tp_size"
|
||||||
self.num_attention_heads = num_attention_heads // tp_size
|
self.num_attention_heads = num_attention_heads // tp_size
|
||||||
self.num_attention_kv_heads = (
|
self.num_attention_kv_heads = (
|
||||||
num_kv_heads + tp_size - 1
|
num_kv_heads + tp_size - 1
|
||||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||||
|
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else self.num_attention_heads
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
|
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
@ -1058,6 +1058,7 @@ class Attention(Module):
|
|||||||
layer_idx=self.local_layer_idx,
|
layer_idx=self.local_layer_idx,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
num_kv_heads=self.num_attention_kv_heads,
|
num_kv_heads=self.num_attention_kv_heads,
|
||||||
|
num_kv_heads_origin=self.num_kv_heads,
|
||||||
hidden_size_per_head=self.attention_head_size,
|
hidden_size_per_head=self.attention_head_size,
|
||||||
q_scaling=self.q_scaling,
|
q_scaling=self.q_scaling,
|
||||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||||
@ -1866,6 +1867,7 @@ class CogVLMAttention(Attention):
|
|||||||
layer_idx=self.local_layer_idx,
|
layer_idx=self.local_layer_idx,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
num_kv_heads=self.num_attention_kv_heads,
|
num_kv_heads=self.num_attention_kv_heads,
|
||||||
|
num_kv_heads_origin=self.num_kv_heads,
|
||||||
hidden_size_per_head=self.attention_head_size,
|
hidden_size_per_head=self.attention_head_size,
|
||||||
q_scaling=self.q_scaling,
|
q_scaling=self.q_scaling,
|
||||||
position_embedding_type=self.position_embedding_type,
|
position_embedding_type=self.position_embedding_type,
|
||||||
@ -2216,6 +2218,7 @@ class DeepseekV2Attention(Attention):
|
|||||||
layer_idx=self.local_layer_idx,
|
layer_idx=self.local_layer_idx,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
|
num_kv_heads_origin=1,
|
||||||
hidden_size_per_head=self.kv_lora_rank + self.qk_rope_head_dim,
|
hidden_size_per_head=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
q_scaling=self.q_scaling,
|
q_scaling=self.q_scaling,
|
||||||
position_embedding_type=self.position_embedding_type,
|
position_embedding_type=self.position_embedding_type,
|
||||||
|
|||||||
@ -123,6 +123,8 @@ class Mapping(object):
|
|||||||
pp_size=1,
|
pp_size=1,
|
||||||
moe_tp_size=-1, # -1 means no moe
|
moe_tp_size=-1, # -1 means no moe
|
||||||
moe_ep_size=-1, # -1 means no moe
|
moe_ep_size=-1, # -1 means no moe
|
||||||
|
attn_tp_size=-1,
|
||||||
|
attn_cp_size=-1,
|
||||||
auto_parallel=False,
|
auto_parallel=False,
|
||||||
enable_attention_dp=False):
|
enable_attention_dp=False):
|
||||||
# set default values for non-moe cases
|
# set default values for non-moe cases
|
||||||
@ -137,6 +139,22 @@ class Mapping(object):
|
|||||||
elif moe_ep_size == -1:
|
elif moe_ep_size == -1:
|
||||||
moe_ep_size = tp_size // moe_tp_size
|
moe_ep_size = tp_size // moe_tp_size
|
||||||
|
|
||||||
|
if attn_tp_size == -1 and attn_cp_size == -1:
|
||||||
|
# fallback to ulysses
|
||||||
|
attn_tp_size = tp_size * cp_size
|
||||||
|
attn_cp_size = 1
|
||||||
|
|
||||||
|
elif attn_tp_size == -1:
|
||||||
|
attn_tp_size = cp_size * tp_size // attn_cp_size
|
||||||
|
|
||||||
|
elif attn_cp_size == -1:
|
||||||
|
attn_cp_size = cp_size * tp_size // attn_tp_size
|
||||||
|
|
||||||
|
if attn_cp_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}."
|
||||||
|
)
|
||||||
|
|
||||||
if auto_parallel:
|
if auto_parallel:
|
||||||
if tp_size != 1 or pp_size != 1 or tp_size != 1:
|
if tp_size != 1 or pp_size != 1 or tp_size != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -154,6 +172,12 @@ class Mapping(object):
|
|||||||
f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
|
f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_tp_cp_size = attn_tp_size * attn_cp_size
|
||||||
|
if attn_tp_cp_size != tp_size * cp_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
|
||||||
|
)
|
||||||
|
|
||||||
if moe_ep_size != 1 and cp_size > 1:
|
if moe_ep_size != 1 and cp_size > 1:
|
||||||
raise NotImplementedError("CP don't support MoE tp/ep yet")
|
raise NotImplementedError("CP don't support MoE tp/ep yet")
|
||||||
|
|
||||||
@ -163,6 +187,8 @@ class Mapping(object):
|
|||||||
self.pp_size = pp_size
|
self.pp_size = pp_size
|
||||||
self.moe_tp_size = moe_tp_size
|
self.moe_tp_size = moe_tp_size
|
||||||
self.moe_ep_size = moe_ep_size
|
self.moe_ep_size = moe_ep_size
|
||||||
|
self.attn_tp_size = attn_tp_size
|
||||||
|
self.attn_cp_size = attn_cp_size
|
||||||
self.auto_parallel = auto_parallel
|
self.auto_parallel = auto_parallel
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@ -218,6 +244,8 @@ class Mapping(object):
|
|||||||
and self.pp_size == other.pp_size
|
and self.pp_size == other.pp_size
|
||||||
and self.moe_tp_size == other.moe_tp_size
|
and self.moe_tp_size == other.moe_tp_size
|
||||||
and self.moe_ep_size == other.moe_ep_size
|
and self.moe_ep_size == other.moe_ep_size
|
||||||
|
and self.attn_tp_size == other.attn_tp_size
|
||||||
|
and self.attn_cp_size == other.attn_cp_size
|
||||||
and self.auto_parallel == other.auto_parallel)
|
and self.auto_parallel == other.auto_parallel)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
@ -225,6 +253,7 @@ class Mapping(object):
|
|||||||
^ hash(self.gpus_per_node) ^ hash(self.cp_size)
|
^ hash(self.gpus_per_node) ^ hash(self.cp_size)
|
||||||
^ hash(self.tp_size) ^ hash(self.pp_size)
|
^ hash(self.tp_size) ^ hash(self.pp_size)
|
||||||
^ hash(self.moe_tp_size) ^ hash(self.moe_ep_size)
|
^ hash(self.moe_tp_size) ^ hash(self.moe_ep_size)
|
||||||
|
^ hash(self.attn_tp_size) ^ hash(self.attn_cp_size)
|
||||||
^ hash(self.auto_parallel))
|
^ hash(self.auto_parallel))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -375,5 +404,7 @@ class Mapping(object):
|
|||||||
'pp_size': self.pp_size,
|
'pp_size': self.pp_size,
|
||||||
'moe_tp_size': self.moe_tp_size,
|
'moe_tp_size': self.moe_tp_size,
|
||||||
'moe_ep_size': self.moe_ep_size,
|
'moe_ep_size': self.moe_ep_size,
|
||||||
|
'attn_tp_size': self.attn_tp_size,
|
||||||
|
'attn_cp_size': self.attn_cp_size,
|
||||||
'auto_parallel': self.auto_parallel,
|
'auto_parallel': self.auto_parallel,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1850,6 +1850,7 @@ class Fp8RowwiseAttention(Module):
|
|||||||
self.attention_mask_type = attention_mask_type
|
self.attention_mask_type = attention_mask_type
|
||||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||||
self.num_attention_heads = num_attention_heads // tp_size
|
self.num_attention_heads = num_attention_heads // tp_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
self.num_attention_kv_heads = (
|
self.num_attention_kv_heads = (
|
||||||
num_kv_heads + tp_size - 1
|
num_kv_heads + tp_size - 1
|
||||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||||
@ -2010,6 +2011,7 @@ class Fp8RowwiseAttention(Module):
|
|||||||
layer_idx=self.local_layer_idx,
|
layer_idx=self.local_layer_idx,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
num_kv_heads=self.num_attention_kv_heads,
|
num_kv_heads=self.num_attention_kv_heads,
|
||||||
|
num_kv_heads_origin=self.num_kv_heads,
|
||||||
hidden_size_per_head=self.attention_head_size,
|
hidden_size_per_head=self.attention_head_size,
|
||||||
q_scaling=self.q_scaling,
|
q_scaling=self.q_scaling,
|
||||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||||
@ -2467,6 +2469,7 @@ class SmoothQuantAttention(Module):
|
|||||||
self.attention_mask_type = attention_mask_type
|
self.attention_mask_type = attention_mask_type
|
||||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||||
self.num_attention_heads = num_attention_heads // tp_size
|
self.num_attention_heads = num_attention_heads // tp_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
self.num_attention_kv_heads = (
|
self.num_attention_kv_heads = (
|
||||||
num_kv_heads + tp_size - 1
|
num_kv_heads + tp_size - 1
|
||||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||||
@ -2634,6 +2637,7 @@ class SmoothQuantAttention(Module):
|
|||||||
layer_idx=self.local_layer_idx,
|
layer_idx=self.local_layer_idx,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
num_kv_heads=self.num_attention_kv_heads,
|
num_kv_heads=self.num_attention_kv_heads,
|
||||||
|
num_kv_heads_origin=self.num_kv_heads,
|
||||||
hidden_size_per_head=self.attention_head_size,
|
hidden_size_per_head=self.attention_head_size,
|
||||||
q_scaling=self.q_scaling,
|
q_scaling=self.q_scaling,
|
||||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||||
@ -2987,6 +2991,7 @@ class QServeAttention(Module):
|
|||||||
self.attention_mask_type = attention_mask_type
|
self.attention_mask_type = attention_mask_type
|
||||||
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
||||||
self.num_attention_heads = num_attention_heads // tp_size
|
self.num_attention_heads = num_attention_heads // tp_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
self.num_attention_kv_heads = (
|
self.num_attention_kv_heads = (
|
||||||
num_kv_heads + tp_size - 1
|
num_kv_heads + tp_size - 1
|
||||||
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
||||||
@ -3154,6 +3159,7 @@ class QServeAttention(Module):
|
|||||||
layer_idx=self.local_layer_idx,
|
layer_idx=self.local_layer_idx,
|
||||||
num_heads=self.num_attention_heads,
|
num_heads=self.num_attention_heads,
|
||||||
num_kv_heads=self.num_attention_kv_heads,
|
num_kv_heads=self.num_attention_kv_heads,
|
||||||
|
num_kv_heads_origin=self.num_kv_heads,
|
||||||
hidden_size_per_head=self.attention_head_size,
|
hidden_size_per_head=self.attention_head_size,
|
||||||
q_scaling=self.q_scaling,
|
q_scaling=self.q_scaling,
|
||||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||||
|
|||||||
@ -905,6 +905,8 @@ def quantize_and_export(*,
|
|||||||
with open(f"{export_path}/config.json", "r") as f:
|
with open(f"{export_path}/config.json", "r") as f:
|
||||||
tensorrt_llm_config = json.load(f)
|
tensorrt_llm_config = json.load(f)
|
||||||
tensorrt_llm_config["mapping"]["cp_size"] = cp_size
|
tensorrt_llm_config["mapping"]["cp_size"] = cp_size
|
||||||
|
tensorrt_llm_config["mapping"]["attn_tp_size"] = -1
|
||||||
|
tensorrt_llm_config["mapping"]["attn_cp_size"] = -1
|
||||||
tensorrt_llm_config["mapping"]["world_size"] *= cp_size
|
tensorrt_llm_config["mapping"]["world_size"] *= cp_size
|
||||||
with open(f"{export_path}/config.json", "w") as f:
|
with open(f"{export_path}/config.json", "w") as f:
|
||||||
json.dump(tensorrt_llm_config, f, indent=4)
|
json.dump(tensorrt_llm_config, f, indent=4)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user