mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
357 lines
12 KiB
C++
357 lines
12 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
|
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <cstdio>
|
|
|
|
#include "cpSplitPlugin.h"
|
|
|
|
using namespace nvinfer1;
|
|
using namespace tensorrt_llm::common;
|
|
using tensorrt_llm::plugins::CpSplitPluginCreator;
|
|
using tensorrt_llm::plugins::CpSplitPlugin;
|
|
|
|
static char const* CPSPLIT_PLUGIN_VERSION{"1"};
|
|
static char const* CPSPLIT_PLUGIN_NAME{"CpSplit"};
|
|
PluginFieldCollection CpSplitPluginCreator::mFC{};
|
|
std::vector<nvinfer1::PluginField> CpSplitPluginCreator::mPluginAttributes;
|
|
|
|
CpSplitPlugin::CpSplitPlugin()
|
|
{
|
|
initFieldsToSerialize();
|
|
}
|
|
|
|
CpSplitPlugin::CpSplitPlugin(int cpSize, int cpRank)
|
|
: mCpSize(cpSize)
|
|
, mCpRank(cpRank)
|
|
{
|
|
initFieldsToSerialize();
|
|
}
|
|
|
|
void CpSplitPlugin::initFieldsToSerialize()
|
|
{
|
|
mDataToSerialize.clear();
|
|
mDataToSerialize.emplace_back(PluginField("cp_size", &mCpSize, PluginFieldType::kINT32, 1));
|
|
mDataToSerialize.emplace_back(PluginField("cp_rank", &mCpRank, PluginFieldType::kINT32, 1));
|
|
mFCToSerialize.nbFields = mDataToSerialize.size();
|
|
mFCToSerialize.fields = mDataToSerialize.data();
|
|
}
|
|
|
|
// IPluginV3 methods
|
|
nvinfer1::IPluginCapability* CpSplitPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept
|
|
{
|
|
switch (type)
|
|
{
|
|
case PluginCapabilityType::kBUILD: return static_cast<IPluginV3OneBuild*>(this);
|
|
case PluginCapabilityType::kRUNTIME: return static_cast<IPluginV3OneRuntime*>(this);
|
|
case PluginCapabilityType::kCORE: return static_cast<IPluginV3OneCore*>(this);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
nvinfer1::IPluginV3* CpSplitPlugin::clone() noexcept
|
|
{
|
|
std::unique_ptr<CpSplitPlugin> plugin{std::make_unique<CpSplitPlugin>(*this)};
|
|
plugin->setPluginNamespace(mNamespace.c_str());
|
|
plugin->initFieldsToSerialize();
|
|
return plugin.release();
|
|
}
|
|
|
|
// IPluginV3OneCore methods
|
|
char const* CpSplitPlugin::getPluginName() const noexcept
|
|
{
|
|
return CPSPLIT_PLUGIN_NAME;
|
|
}
|
|
|
|
char const* CpSplitPlugin::getPluginVersion() const noexcept
|
|
{
|
|
return CPSPLIT_PLUGIN_VERSION;
|
|
}
|
|
|
|
// IPluginV3OneBuild methods
|
|
int32_t CpSplitPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
|
|
nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::getOutputDataTypes(
|
|
DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept
|
|
{
|
|
outputTypes[0] = inputTypes[0];
|
|
outputTypes[1] = DataType::kINT32;
|
|
outputTypes[2] = DataType::kINT32;
|
|
return 0;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
|
|
int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
|
|
{
|
|
outputs[0].nbDims = 1;
|
|
|
|
auto cpSize = exprBuilder.constant(mCpSize);
|
|
auto upper = inputs[0].d[0];
|
|
auto opt = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *upper, *cpSize);
|
|
outputs[0].d[0] = exprBuilder.declareSizeTensor(1, *opt, *upper);
|
|
|
|
// We must have such an output size tensor (with dim == 0) to notify the shape of output tensor above
|
|
outputs[1].nbDims = 0;
|
|
outputs[2].nbDims = 1;
|
|
outputs[2].d[0] = upper;
|
|
return 0;
|
|
}
|
|
|
|
bool CpSplitPlugin::supportsFormatCombination(
|
|
int32_t pos, nvinfer1::DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
|
|
{
|
|
if (pos == IdxEntry::INPUT_IDS)
|
|
{
|
|
return ((inOut[pos].desc.type == DataType::kINT32) && (inOut[pos].desc.format == TensorFormat::kLINEAR));
|
|
}
|
|
else if (pos == IdxEntry::REQUEST_TYPES || pos == IdxEntry::HOST_CONTEXT_LENGTH)
|
|
{
|
|
return inOut[pos].desc.type == DataType::kINT32;
|
|
}
|
|
else
|
|
{
|
|
return ((inOut[pos].desc.type == DataType::kINT32) && (inOut[pos].desc.format == TensorFormat::kLINEAR));
|
|
}
|
|
return false;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::getNbOutputs() const noexcept
|
|
{
|
|
return 3;
|
|
}
|
|
|
|
size_t CpSplitPlugin::getWorkspaceSize(nvinfer1::DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
|
|
nvinfer1::DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::getValidTactics(int32_t* tactics, int32_t nbTactics) noexcept
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::getNbTactics() noexcept
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
char const* CpSplitPlugin::getTimingCacheID() noexcept
|
|
{
|
|
return nullptr;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::getFormatCombinationLimit() noexcept
|
|
{
|
|
return 1;
|
|
}
|
|
|
|
char const* CpSplitPlugin::getMetadataString() noexcept
|
|
{
|
|
return nullptr;
|
|
}
|
|
|
|
// IPluginV3OneRuntime methods
|
|
int32_t CpSplitPlugin::setTactic(int32_t tactic) noexcept
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::onShapeChange(nvinfer1::PluginTensorDesc const* in, int32_t nbInputs,
|
|
nvinfer1::PluginTensorDesc const* out, int32_t nbOutputs) noexcept
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
int32_t CpSplitPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
|
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
|
cudaStream_t stream) noexcept
|
|
{
|
|
// inputs
|
|
// @param inputIds [tokenNum]
|
|
// @param host_request_types [batchSize]: Tensor = None (On CPU)
|
|
// The tensor on the host that indicates if a request is in context or
|
|
// generation phase. Its shape is [batch_size]. See Inflight Batching
|
|
// in docs/gpt_attention.md,
|
|
// @param host_context_lengths [batchSize]: Tensor = None (On CPU)
|
|
// A host tensor that contains the lengths of the different inputs
|
|
// outputs
|
|
// @param outputIds [tokenNum spiltted by cp]
|
|
// @param outputLength scalar
|
|
// @param joinIdx [tokenNum]
|
|
|
|
int64_t tokenNum = 1;
|
|
for (int i = 0; i < inputDesc[0].dims.nbDims; ++i)
|
|
{
|
|
tokenNum *= inputDesc[0].dims.d[i];
|
|
}
|
|
|
|
RequestType const* reqTypes = static_cast<RequestType const*>(inputs[IdxEntry::REQUEST_TYPES]);
|
|
int32_t const* hContextLengths = static_cast<int32_t const*>(inputs[IdxEntry::HOST_CONTEXT_LENGTH]);
|
|
int const* inputIds = reinterpret_cast<int const*>(inputs[IdxEntry::INPUT_IDS]);
|
|
int* outputIds = reinterpret_cast<int*>(outputs[0]);
|
|
int32_t* outputLength = reinterpret_cast<int32_t*>(outputs[1]);
|
|
int32_t* outputJoinIdx = reinterpret_cast<int32_t*>(outputs[2]);
|
|
|
|
int32_t const nbSeq = inputDesc[IdxEntry::HOST_CONTEXT_LENGTH].dims.d[0];
|
|
|
|
int32_t* hInputs = new int[inputDesc[IdxEntry::INPUT_IDS].dims.d[0]];
|
|
int32_t* hOutputs = new int[inputDesc[IdxEntry::INPUT_IDS].dims.d[0]];
|
|
int32_t* hOutputJoinIdx = new int[inputDesc[IdxEntry::INPUT_IDS].dims.d[0]];
|
|
cudaMemcpyAsync(
|
|
hInputs, inputIds, sizeof(int32_t) * inputDesc[IdxEntry::INPUT_IDS].dims.d[0], cudaMemcpyDeviceToHost, stream);
|
|
sync_check_cuda_error(stream);
|
|
|
|
int32_t inputIdx = 0;
|
|
int32_t outputIdx = 0;
|
|
for (int32_t seqIdx = 0; seqIdx < nbSeq; ++seqIdx)
|
|
{
|
|
if (reqTypes[seqIdx] == RequestType::kCONTEXT)
|
|
{
|
|
auto const& ctxLength = hContextLengths[seqIdx];
|
|
int32_t partialAverageLength = (ctxLength + mCpSize - 1) / mCpSize;
|
|
int32_t partialLength
|
|
= mCpRank == mCpSize - 1 ? ctxLength - partialAverageLength * (mCpSize - 1) : partialAverageLength;
|
|
for (int i = 0; i < partialLength; i++)
|
|
{
|
|
hOutputs[outputIdx + i] = hInputs[inputIdx + partialAverageLength * mCpRank + i];
|
|
}
|
|
inputIdx += ctxLength;
|
|
outputIdx += partialAverageLength;
|
|
}
|
|
else if (reqTypes[seqIdx] == RequestType::kGENERATION)
|
|
{
|
|
auto const& genLength = nbSeq - seqIdx;
|
|
int32_t partialAverageLength = (genLength + mCpSize - 1) / mCpSize;
|
|
int32_t partialLength
|
|
= mCpRank == mCpSize - 1 ? genLength - partialAverageLength * (mCpSize - 1) : partialAverageLength;
|
|
for (int i = 0; i < partialLength; i++)
|
|
{
|
|
hOutputs[outputIdx + i] = hInputs[inputIdx + partialAverageLength * mCpRank + i];
|
|
}
|
|
outputIdx += partialAverageLength;
|
|
break;
|
|
}
|
|
}
|
|
int32_t hOutputLength = outputIdx;
|
|
inputIdx = 0;
|
|
outputIdx = 0;
|
|
for (int32_t seqIdx = 0; seqIdx < nbSeq; ++seqIdx)
|
|
{
|
|
if (reqTypes[seqIdx] == RequestType::kCONTEXT)
|
|
{
|
|
auto const& ctxLength = hContextLengths[seqIdx];
|
|
int32_t partialAverageLength = (ctxLength + mCpSize - 1) / mCpSize;
|
|
for (int32_t idx = 0; idx < ctxLength; ++idx)
|
|
{
|
|
hOutputJoinIdx[inputIdx + idx]
|
|
= idx % partialAverageLength + idx / partialAverageLength * hOutputLength + outputIdx;
|
|
}
|
|
inputIdx += ctxLength;
|
|
outputIdx += partialAverageLength;
|
|
}
|
|
else if (reqTypes[seqIdx] == RequestType::kGENERATION)
|
|
{
|
|
auto const& genLength = nbSeq - seqIdx;
|
|
int32_t partialAverageLength = (genLength + mCpSize - 1) / mCpSize;
|
|
for (int32_t idx = 0; idx < genLength; ++idx)
|
|
{
|
|
hOutputJoinIdx[inputIdx + idx]
|
|
= idx % partialAverageLength + idx / partialAverageLength * hOutputLength + outputIdx;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
cudaMemcpyAsync(outputIds, hOutputs, sizeof(int32_t) * hOutputLength, cudaMemcpyHostToDevice, stream);
|
|
cudaMemcpyAsync(outputLength, &hOutputLength, sizeof(int32_t), cudaMemcpyHostToDevice, stream);
|
|
cudaMemcpyAsync(outputJoinIdx, hOutputJoinIdx, sizeof(int32_t) * tokenNum, cudaMemcpyHostToDevice, stream);
|
|
sync_check_cuda_error(stream);
|
|
return 0;
|
|
}
|
|
|
|
nvinfer1::IPluginV3* CpSplitPlugin::attachToContext(nvinfer1::IPluginResourceContext* context) noexcept
|
|
{
|
|
return clone();
|
|
}
|
|
|
|
nvinfer1::PluginFieldCollection const* CpSplitPlugin::getFieldsToSerialize() noexcept
|
|
{
|
|
return &mFCToSerialize;
|
|
}
|
|
|
|
CpSplitPluginCreator::CpSplitPluginCreator()
|
|
{
|
|
// Fill PluginFieldCollection with PluginField arguments metadata
|
|
mPluginAttributes.clear();
|
|
mPluginAttributes.emplace_back(PluginField("cp_size", nullptr, PluginFieldType::kINT32));
|
|
mPluginAttributes.emplace_back(PluginField("cp_rank", nullptr, PluginFieldType::kINT32));
|
|
mFC.nbFields = mPluginAttributes.size();
|
|
mFC.fields = mPluginAttributes.data();
|
|
}
|
|
|
|
char const* CpSplitPluginCreator::getPluginName() const noexcept
|
|
{
|
|
return CPSPLIT_PLUGIN_NAME;
|
|
}
|
|
|
|
char const* CpSplitPluginCreator::getPluginVersion() const noexcept
|
|
{
|
|
return CPSPLIT_PLUGIN_VERSION;
|
|
}
|
|
|
|
nvinfer1::PluginFieldCollection const* CpSplitPluginCreator::getFieldNames() noexcept
|
|
{
|
|
return &mFC;
|
|
}
|
|
|
|
nvinfer1::IPluginV3* CpSplitPluginCreator::createPlugin(
|
|
char const* name, nvinfer1::PluginFieldCollection const* fc, nvinfer1::TensorRTPhase phase) noexcept
|
|
{
|
|
PluginField const* fields = fc->fields;
|
|
int cp_size{};
|
|
int cp_rank{};
|
|
// Read configurations from each fields
|
|
for (int i = 0; i < fc->nbFields; ++i)
|
|
{
|
|
char const* attrName = fields[i].name;
|
|
if (!strcmp(attrName, "cp_size"))
|
|
{
|
|
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
|
cp_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
|
}
|
|
else if (!strcmp(attrName, "cp_rank"))
|
|
{
|
|
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
|
cp_rank = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
|
}
|
|
}
|
|
try
|
|
{
|
|
auto* obj = new CpSplitPlugin(cp_size, cp_rank);
|
|
obj->setPluginNamespace(mNamespace.c_str());
|
|
return obj;
|
|
}
|
|
catch (std::exception const& e)
|
|
{
|
|
caughtError(e);
|
|
}
|
|
return nullptr;
|
|
}
|