mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1918)
This commit is contained in:
parent
9dbc5b38ba
commit
a96cccafcf
@ -172,16 +172,16 @@ struct BenchmarkParams
|
||||
std::optional<std::vector<std::vector<SizeType32>>> medusaChoices;
|
||||
};
|
||||
|
||||
class InferenceRequestsSyncSend
|
||||
class InferenceRequestsAsyncSend
|
||||
{
|
||||
public:
|
||||
InferenceRequestsSyncSend(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
|
||||
InferenceRequestsAsyncSend(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
|
||||
std::list<std::shared_ptr<InferenceRequest>> const& inferenceRequests, int const peer)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_DEBUG("start send requests to rank %d", peer);
|
||||
mNumNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
|
||||
comm->send(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
|
||||
mRequest1 = comm->sendAsync(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
|
||||
if (mNumNewWorkItems > 0)
|
||||
{
|
||||
for (auto const& infReq : inferenceRequests)
|
||||
@ -191,16 +191,31 @@ public:
|
||||
mPacked.insert(mPacked.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
}
|
||||
mVecSize = static_cast<int64_t>(mPacked.size());
|
||||
comm->send(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1);
|
||||
comm->send(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2);
|
||||
mRequest2 = comm->sendAsync(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1);
|
||||
mRequest3 = comm->sendAsync(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
~InferenceRequestsAsyncSend()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
mRequest1->wait();
|
||||
if (mRequest2)
|
||||
mRequest2->wait();
|
||||
if (mRequest3)
|
||||
mRequest3->wait();
|
||||
TLLM_LOG_DEBUG("end send requests");
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t mNumNewWorkItems;
|
||||
int64_t mVecSize;
|
||||
std::vector<int64_t> mPacked;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest1;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest2;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest3;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -930,7 +945,6 @@ public:
|
||||
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
||||
, mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
|
||||
, mActiveCount(0)
|
||||
, mInferReqSyncSndHdl(nullptr)
|
||||
{
|
||||
auto const jsonConfig = GptJsonConfig::parse(trtEnginePath / "config.json");
|
||||
mWorldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
|
||||
@ -966,6 +980,12 @@ public:
|
||||
|
||||
~GptServer()
|
||||
{
|
||||
if (mInferReqWaitThread)
|
||||
{
|
||||
mInferReqWaitThread->join();
|
||||
mInferReqWaitThread.reset(nullptr);
|
||||
}
|
||||
|
||||
mWorkItemsQueue.clear();
|
||||
}
|
||||
|
||||
@ -1031,7 +1051,11 @@ public:
|
||||
// Return up to max_num_requests inference requests.
|
||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
|
||||
{
|
||||
mInferReqSyncSndHdl = nullptr;
|
||||
if (mInferReqWaitThread)
|
||||
{
|
||||
mInferReqWaitThread->join();
|
||||
mInferReqWaitThread.reset(nullptr);
|
||||
}
|
||||
std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
|
||||
auto& comm = COMM_SESSION;
|
||||
if (max_num_requests > 0)
|
||||
@ -1134,8 +1158,9 @@ public:
|
||||
if (!mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto const peer = mWorldConfig.getPipelineParallelRank() + 1;
|
||||
mInferReqSyncSndHdl
|
||||
= std::make_shared<InferenceRequestsSyncSend>(mCommPipelineParallel, inferenceRequests, peer);
|
||||
auto inferReqAsyncSndHdl
|
||||
= std::make_unique<InferenceRequestsAsyncSend>(mCommPipelineParallel, inferenceRequests, peer);
|
||||
mInferReqWaitThread = std::make_unique<std::thread>([handle = std::move(inferReqAsyncSndHdl)]() {});
|
||||
}
|
||||
}
|
||||
return inferenceRequests;
|
||||
@ -1184,7 +1209,7 @@ private:
|
||||
WorldConfig mWorldConfig;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommTensorParallel;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommPipelineParallel;
|
||||
std::shared_ptr<InferenceRequestsSyncSend> mInferReqSyncSndHdl;
|
||||
std::unique_ptr<std::thread> mInferReqWaitThread;
|
||||
|
||||
}; // class GptServer
|
||||
|
||||
|
||||
@ -60,14 +60,20 @@ class BuildConfig:
|
||||
parallel_attention: bool = None
|
||||
new_decoder_architecture: bool = None
|
||||
state_size: int = 0
|
||||
state_dtype: Optional[str] = None
|
||||
state_dtype: Optional[str] = ""
|
||||
conv_kernel: int = 0
|
||||
layer_types: List[str] = field(default_factory=list)
|
||||
rnn_hidden_size: int = 0
|
||||
rnn_head_size: int = 0
|
||||
rnn_conv_dim_size: int = 0
|
||||
logits_soft_cap: float = 0.0
|
||||
opt_batch_size: Optional[int] = None
|
||||
opt_num_tokens: Optional[int] = None
|
||||
use_bias: bool = None
|
||||
mamba_version: str = 'Mamba1'
|
||||
ssm_rmsnorm: bool = True
|
||||
ngroups: int = 1
|
||||
chunk_size: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1218,6 +1224,7 @@ _allowed_configs = {
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=5120,
|
||||
rnn_conv_dim_size=5120,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
)),
|
||||
@ -1238,6 +1245,7 @@ _allowed_configs = {
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=4096,
|
||||
rnn_conv_dim_size=4096,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
)),
|
||||
@ -1258,6 +1266,7 @@ _allowed_configs = {
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=3072,
|
||||
rnn_conv_dim_size=3072,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
)),
|
||||
@ -1278,6 +1287,7 @@ _allowed_configs = {
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=2048,
|
||||
rnn_conv_dim_size=2048,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
)),
|
||||
@ -1298,9 +1308,62 @@ _allowed_configs = {
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=1536,
|
||||
rnn_conv_dim_size=1536,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
)),
|
||||
"mamba2_2.7b":
|
||||
ModelConfig(name="mamba2_2.7b",
|
||||
family="mamba",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=64,
|
||||
num_heads=1,
|
||||
hidden_size=2560,
|
||||
vocab_size=50288,
|
||||
hidden_act="silu",
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=128,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=5120,
|
||||
rnn_conv_dim_size=5376,
|
||||
rnn_head_size=64,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
mamba_version='Mamba2',
|
||||
ssm_rmsnorm=True,
|
||||
ngroups=1,
|
||||
chunk_size=256,
|
||||
)),
|
||||
"mamba2_130m":
|
||||
ModelConfig(name="mamba2_130m",
|
||||
family="mamba",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=24,
|
||||
num_heads=1,
|
||||
hidden_size=768,
|
||||
vocab_size=50288,
|
||||
hidden_act="silu",
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=128,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=1536,
|
||||
rnn_conv_dim_size=1792,
|
||||
rnn_head_size=64,
|
||||
layer_types=["recurrent"],
|
||||
use_bias=False,
|
||||
mamba_version='Mamba2',
|
||||
ssm_rmsnorm=True,
|
||||
ngroups=1,
|
||||
chunk_size=256,
|
||||
)),
|
||||
"whisper_large_v3":
|
||||
ModelConfig(name="whisper_large_v3",
|
||||
family="whisper",
|
||||
@ -1344,6 +1407,7 @@ _allowed_configs = {
|
||||
state_size=1,
|
||||
layer_types=["recurrent", "recurrent", "attention"],
|
||||
rnn_hidden_size=2560,
|
||||
rnn_conv_dim_size=2560,
|
||||
logits_soft_cap=30.0,
|
||||
state_dtype="float32",
|
||||
)),
|
||||
|
||||
@ -295,7 +295,8 @@ def build_gpt(args):
|
||||
builder_config_extra_kwargs = {}
|
||||
extra_items = [
|
||||
'layer_types', 'conv_kernel', 'rnn_hidden_size', 'logits_soft_cap',
|
||||
'state_size', 'use_bias'
|
||||
'state_size', 'use_bias', 'rnn_head_size', 'rnn_conv_dim_size',
|
||||
'mamba_version', 'ssm_rmsnorm', 'ngroups', 'chunk_size'
|
||||
]
|
||||
for item in extra_items:
|
||||
if item in build_config:
|
||||
@ -876,10 +877,16 @@ def build_gpt(args):
|
||||
'state_size': build_config['state_size'],
|
||||
'conv_kernel': build_config['conv_kernel'],
|
||||
'rnn_hidden_size': build_config['rnn_hidden_size'],
|
||||
'rnn_head_size': build_config['rnn_head_size'],
|
||||
'rnn_conv_dim_size': build_config['rnn_conv_dim_size'],
|
||||
'rms_norm': True,
|
||||
'residual_in_fp32': True,
|
||||
'pad_vocab_size_multiple': 8,
|
||||
'use_bias': build_config['use_bias'],
|
||||
'mamba_version': build_config['mamba_version'],
|
||||
'ssm_rmsnorm': build_config['ssm_rmsnorm'],
|
||||
'ngroups': build_config['ngroups'],
|
||||
'chunk_size': build_config['chunk_size'],
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.MambaForCausalLM(config)
|
||||
@ -912,6 +919,8 @@ def build_gpt(args):
|
||||
'state_size': build_config['state_size'],
|
||||
'layer_types': build_config['layer_types'],
|
||||
'rnn_hidden_size': build_config['rnn_hidden_size'],
|
||||
'rnn_head_size': build_config['rnn_head_size'],
|
||||
'rnn_conv_dim_size': build_config['rnn_conv_dim_size'],
|
||||
'logits_soft_cap': build_config['logits_soft_cap'],
|
||||
'rotary_pct': build_config['rotary_pct'],
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
|
||||
rnn_config_items = [
|
||||
'conv_kernel', 'layer_types', 'rnn_hidden_size', 'state_size',
|
||||
'state_dtype'
|
||||
'state_dtype', 'rnn_head_size', 'rnn_conv_dim_size'
|
||||
]
|
||||
rnn_configs_kwargs = {}
|
||||
for item in rnn_config_items:
|
||||
|
||||
@ -116,7 +116,7 @@ public:
|
||||
uint64_t requestId, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: mRequestId{requestId}
|
||||
, mIsStreaming{false}
|
||||
, mlogitsPostProcessor(logitsPostProcessor)
|
||||
, mLogitsPostProcessor(logitsPostProcessor)
|
||||
{
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ public:
|
||||
: mRequestId{requestId}
|
||||
, mIsStreaming{false}
|
||||
, mInputTensors{std::move(tensorMap)}
|
||||
, mlogitsPostProcessor(logitsPostProcessor)
|
||||
, mLogitsPostProcessor(logitsPostProcessor)
|
||||
{
|
||||
for (auto const& [name, tensor] : mInputTensors)
|
||||
{
|
||||
@ -161,12 +161,12 @@ public:
|
||||
|
||||
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> cb)
|
||||
{
|
||||
mlogitsPostProcessor = cb;
|
||||
mLogitsPostProcessor = cb;
|
||||
}
|
||||
|
||||
std::optional<LogitsPostProcessor> getLogitsPostProcessor()
|
||||
{
|
||||
return mlogitsPostProcessor;
|
||||
return mLogitsPostProcessor;
|
||||
}
|
||||
|
||||
static std::array constexpr kTensorNames = {
|
||||
@ -280,7 +280,7 @@ protected:
|
||||
uint64_t mRequestId;
|
||||
bool mIsStreaming;
|
||||
TensorMap mInputTensors;
|
||||
std::optional<LogitsPostProcessor> mlogitsPostProcessor;
|
||||
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
|
||||
};
|
||||
|
||||
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>
|
||||
|
||||
@ -248,16 +248,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
|
||||
{
|
||||
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<int> const& getNumPrepopulatedTokens() const
|
||||
{
|
||||
return mNumPrepopulatedTokens;
|
||||
}
|
||||
|
||||
private:
|
||||
// Slot id of the sequence
|
||||
SizeType32 mSeqSlotIdx;
|
||||
@ -267,10 +257,6 @@ private:
|
||||
SizeType32 mBeamWidth;
|
||||
// List of blocks allocated for each beam of the sequence
|
||||
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
|
||||
// Number of tokens already in kv cache before context phase.
|
||||
// A value > 0 indicates cached kv cache blocks were reused.
|
||||
// One value per beam.
|
||||
std::vector<int> mNumPrepopulatedTokens;
|
||||
};
|
||||
|
||||
// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
|
||||
@ -400,7 +386,10 @@ public:
|
||||
|
||||
private:
|
||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx, SizeType32 seqSlotIdx);
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
|
||||
|
||||
//! \brief Add single block to all beams of sequence.
|
||||
void addBlockToAllBeams(BlockPtr& block, GenerationRequest& sequence);
|
||||
|
||||
//! \brief Store blocks in cached blocks.
|
||||
//! \param blockedTokens Tokens of each block.
|
||||
@ -410,11 +399,8 @@ private:
|
||||
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
|
||||
//! \param blockedTokens Tokens of each block.
|
||||
//! \param sequence Sequence to which blocks are assigned.
|
||||
//! \param beamIdx Beam of sequence to which blocks are assigned.
|
||||
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
|
||||
//! \return Number of matched tokens from loaded blocks.
|
||||
SizeType32 loadOrAllocateBlocks(std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence,
|
||||
SizeType32 beamIdx, SizeType32 seqSlotIdx);
|
||||
SizeType32 loadOrAllocateBlocks(std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence);
|
||||
|
||||
//! \brief Find best primary block to free.
|
||||
//! \details The best primary block to free is the primary block that appears first in the queue and have no primary
|
||||
@ -598,12 +584,6 @@ public:
|
||||
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
||||
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
[[nodiscard]] SizeType32 getNumPrepopulatedTokens(SizeType32 batchSlotIdx, SizeType32 beamIdx) const
|
||||
{
|
||||
auto const& prepopulatedTokens = mSequences.at(batchSlotIdx)->getNumPrepopulatedTokens();
|
||||
return prepopulatedTokens.size() > 0 ? prepopulatedTokens.at(beamIdx) : 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isEnableBlockReuse() const
|
||||
{
|
||||
return mEnableBlockReuse;
|
||||
|
||||
@ -84,7 +84,6 @@ public:
|
||||
, mSamplingConfig(samplingConfig)
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
, mIsStreaming(isStreaming)
|
||||
, mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1))
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mLogitsPostProcessor(logitsPostProcessor)
|
||||
@ -127,7 +126,6 @@ public:
|
||||
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
, mIsStreaming(req.getStreaming())
|
||||
, mReturnAllGeneratedTokens(req.getReturnAllGeneratedTokens())
|
||||
, mEndId(req.getEndId())
|
||||
, mPadId(req.getPadId())
|
||||
, mOrigPromptLen(mPromptLen)
|
||||
@ -154,16 +152,6 @@ public:
|
||||
, mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput)
|
||||
, mDecodingIter(0)
|
||||
{
|
||||
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnAllGeneratedTokens == false)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Setting mReturnAllGeneratedTokens to True since streaming AND beam search are done simultaneously. "
|
||||
"Returning the full beams at each streaming step is needed because beam search + streaming can change "
|
||||
"previous outputs. Initialize request with mReturnAllGeneratedTokens = True to dismiss this error."
|
||||
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output "
|
||||
"length).");
|
||||
mReturnAllGeneratedTokens = true;
|
||||
}
|
||||
if (req.getEncoderInputTokenIds())
|
||||
{
|
||||
mState = REQUEST_STATE_ENCODER_INIT;
|
||||
@ -575,6 +563,16 @@ public:
|
||||
return mOrigPromptLen;
|
||||
}
|
||||
|
||||
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen)
|
||||
{
|
||||
mPrepopulatedPromptLen = prepopulatedPromptLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
|
||||
{
|
||||
return mPrepopulatedPromptLen;
|
||||
}
|
||||
|
||||
void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
|
||||
{
|
||||
mDraftTokens = draftTokens;
|
||||
@ -585,7 +583,7 @@ public:
|
||||
mDraftLogits = draftLogits;
|
||||
}
|
||||
|
||||
SizeType32 getNumDraftTokens() const
|
||||
[[nodiscard]] SizeType32 getNumDraftTokens() const
|
||||
{
|
||||
return mDraftTokens->size();
|
||||
}
|
||||
@ -604,7 +602,7 @@ public:
|
||||
mNumTokensPerIteration = numTokensPerIteration;
|
||||
}
|
||||
|
||||
SizeType32 getNumTokensPerIteration() const
|
||||
[[nodiscard]] SizeType32 getNumTokensPerIteration() const
|
||||
{
|
||||
return mNumTokensPerIteration;
|
||||
}
|
||||
@ -883,22 +881,16 @@ public:
|
||||
// FIXME(nkorobov): For streaming we do not allow beam search and
|
||||
// streaming index calculation here applies only for sampling
|
||||
// getNumTokensPerIteration takes accepted draft tokens into account
|
||||
auto nbTokensOut
|
||||
= (mReturnAllGeneratedTokens || !mIsStreaming) ? maxNbTokens : std::max(getNumTokensPerIteration(), 1);
|
||||
|
||||
int nbTokensOut = mIsStreaming ? std::max(getNumTokensPerIteration(), 1) : maxNbTokens;
|
||||
if (mExcludeInputFromOutput && !mIsStreaming)
|
||||
{
|
||||
nbTokensOut -= getOrigPromptLen();
|
||||
}
|
||||
|
||||
result.outputTokenIds.resize(nbBeams);
|
||||
SizeType32 tokenPos = maxNbTokens - nbTokensOut;
|
||||
|
||||
// in the case of streaming + beam search
|
||||
// we need to return the full beams at all iterations
|
||||
|
||||
SizeType32 tokenPos{maxNbTokens - nbTokensOut};
|
||||
auto const shouldSendResponse = isGenerationCompleteState()
|
||||
|| (mIsStreaming && tokenPos > getMaxSentTokenPos()) || mReturnAllGeneratedTokens;
|
||||
bool shouldSendResponse = isGenerationCompleteState() || (mIsStreaming && tokenPos > getMaxSentTokenPos());
|
||||
|
||||
if (!shouldSendResponse)
|
||||
{
|
||||
@ -909,8 +901,7 @@ public:
|
||||
for (SizeType32 beam = 0; beam < nbBeams; ++beam)
|
||||
{
|
||||
auto tokens = getTokens(beam);
|
||||
auto nbTokens = (mReturnAllGeneratedTokens || !mIsStreaming) ? tokens.size()
|
||||
: (tokenPos - getMaxSentTokenPos());
|
||||
auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size();
|
||||
|
||||
// Take accepted draft tokens into account when streaming
|
||||
auto const numAcceptedTokens = std::max(0, getNumTokensPerIteration() - 1);
|
||||
@ -982,8 +973,6 @@ public:
|
||||
runtime::SamplingConfig mSamplingConfig;
|
||||
LlmRequestState_t mState;
|
||||
bool mIsStreaming;
|
||||
// whether to return the full beams on each iteration. True when doing streaming + beamsearch
|
||||
bool mReturnAllGeneratedTokens;
|
||||
std::optional<TokenIdType> mEndId;
|
||||
std::optional<TokenIdType> mPadId;
|
||||
std::optional<SizeType32> mSeqSlot;
|
||||
@ -993,6 +982,10 @@ public:
|
||||
protected:
|
||||
BeamTokens mTokens;
|
||||
SizeType32 mOrigPromptLen;
|
||||
// Number of tokens already in KV cache before context phase.
|
||||
// A value > 0 indicates cached KV cache blocks were reused.
|
||||
// Up to inputLen - 1 tokens can be reused.
|
||||
SizeType32 mPrepopulatedPromptLen{0};
|
||||
SizeType32 mMaxSentTokenPos;
|
||||
|
||||
std::optional<TensorPtr> mEmbeddingBias;
|
||||
|
||||
@ -385,7 +385,7 @@ private:
|
||||
bool mFreeComm;
|
||||
};
|
||||
|
||||
void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED, bool forwardAbortToParent = false);
|
||||
void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_MULTIPLE, bool forwardAbortToParent = false);
|
||||
|
||||
} // namespace tensorrt_llm::mpi
|
||||
|
||||
|
||||
@ -252,8 +252,6 @@ public:
|
||||
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
|
||||
/// name provided to the ExecutorConfig.
|
||||
/// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models
|
||||
/// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens
|
||||
/// after every streaming step.
|
||||
Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||
@ -264,7 +262,7 @@ public:
|
||||
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
|
||||
std::optional<LoraConfig> loraConfig = std::nullopt,
|
||||
std::optional<std::string> logitsPostProcessorName = std::nullopt,
|
||||
std::optional<VecTokens> encoderInputTokenIds = std::nullopt, bool returnAllGeneratedTokens = false);
|
||||
std::optional<VecTokens> encoderInputTokenIds = std::nullopt);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
@ -290,7 +288,6 @@ public:
|
||||
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
|
||||
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
|
||||
[[nodiscard]] std::optional<VecTokens> getEncoderInputTokenIds() const;
|
||||
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
void setSamplingConfig(SamplingConfig const& config);
|
||||
@ -305,7 +302,6 @@ public:
|
||||
void setLoraConfig(LoraConfig const& loraConfig);
|
||||
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
|
||||
void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds);
|
||||
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
@ -175,7 +175,7 @@ public:
|
||||
{
|
||||
TLLM_CHECK(data.size() <= std::numeric_limits<DimType64>::max());
|
||||
}
|
||||
return of(data.data(), {static_cast<Shape::DimType64 const>(data.size())});
|
||||
return of(data.data(), {static_cast<Shape::DimType64>(data.size())});
|
||||
}
|
||||
|
||||
Tensor() noexcept = default;
|
||||
|
||||
@ -51,6 +51,8 @@ public:
|
||||
SizeType32 stateSize = 0;
|
||||
SizeType32 convKernel = 0;
|
||||
SizeType32 rnnHiddenSize = 0;
|
||||
SizeType32 rnnHeadSize = 0;
|
||||
SizeType32 rnnConvDimSize = 0;
|
||||
};
|
||||
|
||||
enum class LayerType : std::int32_t
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33f2d6b3e871b0a0e651883607887777fe03d6822f06e4154ffc7e35a8d5cc70
|
||||
size 3938416
|
||||
oid sha256:5804fde474d6489db29204259b7e6c368117acadb7fb6dc807868ee0391c458b
|
||||
size 3953206
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8412aa4ca15c232ced1cd4bdfcc54177c7b257aef493d50650c960e0fb527cfc
|
||||
size 4002178
|
||||
oid sha256:85802a0e66148acb17d017a64dd982287775ce7bf5aa4e8bb7e5466b3736c7ee
|
||||
size 4019734
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
d4aa7db860caf8feedb79a280aa70da3 libtensorrt_llm_batch_manager_static.a
|
||||
02b4363342ccea3e2abccc474f3506bb libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
0e1417f27d93de67940c1062cf230017cd8be5f1 commit
|
||||
00fb525bdf4ff217c16940540b2357c4 libtensorrt_llm_batch_manager_static.a
|
||||
97d2db7f62745001d871bc89fb38eed6 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:86f34c84883f1dfed04c6fb18811198da636e4457617a47db71f045cb3066eb4
|
||||
size 3825822
|
||||
oid sha256:33a724d7e9eabc358c0d674151d45cef8849ae702cc5f2f88b259299a8306574
|
||||
size 3842582
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c07c30d986591bbe93bb30d67fc8ebbba3eb55c5875ce939c3265151747656ae
|
||||
size 3782506
|
||||
oid sha256:490a93ff13a67949a30e279fc3df27456c7f5d4084158c3089befccf78118b7f
|
||||
size 3799140
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e0190b794e437fa6a0e2140e9446195413abde0dfbc5109423c790397fbb95a6
|
||||
size 22445474
|
||||
oid sha256:663a163c3177644ed86fa7a2145fe5e9dbf6f2f0ed06c96d367236da323a3432
|
||||
size 22523526
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8729077e2bfb9cf3f647cc6ca9be42a8953c0ddf58426485ae3bded76dc9d5c3
|
||||
size 1403008
|
||||
oid sha256:497b00031131c1dc705e848e52f3d43148f55505e37bdad97f4933b2c074469d
|
||||
size 1400502
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2b68c06565f1b3f795e070420c73d085c620b42c1c2131f9895d2687178a6b54
|
||||
size 1427780
|
||||
oid sha256:417978bdb5c19f97d9758475acacfa18a4038fc3c5a83f981b02ee220104e0c7
|
||||
size 1425792
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
db98ffd911c3c1dde3413e934ce8deb8 libtensorrt_llm_executor_static.a
|
||||
8dc57746aa2c29d8a2fa50196b552bc3 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
0e1417f27d93de67940c1062cf230017cd8be5f1 commit
|
||||
1df55ac2948ca7b7fe2d5e79934e660e libtensorrt_llm_executor_static.a
|
||||
ea1641928d184d117deec0696763b274 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cbc3a279681e877a982c0ebbdd0c13d7792d67a87bad0be125ec81bfe3f87399
|
||||
size 1454684
|
||||
oid sha256:d0441d473852d11f50bcf23f4934b38d7e4c6d4a42f057eb04beb8aea4211cac
|
||||
size 1451118
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa15303c38a748c4bf7b82e1f9c58cb63418efbd60bfede62820f5a62d65710a
|
||||
size 1381738
|
||||
oid sha256:dc8619f99cf5a2e04bdb1482f157a9852bd745e90cf9e03a7878f73ed07e5610
|
||||
size 1383936
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1a3e774d6700444b7164e1b31e26936ea6fcddc73e3e17bba1d8492c65a57b78
|
||||
size 14036486
|
||||
oid sha256:772d1b83e739b926729b99999fbb81768569ffb172c2e120665b2d31b987bb47
|
||||
size 14071986
|
||||
|
||||
329
cpp/tensorrt_llm/kernels/chunkScan/Cn.h
Normal file
329
cpp/tensorrt_llm/kernels/chunkScan/Cn.h
Normal file
@ -0,0 +1,329 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#ifdef __CUDACC__ // for CUDA
|
||||
#define FT_DEV_CEXPR __device__ __host__ inline constexpr
|
||||
#else
|
||||
#define FT_DEV_CEXPR inline constexpr
|
||||
#endif
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Cn: constant integer
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
template <auto value_>
|
||||
struct Cn : public std::integral_constant<decltype(value_), value_>
|
||||
{
|
||||
};
|
||||
|
||||
template <auto value_>
|
||||
constexpr auto cn = Cn<value_>();
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Operators for Cn
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
template <auto value_>
|
||||
FT_DEV_CEXPR auto operator+(Cn<value_>)
|
||||
{
|
||||
return cn<+value_>;
|
||||
}
|
||||
|
||||
template <auto value_>
|
||||
FT_DEV_CEXPR auto operator-(Cn<value_>)
|
||||
{
|
||||
return cn<-value_>;
|
||||
}
|
||||
|
||||
template <auto value_>
|
||||
FT_DEV_CEXPR auto operator!(Cn<value_>)
|
||||
{
|
||||
return cn<!value_>;
|
||||
}
|
||||
|
||||
template <auto value_>
|
||||
FT_DEV_CEXPR auto operator~(Cn<value_>)
|
||||
{
|
||||
return cn<~value_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator+(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ + b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator-(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ - b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator*(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ * b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator/(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ / b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator%(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ % b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator<<(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ << b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator>>(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ >> b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator<(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ < b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator<=(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ <= b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator>(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ > b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator>=(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ >= b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator==(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ == b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator!=(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<(a_ != b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator^(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ ^ b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator&(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ & b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator&&(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn < a_ && b_ > ;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator|(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<a_ | b_>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto operator||(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn < a_ || b_ > ;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator*(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator/(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator%(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator<<(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator>>(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator&(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator&&(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <class A_, auto b_>
|
||||
FT_DEV_CEXPR std::enable_if_t<b_ == 0, Cn<b_>> operator*(A_, Cn<b_>)
|
||||
{
|
||||
return cn<b_>;
|
||||
}
|
||||
|
||||
template <class A_, auto b_>
|
||||
FT_DEV_CEXPR std::enable_if_t<b_ == +1, Cn<decltype(b_)(0)>> operator%(A_, Cn<b_>)
|
||||
{
|
||||
return cn<decltype(b_)(0)>;
|
||||
}
|
||||
|
||||
template <class A_, auto b_>
|
||||
FT_DEV_CEXPR std::enable_if_t<b_ == -1, Cn<decltype(b_)(0)>> operator%(A_, Cn<b_>)
|
||||
{
|
||||
return cn<decltype(b_)(0)>;
|
||||
}
|
||||
|
||||
template <class A_, auto b_>
|
||||
FT_DEV_CEXPR std::enable_if_t<b_ == 0, Cn<b_>> operator&(A_, Cn<b_>)
|
||||
{
|
||||
return cn<b_>;
|
||||
}
|
||||
|
||||
template <class A_, auto b_>
|
||||
FT_DEV_CEXPR std::enable_if_t<b_ == 0, Cn<b_>> operator&&(A_, Cn<b_>)
|
||||
{
|
||||
return cn<b_>;
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// div_up & round_up
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
template <class T_>
|
||||
FT_DEV_CEXPR auto cexpr_abs(T_ a_) // abs is not constexpr until C++20
|
||||
{
|
||||
return a_ >= cn<0> ? +a_ : -a_;
|
||||
}
|
||||
|
||||
template <class T_, class U_>
|
||||
FT_DEV_CEXPR auto div_up(T_ a_, U_ b_)
|
||||
{
|
||||
auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>);
|
||||
|
||||
return tmp / b_;
|
||||
}
|
||||
|
||||
template <class T_, class U_>
|
||||
FT_DEV_CEXPR auto round_up(T_ a_, U_ b_)
|
||||
{
|
||||
auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>);
|
||||
|
||||
return tmp - tmp % b_;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto div_up(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<div_up(a_, b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, auto b_>
|
||||
FT_DEV_CEXPR auto round_up(Cn<a_>, Cn<b_>)
|
||||
{
|
||||
return cn<round_up(a_, b_)>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> div_up(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
template <auto a_, class B_>
|
||||
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> round_up(Cn<a_>, B_)
|
||||
{
|
||||
return cn<a_>;
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// IsTuple: std::tuple, but not std::pair, std::array, etc.
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
template <class T_>
|
||||
struct IsTuple : public std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class... Ts_>
|
||||
struct IsTuple<std::tuple<Ts_...>> : public std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T_>
|
||||
struct IsTuple<const T_> : public IsTuple<T_>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T_>
|
||||
struct IsTuple<T_&> : public IsTuple<T_>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T_>
|
||||
struct IsTuple<T_&&> : public IsTuple<T_>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T_>
|
||||
constexpr bool IsTuple_v = IsTuple<T_>::value;
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
166
cpp/tensorrt_llm/kernels/chunkScan/Common.h
Normal file
166
cpp/tensorrt_llm/kernels/chunkScan/Common.h
Normal file
@ -0,0 +1,166 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
extern "C" __device__ unsigned __nvvm_get_smem_pointer(void* ptr);
|
||||
|
||||
template <int mode_, int line_, class T_>
|
||||
__device__ inline int swizzle(int x_)
|
||||
{
|
||||
return x_ ^ x_ / line_ % (mode_ / 16) * (16 / sizeof(T_));
|
||||
}
|
||||
|
||||
template <class T_>
|
||||
__device__ inline int swizzle(int x_, int y_)
|
||||
{
|
||||
return x_ ^ y_ * (16 / sizeof(T_));
|
||||
}
|
||||
|
||||
template <int size_>
|
||||
__device__ inline void cp_shared_global(unsigned s_ptr, void const* g_ptr)
|
||||
{
|
||||
static_assert(size_ == 4 || size_ == 8 || size_ == 16);
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
if constexpr (size_ == 16)
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_));
|
||||
else if constexpr (size_ == 8)
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_));
|
||||
else if constexpr (size_ == 4)
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_));
|
||||
#else
|
||||
register unsigned tmp[size_ / 4];
|
||||
|
||||
if constexpr (size_ == 16)
|
||||
{
|
||||
asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3])
|
||||
: "l"(g_ptr));
|
||||
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]),
|
||||
"r"(tmp[3]));
|
||||
}
|
||||
else if constexpr (size_ == 8)
|
||||
{
|
||||
asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" : "=r"(tmp[0]), "=r"(tmp[1]) : "l"(g_ptr));
|
||||
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]));
|
||||
}
|
||||
else if constexpr (size_ == 4)
|
||||
{
|
||||
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr));
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0]));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int size_>
|
||||
__device__ inline void cp_shared_global(unsigned s_ptr, void const* g_ptr, bool valid_)
|
||||
{
|
||||
static_assert(size_ == 4 || size_ == 8 || size_ == 16);
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
if constexpr (size_ == 16)
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_),
|
||||
"r"(valid_ ? size_ : 0));
|
||||
else if constexpr (size_ == 8)
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_),
|
||||
"r"(valid_ ? size_ : 0));
|
||||
else if constexpr (size_ == 4)
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_),
|
||||
"r"(valid_ ? size_ : 0));
|
||||
#else
|
||||
register unsigned tmp[size_ / 4];
|
||||
|
||||
if constexpr (size_ == 16)
|
||||
{
|
||||
if (valid_)
|
||||
{
|
||||
asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3])
|
||||
: "l"(g_ptr));
|
||||
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]),
|
||||
"r"(tmp[2]), "r"(tmp[3]));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "n"(0), "n"(0), "n"(0), "n"(0));
|
||||
}
|
||||
}
|
||||
else if constexpr (size_ == 8)
|
||||
{
|
||||
if (valid_)
|
||||
{
|
||||
asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" : "=r"(tmp[0]), "=r"(tmp[1]) : "l"(g_ptr));
|
||||
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "n"(0), "n"(0));
|
||||
}
|
||||
}
|
||||
else if constexpr (size_ == 4)
|
||||
{
|
||||
if (valid_)
|
||||
{
|
||||
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr));
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0]));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "n"(0));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline void cp_commit_group()
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.commit_group;\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int remain_>
|
||||
__device__ inline void cp_wait_group()
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(remain_));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bool trans_ = false>
|
||||
__device__ inline void ldmatrix(unsigned& r0_, unsigned& r1_, unsigned& r2_, unsigned& r3_, unsigned addr_)
|
||||
{
|
||||
if (trans_)
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r0_), "=r"(r1_), "=r"(r2_), "=r"(r3_)
|
||||
: "r"(addr_));
|
||||
else
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r0_), "=r"(r1_), "=r"(r2_), "=r"(r3_)
|
||||
: "r"(addr_));
|
||||
}
|
||||
|
||||
typedef __nv_bfloat16 bf16;
|
||||
typedef __nv_bfloat162 bf162;
|
||||
|
||||
template <int mode_ = 128, int line_ = 64>
|
||||
__device__ int swz(int x_)
|
||||
{
|
||||
return x_ ^ x_ / line_ % (mode_ / 16) * 8;
|
||||
}
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
1277
cpp/tensorrt_llm/kernels/chunkScan/Poly.h
Normal file
1277
cpp/tensorrt_llm/kernels/chunkScan/Poly.h
Normal file
File diff suppressed because it is too large
Load Diff
438
cpp/tensorrt_llm/kernels/chunkScan/bmmchunk.h
Normal file
438
cpp/tensorrt_llm/kernels/chunkScan/bmmchunk.h
Normal file
@ -0,0 +1,438 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <mma.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
|
||||
#include "Common.h"
|
||||
#include "Poly.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int G_, int N_,
|
||||
// const half *g_mxY_, // B*L*H*P
|
||||
// const half *g_mxOs_, // B*C*H*N*P
|
||||
// const half *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
// const float *g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
half* g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int G_, int N_,
|
||||
// const bf16 *g_mxY_, // B*L*H*P
|
||||
// const bf16 *g_mxOs_, // B*C*H*N*P
|
||||
// const bf16 *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
// const float *g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
bf16* g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
|
||||
int warpM_, int warpN_, // warp number
|
||||
int pipeS_, class Tp_>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> bmm_chunk_kernel(int B_,
|
||||
int L_, int G_, int N_,
|
||||
// const Tp_ *g_mxY_, // B*L*H*P
|
||||
// const Tp_ *g_mxOs_, // B*C*H*N*P
|
||||
// const Tp_ *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
// const float *g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
Tp_* g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
|
||||
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
|
||||
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
|
||||
|
||||
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
|
||||
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
|
||||
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
|
||||
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
// auto H = Rn<ID>{H_};
|
||||
// auto P = Rn<ID>{P_};
|
||||
auto G = Rn<ID>{G_};
|
||||
auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
|
||||
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
else
|
||||
{
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
|
||||
if (blockIdx_y * Q >= L)
|
||||
return;
|
||||
|
||||
auto gStart = blockIdx_x / (Q / cn<tileN_>) / (Q / cn<tileM_>);
|
||||
auto mStart = blockIdx_x / (Q / cn<tileN_>) % (Q / cn<tileM_>);
|
||||
auto nStart = blockIdx_x % (Q / cn<tileN_>);
|
||||
|
||||
extern __shared__ float smem[];
|
||||
|
||||
Tp_* s_mxC = (Tp_*) smem;
|
||||
Tp_* s_mxB = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
|
||||
Tp_* s_mxCB = (Tp_*) smem;
|
||||
|
||||
unsigned b_base = __nvvm_get_smem_pointer(smem);
|
||||
|
||||
unsigned b_mxC = b_base;
|
||||
unsigned b_mxB = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
|
||||
unsigned b_mxCB = b_base;
|
||||
|
||||
using std::array;
|
||||
|
||||
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxCB
|
||||
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
|
||||
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxC;
|
||||
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxB;
|
||||
|
||||
constexpr int step = std::max(
|
||||
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
|
||||
|
||||
auto baseC = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
|
||||
auto baseB = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
|
||||
|
||||
auto thread = [=](auto iStep)
|
||||
{
|
||||
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
|
||||
+ threadIdx_x * cn<8>;
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get(
|
||||
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
|
||||
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
|
||||
cp_shared_global<16>(b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get(
|
||||
(aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
|
||||
+ cn<0> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
cp_commit_group();
|
||||
}
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int iK = pipeS_; iK < N_ / tileK_ + pipeS_; iK++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int k = 0; k < tileK_ / wmmaK_; k++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
|
||||
{
|
||||
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
|
||||
int y1 = x1 - tileN_ / wmmaN_ / warpN_
|
||||
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
|
||||
|
||||
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
|
||||
{
|
||||
if (wmmaK_ == 16)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3])
|
||||
: "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2)
|
||||
+ 2
|
||||
* swz<tileK_ * 2, tileK_>(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_
|
||||
+ threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_
|
||||
+ threadIdx.x / 16 * 8)));
|
||||
}
|
||||
|
||||
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
|
||||
{
|
||||
if (wmmaK_ == 16 && x1 % 2 == 0)
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r_mxB[x1][0]), "=r"(r_mxB[x1][1]), "=r"(r_mxB[x1 + 1][0]),
|
||||
"=r"(r_mxB[x1 + 1][1])
|
||||
: "r"(b_mxB + iK % pipeS_ * (tileK_ * tileN_ * 2)
|
||||
+ 2
|
||||
* swz<tileK_ * 2, tileK_>(x1 * warpN_ * wmmaN_ * tileK_
|
||||
+ k * wmmaK_ + threadIdx.y * wmmaN_ * tileK_
|
||||
+ threadIdx.x % 8 * tileK_ + threadIdx.x / 8 % 2 * 8
|
||||
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_ * tileK_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if (wmmaK_ == 16)
|
||||
{
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
|
||||
" {%0, %1, %2, %3}, \n"
|
||||
" {%4, %5, %6, %7}, \n"
|
||||
" {%8, %9}, \n"
|
||||
" {%0, %1, %2, %3}; \n"
|
||||
: "+f"(r_mxCB[y][x][0]), "+f"(r_mxCB[y][x][1]), "+f"(r_mxCB[y][x][2]),
|
||||
"+f"(r_mxCB[y][x][3])
|
||||
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
|
||||
"r"(r_mxB[x][0]), "r"(r_mxB[x][1]));
|
||||
else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
|
||||
" {%0, %1, %2, %3}, \n"
|
||||
" {%4, %5, %6, %7}, \n"
|
||||
" {%8, %9}, \n"
|
||||
" {%0, %1, %2, %3}; \n"
|
||||
: "+f"(r_mxCB[y][x][0]), "+f"(r_mxCB[y][x][1]), "+f"(r_mxCB[y][x][2]),
|
||||
"+f"(r_mxCB[y][x][3])
|
||||
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
|
||||
"r"(r_mxB[x][0]), "r"(r_mxB[x][1]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (iK * tileK_ < N_)
|
||||
{
|
||||
|
||||
auto jK = Rn<>{iK};
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
|
||||
* G * N
|
||||
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2>
|
||||
* G * N
|
||||
+ cn<0> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
}
|
||||
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
{
|
||||
*(half2*) &r_mxCB[y][x][0] = __floats2half2_rn(r_mxCB[y][x][0], r_mxCB[y][x][1]);
|
||||
*(half2*) &r_mxCB[y][x][2] = __floats2half2_rn(r_mxCB[y][x][2], r_mxCB[y][x][3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
*(bf162*) &r_mxCB[y][x][0] = __floats2bfloat162_rn(r_mxCB[y][x][0], r_mxCB[y][x][1]);
|
||||
*(bf162*) &r_mxCB[y][x][2] = __floats2bfloat162_rn(r_mxCB[y][x][2], r_mxCB[y][x][3]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxCB
|
||||
+ 2
|
||||
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_
|
||||
+ (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
|
||||
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
|
||||
"r"(*(unsigned*) &r_mxCB[y][x][0]));
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxCB
|
||||
+ 2
|
||||
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_
|
||||
+ x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
|
||||
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
|
||||
"r"(*(unsigned*) &r_mxCB[y][x][2]));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileN_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileN_>)
|
||||
*(int4*) (g_mxCB_
|
||||
+ get(cStart * G * Q * Q + blockIdx_y * G * Q * Q + gStart * Q * Q
|
||||
+ (mStart * cn<tileM_> + thread(iStep) / cn<tileN_>) *Q + nStart * cn<tileN_>
|
||||
+ thread(iStep) % cn<tileN_>))
|
||||
= *(int4*) ((char*) s_mxCB + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>));
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
|
||||
#endif
|
||||
}
|
||||
|
||||
BmmChunkKernelFuncFp16 getBmmChunkKernelFp16(
|
||||
int B_, int L_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
// int H = H_;
|
||||
// int P = P_;
|
||||
int G = G_;
|
||||
// int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileM = 128;
|
||||
int tileN = 64;
|
||||
int tileK = 32;
|
||||
int warpM = 2;
|
||||
int warpN = 1;
|
||||
int pipeS = 2;
|
||||
|
||||
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2, (tileM * tileN) * 2);
|
||||
|
||||
*blockDims_ = dim3(G * Q / tileN * Q / tileM, C, B);
|
||||
*threadDims_ = dim3(32, warpN, warpM);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return bmm_chunk_kernel<128, 128, 64, 32, 16, 8, 16, 2, 1, 2, half>;
|
||||
else if (Q_ == 256)
|
||||
return bmm_chunk_kernel<256, 128, 64, 32, 16, 8, 16, 2, 1, 2, half>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
BmmChunkKernelFuncBf16 getBmmChunkKernelBf16(
|
||||
int B_, int L_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
// int H = H_;
|
||||
// int P = P_;
|
||||
int G = G_;
|
||||
// int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileM = 128;
|
||||
int tileN = 64;
|
||||
int tileK = 32;
|
||||
int warpM = 2;
|
||||
int warpN = 1;
|
||||
int pipeS = 2;
|
||||
|
||||
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2, (tileM * tileN) * 2);
|
||||
|
||||
*blockDims_ = dim3(G * Q / tileN * Q / tileM, C, B);
|
||||
*threadDims_ = dim3(32, warpN, warpM);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return bmm_chunk_kernel<128, 128, 64, 32, 16, 8, 16, 2, 1, 2, bf16>;
|
||||
else if (Q_ == 256)
|
||||
return bmm_chunk_kernel<256, 128, 64, 32, 16, 8, 16, 2, 1, 2, bf16>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
245
cpp/tensorrt_llm/kernels/chunkScan/chunkcumsum.h
Normal file
245
cpp/tensorrt_llm/kernels/chunkScan/chunkcumsum.h
Normal file
@ -0,0 +1,245 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <mma.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
|
||||
#include "Common.h"
|
||||
#include "Poly.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_,
|
||||
// const half *g_mxY_, // B*L*H*P
|
||||
// const half *g_mxOs_, // B*C*H*N*P
|
||||
// const half *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float* g_mxdc_, // B*C*H*Q
|
||||
float* g_mxdA_, // B*C*H*Q
|
||||
half const* g_mxdt_, // B*L*H
|
||||
float const* g_mxdb_, // H
|
||||
float const* g_mxA_, // H
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
// const half *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_,
|
||||
// const bf16 *g_mxY_, // B*L*H*P
|
||||
// const bf16 *g_mxOs_, // B*C*H*N*P
|
||||
// const bf16 *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float* g_mxdc_, // B*C*H*Q
|
||||
float* g_mxdA_, // B*C*H*Q
|
||||
bf16 const* g_mxdt_, // B*L*H
|
||||
float const* g_mxdb_, // H
|
||||
float const* g_mxA_, // H
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
// const bf16 *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileH_, int warpH_, bool dtSoftplus_, class Tp_, class Wt_ = float>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_cumsum_kernel(int B_,
|
||||
int L_, int H_,
|
||||
// const Tp_ *g_mxY_, // B*L*H*P
|
||||
// const Tp_ *g_mxOs_, // B*C*H*N*P
|
||||
// const Tp_ *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float* g_mxdc_, // B*C*H*Q
|
||||
float* g_mxdA_, // B*C*H*Q
|
||||
Tp_ const* g_mxdt_, // B*L*H
|
||||
Wt_ const* g_mxdb_, // H
|
||||
Wt_ const* g_mxA_, // H
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
// const Tp_ *g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
|
||||
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
|
||||
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
|
||||
|
||||
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
|
||||
auto threadIdx_y = Rn<ID, warpH_>{int(threadIdx.y)};
|
||||
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
auto H = Rn<ID>{H_};
|
||||
// auto P = Rn<ID>{P_};
|
||||
// auto G = Rn<ID>{G_};
|
||||
// auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
|
||||
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
else
|
||||
{
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
|
||||
if (blockIdx_y * Q >= L)
|
||||
return;
|
||||
|
||||
auto thread = [=](auto iStep) { return iStep * cn<warpH_ * 32> + threadIdx_y * cn<32> + threadIdx_x; };
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileH_, warpH_ * 32)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
{
|
||||
float r_A = 0.f, r_db = 0.f, sum = 0.f;
|
||||
|
||||
if (thread(iStep) < cn<tileH_>)
|
||||
r_A = g_mxA_[get(blockIdx_x * cn<tileH_> + thread(iStep))];
|
||||
if (thread(iStep) < cn<tileH_> && g_mxdb_)
|
||||
r_db = g_mxdb_[get(blockIdx_x * cn<tileH_> + thread(iStep))];
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, Q_> iQ; iQ.var < iQ.size; iQ.var++)
|
||||
{
|
||||
float r_dt = 0.f;
|
||||
|
||||
if (thread(iStep) < cn<tileH_> && blockIdx_y * Q + iQ < L)
|
||||
{
|
||||
r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * H + blockIdx_x * cn<tileH_> + thread(iStep))])
|
||||
+ r_db;
|
||||
|
||||
if (dtSoftplus_)
|
||||
r_dt = r_dt > 32.f ? r_dt : log1p(expf(r_dt));
|
||||
|
||||
sum += r_dt;
|
||||
}
|
||||
|
||||
if (thread(iStep) < cn<tileH_>)
|
||||
{
|
||||
g_mxdc_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn<tileH_> + thread(iStep)) * Q + iQ)] = r_dt;
|
||||
g_mxdA_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn<tileH_> + thread(iStep)) * Q + iQ)]
|
||||
= sum * r_A;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ChunkCumsumKernelFuncFp16 getChunkCumsumKernelFp16(
|
||||
int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
// int P = P_;
|
||||
// int G = G_;
|
||||
// int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileH = 1;
|
||||
int warpH = 1;
|
||||
|
||||
auto sharedMem = 0;
|
||||
|
||||
*blockDims_ = dim3(H / tileH, C, B);
|
||||
*threadDims_ = dim3(32, warpH);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (dtSoftPlus_)
|
||||
{
|
||||
if (Q_ == 128)
|
||||
return chunk_cumsum_kernel<128, 1, 1, true, half>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_cumsum_kernel<256, 1, 1, true, half>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (Q_ == 128)
|
||||
return chunk_cumsum_kernel<128, 1, 1, false, half>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_cumsum_kernel<256, 1, 1, false, half>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ChunkCumsumKernelFuncBf16 getChunkCumsumKernelBf16(
|
||||
int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
// int P = P_;
|
||||
// int G = G_;
|
||||
// int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileH = 1;
|
||||
int warpH = 1;
|
||||
|
||||
auto sharedMem = 0;
|
||||
|
||||
*blockDims_ = dim3(H / tileH, C, B);
|
||||
*threadDims_ = dim3(32, warpH);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (dtSoftPlus_)
|
||||
{
|
||||
if (Q_ == 128)
|
||||
return chunk_cumsum_kernel<128, 1, 1, true, bf16>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_cumsum_kernel<256, 1, 1, true, bf16>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (Q_ == 128)
|
||||
return chunk_cumsum_kernel<128, 1, 1, false, bf16>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_cumsum_kernel<256, 1, 1, false, bf16>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
639
cpp/tensorrt_llm/kernels/chunkScan/chunkscan.h
Normal file
639
cpp/tensorrt_llm/kernels/chunkScan/chunkscan.h
Normal file
@ -0,0 +1,639 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <mma.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
|
||||
#include "Common.h"
|
||||
#include "Poly.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
half* g_mxY_, // B*L*H*P
|
||||
half const* g_mxOs_, // B*C*H*N*P
|
||||
// const half *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
half const* g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxBC_, // B*L*2*G*N
|
||||
float const* g_mxD_, // H
|
||||
half const* g_mxX_, // B*L*H*P
|
||||
half const* g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
bf16* g_mxY_, // B*L*H*P
|
||||
bf16 const* g_mxOs_, // B*C*H*N*P
|
||||
// const bf16 *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
bf16 const* g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxBC_, // B*L*2*G*N
|
||||
float const* g_mxD_, // H
|
||||
bf16 const* g_mxX_, // B*L*H*P
|
||||
bf16 const* g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
|
||||
int warpM_, int warpN_, // warp number
|
||||
int pipeS_, class Tp_, class Wt_ = float>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_scan_kernel(int B_,
|
||||
int L_, int H_, int P_, int G_, int N_,
|
||||
Tp_* g_mxY_, // B*L*H*P
|
||||
Tp_ const* g_mxOs_, // B*C*H*N*P
|
||||
// const Tp_ *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
Tp_ const* g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxBC_, // B*L*2*G*N
|
||||
Wt_ const* g_mxD_, // H
|
||||
Tp_ const* g_mxX_, // B*L*H*P
|
||||
Tp_ const* g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
|
||||
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
|
||||
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
|
||||
|
||||
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
|
||||
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
|
||||
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
|
||||
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
auto H = Rn<ID>{H_};
|
||||
auto P = Rn<ID>{P_};
|
||||
auto G = Rn<ID>{G_};
|
||||
auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
|
||||
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
else
|
||||
{
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
|
||||
if (blockIdx_y * Q >= L)
|
||||
return;
|
||||
|
||||
auto hStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) / (Q / cn<tileM_>) };
|
||||
auto mStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) % (Q / cn<tileM_>) };
|
||||
auto nStart = Rn<ID>{blockIdx_x.var % (P_ / cn<tileN_>) };
|
||||
auto gStart = Rn<ID>{hStart.var / (H_ / G_)};
|
||||
|
||||
extern __shared__ float smem[];
|
||||
|
||||
Tp_* s_mxC = (Tp_*) smem;
|
||||
Tp_* s_mxOs = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
|
||||
Tp_* s_mxY = (Tp_*) smem;
|
||||
|
||||
float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2;
|
||||
float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_;
|
||||
|
||||
unsigned b_base = __nvvm_get_smem_pointer(smem);
|
||||
|
||||
unsigned b_mxC = b_base;
|
||||
unsigned b_mxOs = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
|
||||
unsigned b_mxY = b_base;
|
||||
|
||||
using std::array;
|
||||
|
||||
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxY
|
||||
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
|
||||
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxC;
|
||||
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxOs;
|
||||
|
||||
constexpr int step = std::max(
|
||||
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
|
||||
|
||||
auto baseC = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
|
||||
auto baseOs = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
|
||||
|
||||
auto thread = [=](auto iStep)
|
||||
{
|
||||
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
|
||||
+ threadIdx_x * cn<8>;
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(Q_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<Q_>)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i += 4)
|
||||
{
|
||||
*(int4*) (s_mxdc + get(thread(iStep)) + i)
|
||||
= *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
|
||||
*(int4*) (s_mxdA + get(thread(iStep)) + i)
|
||||
= *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get(
|
||||
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
|
||||
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(iK) * cn<2>),
|
||||
g_mxOs_
|
||||
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
|
||||
+ (iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *P + nStart * cn<tileN_>
|
||||
+ thread(iStep) % cn<tileN_>));
|
||||
|
||||
cp_commit_group();
|
||||
}
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int iK = pipeS_; iK < (N_ + Q_) / tileK_ + pipeS_; iK++)
|
||||
{
|
||||
auto jK = Rn<>{iK};
|
||||
if ((iK - pipeS_) * cn<tileK_> == N_)
|
||||
{
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
float2 tmp2 = float2{expf(s_mxdA[get(mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>)]),
|
||||
expf(s_mxdA[get(mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>)])};
|
||||
|
||||
r_mxY[y][x][0] *= tmp2.x;
|
||||
r_mxY[y][x][1] *= tmp2.x;
|
||||
r_mxY[y][x][2] *= tmp2.y;
|
||||
r_mxY[y][x][3] *= tmp2.y;
|
||||
}
|
||||
}
|
||||
|
||||
if ((iK - pipeS_) * cn<tileK_> >= N_)
|
||||
{
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
{
|
||||
register Tp_ tmpCB[8];
|
||||
|
||||
*(int4*) &tmpCB[0] = *(int4*) ((char*) s_mxC
|
||||
+ swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i += 2)
|
||||
{
|
||||
float2 tmp2 = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmpCB[i])
|
||||
: bf1622float2(*(bf162*) &tmpCB[i]);
|
||||
|
||||
int kStart = (iK - pipeS_) * cn<tileK_> - N_;
|
||||
|
||||
tmp2.x *= expf(s_mxdA[get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)]
|
||||
- s_mxdA[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i})])
|
||||
* s_mxdc[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i})];
|
||||
tmp2.y *= expf(s_mxdA[get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)]
|
||||
- s_mxdA[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1})])
|
||||
* s_mxdc[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1})];
|
||||
|
||||
if (get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)
|
||||
< kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i}))
|
||||
tmp2.x = 0;
|
||||
if (get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)
|
||||
< kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1}))
|
||||
tmp2.y = 0;
|
||||
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
*(half2*) &tmpCB[i] = __float22half2_rn(tmp2);
|
||||
else
|
||||
*(bf162*) &tmpCB[i] = __float22bfloat162_rn(tmp2);
|
||||
}
|
||||
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
|
||||
= *(int4*) &tmpCB[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < tileK_ / wmmaK_; k++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
|
||||
{
|
||||
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
|
||||
int y1 = x1 - tileN_ / wmmaN_ / warpN_
|
||||
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
|
||||
|
||||
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
|
||||
{
|
||||
if (wmmaK_ == 16)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3])
|
||||
: "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2)
|
||||
+ 2
|
||||
* swz<tileK_ * 2, tileK_>(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_
|
||||
+ threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_
|
||||
+ threadIdx.x / 16 * 8)));
|
||||
}
|
||||
|
||||
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
|
||||
{
|
||||
if (wmmaK_ == 16 && x1 % 2 == 0)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r_mxOs[x1][0]), "=r"(r_mxOs[x1][1]), "=r"(r_mxOs[x1 + 1][0]),
|
||||
"=r"(r_mxOs[x1 + 1][1])
|
||||
: "r"(b_mxOs + iK % pipeS_ * (tileK_ * tileN_ * 2)
|
||||
+ 2
|
||||
* swz<tileN_ * 2, tileN_>(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_
|
||||
+ threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_
|
||||
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if (wmmaK_ == 16)
|
||||
{
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
|
||||
" {%0, %1, %2, %3}, \n"
|
||||
" {%4, %5, %6, %7}, \n"
|
||||
" {%8, %9}, \n"
|
||||
" {%0, %1, %2, %3}; \n"
|
||||
: "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3])
|
||||
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
|
||||
"r"(r_mxOs[x][0]), "r"(r_mxOs[x][1]));
|
||||
else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
|
||||
" {%0, %1, %2, %3}, \n"
|
||||
" {%4, %5, %6, %7}, \n"
|
||||
" {%8, %9}, \n"
|
||||
" {%0, %1, %2, %3}; \n"
|
||||
: "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3])
|
||||
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
|
||||
"r"(r_mxOs[x][0]), "r"(r_mxOs[x][1]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (iK * cn<tileK_> < N_)
|
||||
{
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
|
||||
* G * N
|
||||
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
|
||||
g_mxOs_
|
||||
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
|
||||
+ (jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *P + nStart * cn<tileN_>
|
||||
+ thread(iStep) % cn<tileN_>));
|
||||
}
|
||||
else if (iK * cn<tileK_> < N_ + Q_)
|
||||
{
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
|
||||
g_mxCB_
|
||||
+ get((cStart + blockIdx_y) * G * Q * Q + gStart * Q * Q
|
||||
+ (mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *Q + jK * cn<tileK_>
|
||||
- N + thread(iStep) % cn<tileK_>));
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>
|
||||
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_> + N)
|
||||
cp_shared_global<16>(
|
||||
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
|
||||
g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *H * P
|
||||
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxOs
|
||||
+ swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
}
|
||||
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (g_mxD_)
|
||||
{
|
||||
float r_D = g_mxD_[hStart.var];
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
Tp_ tmp16[4] = {0};
|
||||
float tmp32[4] = {0};
|
||||
|
||||
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
|
||||
< L)
|
||||
{
|
||||
*(int*) &tmp16[0] = *(int*) (g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
*(float2*) &tmp32[0] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[0])
|
||||
: bf1622float2(*(bf162*) &tmp16[0]);
|
||||
|
||||
r_mxY[y][x][0] += r_D * tmp32[0];
|
||||
r_mxY[y][x][1] += r_D * tmp32[1];
|
||||
}
|
||||
|
||||
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
|
||||
< L)
|
||||
{
|
||||
*(int*) &tmp16[2] = *(int*) (g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
*(float2*) &tmp32[2] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[2])
|
||||
: bf1622float2(*(bf162*) &tmp16[2]);
|
||||
|
||||
r_mxY[y][x][2] += r_D * tmp32[2];
|
||||
r_mxY[y][x][3] += r_D * tmp32[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (g_mxZ_)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
Tp_ tmp16[4] = {0};
|
||||
float tmp32[4] = {0};
|
||||
|
||||
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
|
||||
< L)
|
||||
{
|
||||
*(int*) &tmp16[0] = *(int*) (g_mxZ_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
*(float2*) &tmp32[0] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[0])
|
||||
: bf1622float2(*(bf162*) &tmp16[0]);
|
||||
|
||||
r_mxY[y][x][0] *= tmp32[0] > 32.f ? tmp32[0] : tmp32[0] / (1.f + expf(-tmp32[0]));
|
||||
r_mxY[y][x][1] *= tmp32[1] > 32.f ? tmp32[1] : tmp32[1] / (1.f + expf(-tmp32[1]));
|
||||
}
|
||||
|
||||
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
|
||||
< L)
|
||||
{
|
||||
*(int*) &tmp16[2] = *(int*) (g_mxZ_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
*(float2*) &tmp32[2] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[2])
|
||||
: bf1622float2(*(bf162*) &tmp16[2]);
|
||||
|
||||
r_mxY[y][x][2] *= tmp32[2] > 32.f ? tmp32[2] : tmp32[2] / (1.f + expf(-tmp32[2]));
|
||||
r_mxY[y][x][3] *= tmp32[3] > 32.f ? tmp32[3] : tmp32[3] / (1.f + expf(-tmp32[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
{
|
||||
*(half2*) &r_mxY[y][x][0] = __floats2half2_rn(r_mxY[y][x][0], r_mxY[y][x][1]);
|
||||
*(half2*) &r_mxY[y][x][2] = __floats2half2_rn(r_mxY[y][x][2], r_mxY[y][x][3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
*(bf162*) &r_mxY[y][x][0] = __floats2bfloat162_rn(r_mxY[y][x][0], r_mxY[y][x][1]);
|
||||
*(bf162*) &r_mxY[y][x][2] = __floats2bfloat162_rn(r_mxY[y][x][2], r_mxY[y][x][3]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY
|
||||
+ 2
|
||||
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_
|
||||
+ (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
|
||||
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
|
||||
"r"(*(unsigned*) &r_mxY[y][x][0]));
|
||||
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY
|
||||
+ 2
|
||||
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_
|
||||
+ x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
|
||||
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
|
||||
"r"(*(unsigned*) &r_mxY[y][x][2]));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileN_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileN_>
|
||||
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
*(int4*) (g_mxY_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileN_>) *H * P + hStart * P
|
||||
+ nStart * cn<tileN_> + thread(iStep) % cn<tileN_>))
|
||||
= *(int4*) ((char*) s_mxY + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>));
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
|
||||
#endif
|
||||
}
|
||||
|
||||
ChunkScanKernelFuncFp16 getChunkScanKernelFp16(
|
||||
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
int P = P_;
|
||||
// int G = G_;
|
||||
// int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileM = 128;
|
||||
int tileN = 64;
|
||||
int tileK = 32;
|
||||
int warpM = 4;
|
||||
int warpN = 1;
|
||||
int pipeS = 2;
|
||||
|
||||
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2);
|
||||
|
||||
*blockDims_ = dim3(H * P / tileN * Q / tileM, C, B);
|
||||
*threadDims_ = dim3(32, warpN, warpM);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ChunkScanKernelFuncBf16 getChunkScanKernelBf16(
|
||||
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
int P = P_;
|
||||
// int G = G_;
|
||||
// int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileM = 128;
|
||||
int tileN = 64;
|
||||
int tileK = 32;
|
||||
int warpM = 4;
|
||||
int warpN = 1;
|
||||
int pipeS = 2;
|
||||
|
||||
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2);
|
||||
|
||||
*blockDims_ = dim3(H * P / tileN * Q / tileM, C, B);
|
||||
*threadDims_ = dim3(32, warpN, warpM);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
453
cpp/tensorrt_llm/kernels/chunkScan/chunkstate.h
Normal file
453
cpp/tensorrt_llm/kernels/chunkScan/chunkstate.h
Normal file
@ -0,0 +1,453 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <mma.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
|
||||
#include "Common.h"
|
||||
#include "Poly.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*ChunkStateKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
// const half *g_mxY_, // B*L*H*P
|
||||
// const half *g_mxOs_, // B*C*H*N*P
|
||||
// const half *g_mxFs_, // B *H*N*P
|
||||
float* g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
half const* g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
// const bf16 *g_mxY_, // B*L*H*P
|
||||
// const bf16 *g_mxOs_, // B*C*H*N*P
|
||||
// const bf16 *g_mxFs_, // B *H*N*P
|
||||
float* g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
bf16 const* g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
|
||||
int warpM_, int warpN_, // warp number
|
||||
int pipeS_, class Tp_>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_state_kernel(int B_,
|
||||
int L_, int H_, int P_, int G_, int N_,
|
||||
// const Tp_ *g_mxY_, // B*L*H*P
|
||||
// const Tp_ *g_mxOs_, // B*C*H*N*P
|
||||
// const Tp_ *g_mxFs_, // B *H*N*P
|
||||
float* g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
Tp_ const* g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
|
||||
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
|
||||
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
|
||||
|
||||
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
|
||||
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
|
||||
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
|
||||
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
auto H = Rn<ID>{H_};
|
||||
auto P = Rn<ID>{P_};
|
||||
auto G = Rn<ID>{G_};
|
||||
auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
|
||||
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
else
|
||||
{
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
|
||||
if (blockIdx_y * Q >= L)
|
||||
return;
|
||||
|
||||
auto hStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) / (N_ / cn<tileM_>) };
|
||||
auto mStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) % (N_ / cn<tileM_>) };
|
||||
auto nStart = Rn<ID>{blockIdx_x.var % (P_ / cn<tileN_>) };
|
||||
auto gStart = Rn<ID>{hStart.var / (H_ / G_)};
|
||||
|
||||
extern __shared__ float smem[];
|
||||
|
||||
Tp_* s_mxB = (Tp_*) smem;
|
||||
Tp_* s_mxX = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
|
||||
|
||||
float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2;
|
||||
float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_;
|
||||
|
||||
unsigned b_base = __nvvm_get_smem_pointer(smem);
|
||||
|
||||
unsigned b_mxB = b_base;
|
||||
unsigned b_mxX = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
|
||||
|
||||
using std::array;
|
||||
|
||||
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxSt
|
||||
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
|
||||
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxB;
|
||||
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxX;
|
||||
|
||||
constexpr int step = std::max(
|
||||
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
|
||||
|
||||
auto baseB = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
|
||||
auto baseX = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
|
||||
|
||||
auto thread = [=](auto iStep)
|
||||
{
|
||||
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
|
||||
+ threadIdx_x * cn<8>;
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(Q_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<Q_>)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i += 4)
|
||||
{
|
||||
*(int4*) (s_mxdc + get(thread(iStep)) + i)
|
||||
= *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
|
||||
*(int4*) (s_mxdA + get(thread(iStep)) + i)
|
||||
= *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - iK * cn<tileK_>)
|
||||
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
|
||||
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>
|
||||
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - iK * cn<tileK_>)
|
||||
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>),
|
||||
g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
|
||||
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
cp_commit_group();
|
||||
}
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int iK = pipeS_; iK < Q_ / tileK_ + pipeS_; iK++)
|
||||
{
|
||||
auto jK = Rn<>{iK};
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
{
|
||||
register Tp_ tmpB[8];
|
||||
|
||||
*(int4*) &tmpB[0] = *(
|
||||
int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i += 2)
|
||||
{
|
||||
float2 tmp2 = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmpB[i])
|
||||
: bf1622float2(*(bf162*) &tmpB[i]);
|
||||
|
||||
int kStart = (iK - pipeS_) * cn<tileK_>;
|
||||
|
||||
tmp2.x *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn<tileM_>)])
|
||||
* s_mxdc[kStart + get(thread(iStep) / cn<tileM_>)];
|
||||
tmp2.y *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn<tileM_>)])
|
||||
* s_mxdc[kStart + get(thread(iStep) / cn<tileM_>)];
|
||||
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
*(half2*) &tmpB[i] = __float22half2_rn(tmp2);
|
||||
else
|
||||
*(bf162*) &tmpB[i] = __float22bfloat162_rn(tmp2);
|
||||
}
|
||||
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
|
||||
= *(int4*) &tmpB[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < tileK_ / wmmaK_; k++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
|
||||
{
|
||||
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
|
||||
int y1 = x1 - tileN_ / wmmaN_ / warpN_
|
||||
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
|
||||
|
||||
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
|
||||
{
|
||||
if (wmmaK_ == 16)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r_mxB[y1][0]), "=r"(r_mxB[y1][1]), "=r"(r_mxB[y1][2]), "=r"(r_mxB[y1][3])
|
||||
: "r"(b_mxB + iK % pipeS_ * (tileM_ * tileK_ * 2)
|
||||
+ 2
|
||||
* swz<tileM_ * 2, tileM_>(y1 * warpM_ * wmmaM_ + k * wmmaK_ * tileM_
|
||||
+ threadIdx.z * wmmaM_ + threadIdx.x % 8 * tileM_
|
||||
+ threadIdx.x / 8 % 2 * 8 + threadIdx.x / wmmaK_ * 8 * tileM_)));
|
||||
}
|
||||
|
||||
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
|
||||
{
|
||||
if (wmmaK_ == 16 && x1 % 2 == 0)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(r_mxX[x1][0]), "=r"(r_mxX[x1][1]), "=r"(r_mxX[x1 + 1][0]),
|
||||
"=r"(r_mxX[x1 + 1][1])
|
||||
: "r"(b_mxX + iK % pipeS_ * (tileK_ * tileN_ * 2)
|
||||
+ 2
|
||||
* swz<tileN_ * 2, tileN_>(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_
|
||||
+ threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_
|
||||
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
if (wmmaK_ == 16)
|
||||
{
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
|
||||
" {%0, %1, %2, %3}, \n"
|
||||
" {%4, %5, %6, %7}, \n"
|
||||
" {%8, %9}, \n"
|
||||
" {%0, %1, %2, %3}; \n"
|
||||
: "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]),
|
||||
"+f"(r_mxSt[y][x][3])
|
||||
: "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]),
|
||||
"r"(r_mxX[x][0]), "r"(r_mxX[x][1]));
|
||||
else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
|
||||
" {%0, %1, %2, %3}, \n"
|
||||
" {%4, %5, %6, %7}, \n"
|
||||
" {%8, %9}, \n"
|
||||
" {%0, %1, %2, %3}; \n"
|
||||
: "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]),
|
||||
"+f"(r_mxSt[y][x][3])
|
||||
: "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]),
|
||||
"r"(r_mxX[x][0]), "r"(r_mxX[x][1]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileM_ * tileK_> && thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - jK * cn<tileK_>
|
||||
&& jK * cn<tileK_> < Q)
|
||||
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
|
||||
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_> && jK * cn<tileK_> < Q)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
|
||||
if (thread(iStep) < cn<tileN_ * tileK_> && thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_>
|
||||
&& jK * cn<tileK_> < Q)
|
||||
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>),
|
||||
g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
|
||||
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_> && jK * cn<tileK_> < Q)
|
||||
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
|
||||
{
|
||||
*(float2*) (g_mxSt_
|
||||
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
|
||||
+ (mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + threadIdx_z * cn<wmmaM_>
|
||||
+ threadIdx_x / cn<4>) *P
|
||||
+ nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_> + threadIdx_y * cn<wmmaN_>
|
||||
+ threadIdx_x % cn<4> * cn<2>))
|
||||
= *(float2*) &r_mxSt[y][x][0];
|
||||
|
||||
*(float2*) (g_mxSt_
|
||||
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
|
||||
+ (mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8> + threadIdx_z * cn<wmmaM_>
|
||||
+ threadIdx_x / cn<4>) *P
|
||||
+ nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_> + threadIdx_y * cn<wmmaN_>
|
||||
+ threadIdx_x % cn<4> * cn<2>))
|
||||
= *(float2*) &r_mxSt[y][x][2];
|
||||
}
|
||||
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
|
||||
#endif
|
||||
}
|
||||
|
||||
ChunkStateKernelFuncFp16 getChunkStateKernelFp16(
|
||||
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
int P = P_;
|
||||
// int G = G_;
|
||||
int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileM = 64;
|
||||
int tileN = 64;
|
||||
int tileK = 32;
|
||||
int warpM = 1;
|
||||
int warpN = 2;
|
||||
int pipeS = 3;
|
||||
|
||||
auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8;
|
||||
|
||||
*blockDims_ = dim3(H * P / tileN * N / tileM, C, B);
|
||||
*threadDims_ = dim3(32, warpN, warpM);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ChunkStateKernelFuncBf16 getChunkStateKernelBf16(
|
||||
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
int P = P_;
|
||||
// int G = G_;
|
||||
int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileM = 64;
|
||||
int tileN = 64;
|
||||
int tileK = 32;
|
||||
int warpM = 1;
|
||||
int warpN = 2;
|
||||
int pipeS = 3;
|
||||
|
||||
auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8;
|
||||
|
||||
*blockDims_ = dim3(H * P / tileN * N / tileM, C, B);
|
||||
*threadDims_ = dim3(32, warpN, warpM);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>;
|
||||
else if (Q_ == 256)
|
||||
return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
241
cpp/tensorrt_llm/kernels/chunkScan/statepassing.h
Normal file
241
cpp/tensorrt_llm/kernels/chunkScan/statepassing.h
Normal file
@ -0,0 +1,241 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <mma.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
|
||||
#include "Common.h"
|
||||
#include "Poly.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*StatePassingKernelFuncFp16)(int B_, int L_, int H_, int P_, int N_,
|
||||
// const half *g_mxY_, // B*L*H*P
|
||||
half* g_mxOs_, // B*C*H*N*P
|
||||
half* g_mxFs_, // B *H*N*P
|
||||
float const* g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
// const half *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
|
||||
|
||||
typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N_,
|
||||
// const bf16 *g_mxY_, // B*L*H*P
|
||||
bf16* g_mxOs_, // B*C*H*N*P
|
||||
bf16* g_mxFs_, // B *H*N*P
|
||||
float const* g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
// const bf16 *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
|
||||
|
||||
template <int Q_, int tileH_, int warpH_, class Tp_>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> state_passing_kernel(
|
||||
int B_, int L_, int H_, int P_, int N_,
|
||||
// const Tp_ *g_mxY_, // B*L*H*P
|
||||
Tp_* g_mxOs_, // B*C*H*N*P
|
||||
Tp_* g_mxFs_, // B *H*N*P
|
||||
float const* g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
// const Tp_ *g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_)
|
||||
{
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
|
||||
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
|
||||
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
|
||||
|
||||
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
|
||||
auto threadIdx_y = Rn<ID, warpH_>{int(threadIdx.y)};
|
||||
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
auto H = Rn<ID>{H_};
|
||||
auto P = Rn<ID>{P_};
|
||||
// auto G = Rn<ID>{G_};
|
||||
auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
|
||||
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
else
|
||||
{
|
||||
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
|
||||
C = Rn<ID>{div_up(L.var, Q_)};
|
||||
}
|
||||
|
||||
if (stateSlotMappingPtr_)
|
||||
{
|
||||
g_mxFs_ += stateSlotMappingPtr_[blockIdx.z] * H_ * N_ * P_;
|
||||
}
|
||||
else
|
||||
{
|
||||
g_mxFs_ += blockIdx.z * H_ * N_ * P_;
|
||||
}
|
||||
|
||||
auto hStart = Rn<ID>{blockIdx_x.var * tileH_ / N_ / P_};
|
||||
|
||||
register Tp_ r_mxOs[tileH_ / (warpH_ * 32)] = {0};
|
||||
register float r_mxSt[tileH_ / (warpH_ * 32)] = {0};
|
||||
|
||||
for (int iC = 0; iC < C.var; iC++)
|
||||
{
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
|
||||
*(half2*) &r_mxOs[i] = __float22half2_rn(*(float2*) &r_mxSt[i]);
|
||||
else
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
|
||||
*(bf162*) &r_mxOs[i] = __float22bfloat162_rn(*(float2*) &r_mxSt[i]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
|
||||
*(int*) (g_mxOs_
|
||||
+ get((cStart + Rn<>{iC}) * H * N * P + blockIdx_x * cn<tileH_>
|
||||
+ (threadIdx_y * cn<32> + threadIdx_x) * cn<tileH_ / (warpH_ * 32)> + Rn<UNROLL>{i}))
|
||||
= *(int*) &r_mxOs[i];
|
||||
|
||||
float scale = expf(g_mxdA_[get((cStart + Rn<>{iC}) * H * Q + hStart * Q + Q - cn<1>)]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i++)
|
||||
{
|
||||
float tmp = g_mxSt_[get((cStart + Rn<>{iC}) * H * N * P + blockIdx_x * cn<tileH_>
|
||||
+ (threadIdx_y * cn<32> + threadIdx_x) * cn<tileH_ / (warpH_ * 32)> + Rn<UNROLL>{i})];
|
||||
|
||||
r_mxSt[i] = scale * r_mxSt[i] + tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same_v<Tp_, half>)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
|
||||
*(half2*) &r_mxOs[i] = __float22half2_rn(*(float2*) &r_mxSt[i]);
|
||||
else
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
|
||||
*(bf162*) &r_mxOs[i] = __float22bfloat162_rn(*(float2*) &r_mxSt[i]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 8)
|
||||
*(int4*) (g_mxFs_
|
||||
+ get(blockIdx_x * cn<tileH_> + (threadIdx_y * cn<32> + threadIdx_x) * cn<tileH_ / (warpH_ * 32)>
|
||||
+ Rn<UNROLL>{i}))
|
||||
= *(int4*) &r_mxOs[i];
|
||||
}
|
||||
|
||||
StatePassingKernelFuncFp16 getStatePassingKernelFp16(
|
||||
int B_, int L_, int H_, int P_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
int P = P_;
|
||||
// int G = G_;
|
||||
int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileH = 1024;
|
||||
int warpH = 8;
|
||||
|
||||
auto sharedMem = 0;
|
||||
|
||||
*blockDims_ = dim3(H * N * P / tileH, 1, B);
|
||||
*threadDims_ = dim3(32, warpH);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return state_passing_kernel<128, 1024, 8, half>;
|
||||
else if (Q_ == 256)
|
||||
return state_passing_kernel<256, 1024, 8, half>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StatePassingKernelFuncBf16 getStatePassingKernelBf16(
|
||||
int B_, int L_, int H_, int P_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
|
||||
{
|
||||
int B = B_;
|
||||
int L = L_;
|
||||
int H = H_;
|
||||
int P = P_;
|
||||
// int G = G_;
|
||||
int N = N_;
|
||||
int Q = Q_;
|
||||
int C = div_up(L, Q);
|
||||
|
||||
int tileH = 1024;
|
||||
int warpH = 8;
|
||||
|
||||
auto sharedMem = 0;
|
||||
|
||||
*blockDims_ = dim3(H * N * P / tileH, 1, B);
|
||||
*threadDims_ = dim3(32, warpH);
|
||||
*sharedMem_ = sharedMem;
|
||||
|
||||
if (Q_ == 128)
|
||||
return state_passing_kernel<128, 1024, 8, bf16>;
|
||||
else if (Q_ == 256)
|
||||
return state_passing_kernel<256, 1024, 8, bf16>;
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// vim: ts=2 sw=2 sts=2 et sta
|
||||
@ -61,9 +61,19 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads == xqaParams.num_kv_heads)
|
||||
if (xqaParams.num_kv_heads != 0 && xqaParams.num_q_heads % xqaParams.num_kv_heads != 0)
|
||||
{
|
||||
// Do not use XQA kernel for MHA.
|
||||
return false;
|
||||
}
|
||||
bool is_vanilla_mha = xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads == xqaParams.num_kv_heads;
|
||||
if (is_vanilla_mha && xqaParams.beam_width == 1)
|
||||
{
|
||||
// Do not use XQA kernel for vanilla MHA case for performance reasons.
|
||||
return false;
|
||||
}
|
||||
if (is_vanilla_mha && xqaParams.head_size <= 128)
|
||||
{
|
||||
// TODO(yaoy): remove this when the kernel bug for num_kv_heads <= 128 gets fixed.
|
||||
return false;
|
||||
}
|
||||
if (xqaParams.multi_block_mode)
|
||||
@ -108,11 +118,7 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads % xqaParams.num_kv_heads != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
int32_t head_grp_size = xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
int32_t head_grp_size = xqaParams.num_kv_heads == 0 ? 1 : xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
if (head_grp_size * xqaParams.beam_width > 32)
|
||||
{
|
||||
return false;
|
||||
@ -150,11 +156,7 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads % xqaParams.num_kv_heads != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
int32_t head_grp_size = xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
int32_t head_grp_size = xqaParams.num_kv_heads == 0 ? 1 : xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
if (head_grp_size * xqaParams.beam_width > 32)
|
||||
{
|
||||
return false;
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
8b0f8deb35940359b39f876fc5e94e4f libtensorrt_llm_nvrtc_wrapper.so
|
||||
0e1417f27d93de67940c1062cf230017cd8be5f1 commit
|
||||
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:53746a0351295accb650f9e509303914ae8d8dc3c2605baf680f30cfc40d96f6
|
||||
oid sha256:78209a1351f9f21f635bf9f763f4947031ea12b7526c5782094e9869b667a23f
|
||||
size 1091072
|
||||
|
||||
@ -482,12 +482,11 @@ __global__ void finalizeKernel(BeamHypotheses bh)
|
||||
|
||||
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s %s start", __FILE__, __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__);
|
||||
|
||||
int const nBM = bh.nBeamWidth;
|
||||
size_t const smem_size = sizeof(int) * nBM * 2 + sizeof(float) * nBM * 2;
|
||||
finalizeKernel<<<bh.nBatchSize, roundUp(nBM * 2, 32), smem_size, stream>>>(bh);
|
||||
TLLM_LOG_TRACE("%s %s stop", __FILE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
__global__ void initializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const nMaxSeqLen)
|
||||
|
||||
@ -28,6 +28,12 @@
|
||||
|
||||
#include "selectiveScan.h"
|
||||
|
||||
#include "chunkScan/bmmchunk.h"
|
||||
#include "chunkScan/chunkcumsum.h"
|
||||
#include "chunkScan/chunkscan.h"
|
||||
#include "chunkScan/chunkstate.h"
|
||||
#include "chunkScan/statepassing.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
@ -319,8 +325,6 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
int samples = params.batch;
|
||||
int channels = params.dim;
|
||||
|
||||
TLLM_CHECK(params.is_variable_B);
|
||||
TLLM_CHECK(params.is_variable_C);
|
||||
TLLM_CHECK(params.dstate == 16);
|
||||
|
||||
int const threads = 128;
|
||||
@ -331,6 +335,107 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
selective_scan_loop_kernel<input_t, weight_t><<<grid, block, 0, stream>>>(params);
|
||||
}
|
||||
|
||||
template <typename input_t, typename weight_t>
|
||||
void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
{
|
||||
int B = params.batch;
|
||||
int L = params.max_seqlen;
|
||||
int H = params.nheads;
|
||||
int P = params.dim / H;
|
||||
int G = params.ngroups;
|
||||
int N = params.dstate;
|
||||
int Q = params.chunk_size;
|
||||
|
||||
bool dtsp = params.delta_softplus;
|
||||
|
||||
if constexpr (std::is_same_v<input_t, half>)
|
||||
{
|
||||
dim3 bds[5], tds[5];
|
||||
int shms[5];
|
||||
|
||||
ChunkCumsumKernelFuncFp16 chunk_cumsum = getChunkCumsumKernelFp16(B, L, H, Q, dtsp, &bds[0], &tds[0], &shms[0]);
|
||||
ChunkStateKernelFuncFp16 chunk_state = getChunkStateKernelFp16(B, L, H, P, G, N, Q, &bds[1], &tds[1], &shms[1]);
|
||||
StatePassingKernelFuncFp16 state_passing
|
||||
= getStatePassingKernelFp16(B, L, H, P, N, Q, &bds[2], &tds[2], &shms[2]);
|
||||
BmmChunkKernelFuncFp16 bmm_chunk = getBmmChunkKernelFp16(B, L, G, N, Q, &bds[3], &tds[3], &shms[3]);
|
||||
ChunkScanKernelFuncFp16 chunk_scan = getChunkScanKernelFp16(B, L, H, P, G, N, Q, &bds[4], &tds[4], &shms[4]);
|
||||
|
||||
half* mxY = (half*) params.out_ptr;
|
||||
half* mxOs = (half*) params.Os_ptr;
|
||||
half* mxFs = (half*) params.x_ptr;
|
||||
float* mxSt = (float*) params.St_ptr;
|
||||
float* mxdc = (float*) params.dc_ptr;
|
||||
float* mxdA = (float*) params.dA_ptr;
|
||||
half const* mxdt = (half const*) params.delta_ptr;
|
||||
float const* mxdb = (float const*) params.delta_bias_ptr;
|
||||
float const* mxA = (float const*) params.A_ptr;
|
||||
half* mxCB = (half*) params.CB_ptr;
|
||||
half const* mxBC = (half const*) params.BC_ptr;
|
||||
float const* mxD = (float const*) params.D_ptr;
|
||||
half const* mxX = (half const*) params.u_ptr;
|
||||
half const* mxZ = (half const*) params.z_ptr;
|
||||
|
||||
auto rp = params.remove_padding;
|
||||
auto ltip = params.last_token_ids_ptr;
|
||||
auto ssmp = params.slot_mapping_ptr;
|
||||
|
||||
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
|
||||
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
|
||||
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
|
||||
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
|
||||
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
|
||||
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
|
||||
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
|
||||
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
|
||||
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
|
||||
}
|
||||
else if constexpr (std::is_same_v<input_t, __nv_bfloat16>)
|
||||
{
|
||||
dim3 bds[5], tds[5];
|
||||
int shms[5];
|
||||
|
||||
ChunkCumsumKernelFuncBf16 chunk_cumsum = getChunkCumsumKernelBf16(B, L, H, Q, dtsp, &bds[0], &tds[0], &shms[0]);
|
||||
ChunkStateKernelFuncBf16 chunk_state = getChunkStateKernelBf16(B, L, H, P, G, N, Q, &bds[1], &tds[1], &shms[1]);
|
||||
StatePassingKernelFuncBf16 state_passing
|
||||
= getStatePassingKernelBf16(B, L, H, P, N, Q, &bds[2], &tds[2], &shms[2]);
|
||||
BmmChunkKernelFuncBf16 bmm_chunk = getBmmChunkKernelBf16(B, L, G, N, Q, &bds[3], &tds[3], &shms[3]);
|
||||
ChunkScanKernelFuncBf16 chunk_scan = getChunkScanKernelBf16(B, L, H, P, G, N, Q, &bds[4], &tds[4], &shms[4]);
|
||||
|
||||
__nv_bfloat16* mxY = (__nv_bfloat16*) params.out_ptr;
|
||||
__nv_bfloat16* mxOs = (__nv_bfloat16*) params.Os_ptr;
|
||||
__nv_bfloat16* mxFs = (__nv_bfloat16*) params.x_ptr;
|
||||
float* mxSt = (float*) params.St_ptr;
|
||||
float* mxdc = (float*) params.dc_ptr;
|
||||
float* mxdA = (float*) params.dA_ptr;
|
||||
__nv_bfloat16 const* mxdt = (__nv_bfloat16 const*) params.delta_ptr;
|
||||
float const* mxdb = (float const*) params.delta_bias_ptr;
|
||||
float const* mxA = (float const*) params.A_ptr;
|
||||
__nv_bfloat16* mxCB = (__nv_bfloat16*) params.CB_ptr;
|
||||
__nv_bfloat16 const* mxBC = (__nv_bfloat16 const*) params.BC_ptr;
|
||||
float const* mxD = (float const*) params.D_ptr;
|
||||
__nv_bfloat16 const* mxX = (__nv_bfloat16 const*) params.u_ptr;
|
||||
__nv_bfloat16 const* mxZ = (__nv_bfloat16 const*) params.z_ptr;
|
||||
|
||||
auto rp = params.remove_padding;
|
||||
auto ltip = params.last_token_ids_ptr;
|
||||
auto ssmp = params.slot_mapping_ptr;
|
||||
|
||||
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
|
||||
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
|
||||
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
|
||||
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
|
||||
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
|
||||
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
|
||||
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
|
||||
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
|
||||
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE(input_t, weight_t) \
|
||||
template void invokeSelectiveScan<input_t, weight_t>(SSMParamsBase & params, cudaStream_t stream);
|
||||
|
||||
@ -341,9 +446,19 @@ INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE(__nv_bfloat16, float);
|
||||
#endif
|
||||
#undef INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE
|
||||
|
||||
#define INSTANTIATE_CHUNK_SCAN_DATA_TYPE(input_t, weight_t) \
|
||||
template void invokeChunkScan<input_t, weight_t>(SSMParamsBase & params, cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_CHUNK_SCAN_DATA_TYPE(float, float);
|
||||
INSTANTIATE_CHUNK_SCAN_DATA_TYPE(half, float);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_CHUNK_SCAN_DATA_TYPE(__nv_bfloat16, float);
|
||||
#endif
|
||||
#undef INSTANTIATE_CHUNK_SCAN_DATA_TYPE
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128>
|
||||
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, bool MAMBA_V1 = true>
|
||||
__launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParamsBase params)
|
||||
{
|
||||
|
||||
@ -359,15 +474,21 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
weight_t* dt_bias = reinterpret_cast<weight_t*>(params.delta_bias_ptr);
|
||||
bool dt_softplus = params.delta_softplus;
|
||||
int num_channels = params.dim;
|
||||
int nheads = params.nheads;
|
||||
int ngroups = params.ngroups;
|
||||
|
||||
int const channel = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (channel >= num_channels)
|
||||
return;
|
||||
int const sample = blockIdx.y;
|
||||
int const head_dim = num_channels / nheads;
|
||||
int const head = channel / head_dim;
|
||||
int const head_chl = channel % head_dim;
|
||||
int const group = head / (nheads / ngroups);
|
||||
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
|
||||
int const bc_cols = DSTATE * 2 + params.dt_rank;
|
||||
int const b_offset = params.dt_rank;
|
||||
int const c_offset = params.dt_rank + DSTATE;
|
||||
int const bc_offset = MAMBA_V1 ? sample * (DSTATE * 2 + params.dt_rank) : sample * DSTATE * ngroups * 2;
|
||||
int const b_offset = MAMBA_V1 ? params.dt_rank : DSTATE * group;
|
||||
int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : DSTATE * (ngroups + group);
|
||||
|
||||
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
|
||||
input_t* my_output = &output[sample * num_channels];
|
||||
@ -375,30 +496,45 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
float rA[DSTATE];
|
||||
float rB[DSTATE];
|
||||
float rC[DSTATE];
|
||||
|
||||
float rState[DSTATE];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
rA[i] = toFloat(A[i * num_channels + channel]);
|
||||
rB[i] = toFloat(B[sample * bc_cols + b_offset + i]);
|
||||
rC[i] = toFloat(C[sample * bc_cols + c_offset + i]);
|
||||
rState[i] = toFloat(my_state[i * num_channels + channel]);
|
||||
}
|
||||
|
||||
float my_x, my_dt, my_z, my_dt_bias, my_D;
|
||||
my_x = toFloat(x[sample * num_channels + channel]);
|
||||
my_dt = toFloat(dt[sample * num_channels + channel]);
|
||||
my_z = z ? toFloat(z[sample * num_channels + channel]) : 0.f;
|
||||
my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f;
|
||||
my_D = D ? toFloat(D[channel]) : 0.f;
|
||||
|
||||
if (MAMBA_V1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
rA[i] = toFloat(A[i * num_channels + channel]);
|
||||
rB[i] = toFloat(B[bc_offset + b_offset + i]);
|
||||
rC[i] = toFloat(C[bc_offset + c_offset + i]);
|
||||
rState[i] = toFloat(my_state[i * num_channels + channel]);
|
||||
}
|
||||
my_dt = toFloat(dt[sample * num_channels + channel]);
|
||||
my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f;
|
||||
my_D = D ? toFloat(D[channel]) : 0.f;
|
||||
}
|
||||
else
|
||||
{
|
||||
float A_tmp = toFloat(A[head]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
rA[i] = A_tmp;
|
||||
rB[i] = toFloat(B[bc_offset + b_offset + i]);
|
||||
rC[i] = toFloat(C[bc_offset + c_offset + i]);
|
||||
rState[i] = toFloat(my_state[(head * DSTATE + i) * head_dim + head_chl]);
|
||||
}
|
||||
my_dt = toFloat(dt[sample * nheads + head]);
|
||||
my_dt_bias = dt_bias ? toFloat(dt_bias[head]) : 0.f;
|
||||
my_D = D ? toFloat(D[head]) : 0.f;
|
||||
}
|
||||
|
||||
float dt_b = my_dt + my_dt_bias;
|
||||
float dt_b_sp;
|
||||
if (dt_softplus)
|
||||
{
|
||||
// dt_b_sp = dt_b <= 20.f ? logf(1.f + expf(dt_b)) : dt_b; // softplus
|
||||
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
|
||||
}
|
||||
|
||||
@ -407,19 +543,21 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
// float dA = expf(rA[i] * dt_b_sp);
|
||||
float dA = __expf(rA[i] * dt_b_sp);
|
||||
float dB = rB[i] * dt_b_sp;
|
||||
float sdA = rState[i] * dA;
|
||||
float dBx = dB * my_x;
|
||||
float newState = sdA + dBx;
|
||||
convertAndStore(&my_state[i * num_channels + channel], newState); // Write the new state back out to the cache
|
||||
// Write the new state back out to the cache
|
||||
if (MAMBA_V1)
|
||||
convertAndStore(&my_state[i * num_channels + channel], newState);
|
||||
else
|
||||
convertAndStore(&my_state[(head * DSTATE + i) * head_dim + head_chl], newState);
|
||||
out += newState * rC[i];
|
||||
}
|
||||
|
||||
if (z)
|
||||
{
|
||||
// float sig_z = 1.0 / (1.0 + exp(0.f - my_z));
|
||||
float sig_z = __fdividef(1.f, (1.f + __expf(0.f - my_z)));
|
||||
float silu_z = my_z * sig_z;
|
||||
out *= silu_z;
|
||||
@ -433,16 +571,25 @@ void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream)
|
||||
{
|
||||
int samples = params.batch;
|
||||
int channels = params.dim;
|
||||
int nheads = params.nheads;
|
||||
int ngroups = params.ngroups;
|
||||
|
||||
int const threads = 128;
|
||||
int const blocks = (channels + threads - 1) / threads;
|
||||
dim3 block(threads, 1);
|
||||
dim3 grid(blocks, samples);
|
||||
|
||||
TLLM_CHECK(params.is_variable_B);
|
||||
TLLM_CHECK(params.is_variable_C);
|
||||
TLLM_CHECK(params.dstate == 16);
|
||||
selective_scan_update_kernel<input_t, weight_t><<<grid, block, 0, stream>>>(params);
|
||||
TLLM_CHECK_WITH_INFO(nheads % ngroups == 0, "nheads must be divisible by ngroups");
|
||||
if (params.is_mamab2)
|
||||
{
|
||||
TLLM_CHECK(params.dstate == 128);
|
||||
selective_scan_update_kernel<input_t, weight_t, 128, 128, false><<<grid, block, 0, stream>>>(params);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK(params.dstate == 16);
|
||||
selective_scan_update_kernel<input_t, weight_t, 16, 128, true><<<grid, block, 0, stream>>>(params);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SELECTIVE_SCAN_UPDATE_DATA_TYPE(input_t, weight_t) \
|
||||
|
||||
@ -40,13 +40,11 @@ namespace kernels
|
||||
|
||||
struct SSMParamsBase
|
||||
{
|
||||
int batch, dim, dstate, dt_rank;
|
||||
int batch, dim, dstate, dt_rank, nheads, ngroups, chunk_size;
|
||||
int max_seqlen; // only valid for padded input.
|
||||
bool remove_padding;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
bool delta_softplus;
|
||||
bool is_mamab2;
|
||||
|
||||
// Common data pointers.
|
||||
void* __restrict__ A_ptr;
|
||||
@ -58,6 +56,12 @@ struct SSMParamsBase
|
||||
void* __restrict__ out_ptr;
|
||||
void* __restrict__ x_ptr;
|
||||
void* __restrict__ z_ptr;
|
||||
// Workspace data pointers.
|
||||
void* __restrict__ Os_ptr;
|
||||
void* __restrict__ St_ptr;
|
||||
void* __restrict__ dc_ptr;
|
||||
void* __restrict__ dA_ptr;
|
||||
void* __restrict__ CB_ptr;
|
||||
int const* __restrict__ last_token_ids_ptr;
|
||||
int const* __restrict__ slot_mapping_ptr;
|
||||
};
|
||||
@ -67,6 +71,9 @@ struct SSMParamsBase
|
||||
template <typename input_t, typename weight_t>
|
||||
void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream);
|
||||
|
||||
template <typename input_t, typename weight_t>
|
||||
void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream);
|
||||
|
||||
template <typename input_t, typename weight_t>
|
||||
void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
|
||||
@ -158,7 +158,7 @@ void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileT
|
||||
common::check_cuda_error(cudaStreamCreate(&mStream));
|
||||
|
||||
int const startMinMRounded = nextPowerOfTwo(dims.minM);
|
||||
for (int m = startMinMRounded; m < maxM; m *= 2)
|
||||
for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2)
|
||||
{
|
||||
profileTactics(m, dims.n, dims.k);
|
||||
}
|
||||
@ -184,7 +184,7 @@ std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHa
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int const mRounded = std::min(nextPowerOfTwo(m), getMaxProfileM());
|
||||
int const mRounded = std::min(std::max(1, nextPowerOfTwo(m)), getMaxProfileM());
|
||||
fflush(stdout);
|
||||
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
|
||||
}
|
||||
|
||||
@ -275,8 +275,23 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
|
||||
{
|
||||
TLLM_CHECK(mHeadSize > 0);
|
||||
|
||||
int const beamWidth
|
||||
= isCrossAttention() ? 1 : (useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1] : 1);
|
||||
int beamWidth = -1;
|
||||
if (!isCrossAttention() && useKVCache())
|
||||
{
|
||||
// desc_val == -1 means beam_width is not static, we should look at min/max/opt.
|
||||
//
|
||||
// In prepareEnqueueGeneration, we'll prepare for all cases where beam_width doesn't exceed max.
|
||||
// TODO(minwei): pass min AND max to prepareEnqueueGeneration instead of max only.
|
||||
int desc_val = in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1];
|
||||
int max_val = in[getIdx(IdxEntry::CACHE_INDIR)].max.d[1];
|
||||
beamWidth = desc_val == -1 ? max_val : desc_val;
|
||||
}
|
||||
else
|
||||
{
|
||||
beamWidth = 1;
|
||||
}
|
||||
TLLM_CHECK(beamWidth != -1);
|
||||
|
||||
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
|
||||
// unless each layer has different attention window sizes.
|
||||
// the kv_cache capacity.
|
||||
|
||||
@ -29,18 +29,22 @@ static char const* SELECTIVE_SCAN_PLUGIN_NAME{"SelectiveScan"};
|
||||
PluginFieldCollection SelectiveScanPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> SelectiveScanPluginCreator::mPluginAttributes;
|
||||
|
||||
SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dt_rank, bool isVariableB, bool isVariableC,
|
||||
bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState)
|
||||
SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize,
|
||||
bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2)
|
||||
: mDim(dim)
|
||||
, mDState(dstate)
|
||||
, mDtRank(dt_rank)
|
||||
, mIsVariableB(isVariableB)
|
||||
, mIsVariableC(isVariableC)
|
||||
, mDtRank(dtRank)
|
||||
, mNHeads(nHeads)
|
||||
, mNGroups(nGroups)
|
||||
, mChunkSize(chunkSize)
|
||||
, mDeltaSoftplus(deltaSoftplus)
|
||||
, mType(type)
|
||||
, mRemovePadding(removePadding)
|
||||
, mPagedState(pagedState)
|
||||
, mZEnabled(zEnabled)
|
||||
, mIsMamba2(isMamba2)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (!mIsMamba2), "Pre SM 80 GPUs do not support Mamba2");
|
||||
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16),
|
||||
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF),
|
||||
@ -54,12 +58,15 @@ SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length)
|
||||
read(d, mDim);
|
||||
read(d, mDState);
|
||||
read(d, mDtRank);
|
||||
read(d, mIsVariableB);
|
||||
read(d, mIsVariableC);
|
||||
read(d, mNHeads);
|
||||
read(d, mNGroups);
|
||||
read(d, mChunkSize);
|
||||
read(d, mDeltaSoftplus);
|
||||
read(d, mType);
|
||||
read(d, mRemovePadding);
|
||||
read(d, mPagedState);
|
||||
read(d, mZEnabled);
|
||||
read(d, mIsMamba2);
|
||||
TLLM_CHECK(d == a + length);
|
||||
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type");
|
||||
TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF),
|
||||
@ -69,8 +76,8 @@ SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length)
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* SelectiveScanPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new SelectiveScanPlugin(
|
||||
mDim, mDState, mDtRank, mIsVariableB, mIsVariableC, mDeltaSoftplus, mType, mRemovePadding, mPagedState);
|
||||
auto* plugin = new SelectiveScanPlugin(mDim, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, mDeltaSoftplus, mType,
|
||||
mRemovePadding, mPagedState, mZEnabled, mIsMamba2);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
@ -91,7 +98,8 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions(
|
||||
bool SelectiveScanPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx() || (mPagedState && pos == getSlotMappingIdx()))
|
||||
if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx()
|
||||
|| (mRemovePadding && pos == getHostContextLengthIdx()) || (mPagedState && pos == getSlotMappingIdx()))
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
@ -117,14 +125,56 @@ void SelectiveScanPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc cons
|
||||
size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
return 0;
|
||||
if (!mIsMamba2)
|
||||
return 0;
|
||||
|
||||
int const NUM_BUFFERS = 5;
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
|
||||
if (mRemovePadding)
|
||||
{
|
||||
int B = inputs[getLastTokenIdsIdx()].dims.d[0];
|
||||
int BxL = inputs[getInputTensorIdx()].dims.d[0]; // num_tokens
|
||||
int H = mNHeads;
|
||||
int P = inputs[getInputTensorIdx()].dims.d[1] / H;
|
||||
int G = mNGroups;
|
||||
int N = inputs[getBCIdx()].dims.d[1] / G / 2;
|
||||
int Q = mChunkSize;
|
||||
int BxC = (BxL + Q - 1) / Q + B;
|
||||
|
||||
workspaces[0] = BxC * H * N * P * 2; // g_mxOs_
|
||||
workspaces[1] = BxC * H * N * P * 4; // g_mxSt_ in float
|
||||
workspaces[2] = BxC * H * Q * 4; // g_mxdc_ in float
|
||||
workspaces[3] = BxC * H * Q * 4; // g_mxdA_ in float
|
||||
workspaces[4] = BxC * G * Q * Q * 2; // g_mxCB_
|
||||
}
|
||||
else
|
||||
{
|
||||
int B = inputs[getInputTensorIdx()].dims.d[0];
|
||||
int L = inputs[getInputTensorIdx()].dims.d[1];
|
||||
int H = mNHeads;
|
||||
int P = inputs[getInputTensorIdx()].dims.d[2] / H;
|
||||
int G = mNGroups;
|
||||
int N = inputs[getBCIdx()].dims.d[2] / G / 2;
|
||||
int Q = mChunkSize;
|
||||
int C = (L + Q - 1) / Q;
|
||||
|
||||
workspaces[0] = B * C * H * N * P * 2; // g_mxOs_
|
||||
workspaces[1] = B * C * H * N * P * 4; // g_mxSt_ in float
|
||||
workspaces[2] = B * C * H * Q * 4; // g_mxdc_ in float
|
||||
workspaces[3] = B * C * H * Q * 4; // g_mxdA_ in float
|
||||
workspaces[4] = B * C * G * Q * Q * 2; // g_mxCB_
|
||||
}
|
||||
|
||||
return calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
}
|
||||
|
||||
void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch, const size_t dim,
|
||||
const size_t maxSeqLen, const size_t dstate, const size_t dtRank, bool const isVariableB, bool const isVariableC,
|
||||
void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC,
|
||||
void const* D, void const* z, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus,
|
||||
bool removePadding)
|
||||
const size_t maxSeqLen, const size_t dstate, const size_t dtRank, const size_t nHeads, const size_t nGroups,
|
||||
const size_t chunkSize, void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A,
|
||||
void const* BC, void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr,
|
||||
void const* dAPtr, void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out,
|
||||
bool deltaSoftplus, bool removePadding)
|
||||
{
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@ -134,12 +184,13 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch
|
||||
params.max_seqlen = maxSeqLen;
|
||||
params.dstate = dstate;
|
||||
params.dt_rank = dtRank;
|
||||
params.nheads = nHeads;
|
||||
params.ngroups = nGroups;
|
||||
params.chunk_size = chunkSize;
|
||||
|
||||
params.delta_softplus = deltaSoftplus;
|
||||
params.remove_padding = removePadding;
|
||||
|
||||
params.is_variable_B = isVariableB;
|
||||
params.is_variable_C = isVariableC;
|
||||
params.is_mamab2 = mIsMamba2;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.u_ptr = const_cast<void*>(x);
|
||||
@ -151,6 +202,11 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch
|
||||
params.out_ptr = out;
|
||||
params.x_ptr = statePtr;
|
||||
params.z_ptr = const_cast<void*>(z);
|
||||
params.Os_ptr = const_cast<void*>(osPtr);
|
||||
params.St_ptr = const_cast<void*>(stPtr);
|
||||
params.dc_ptr = const_cast<void*>(dcPtr);
|
||||
params.dA_ptr = const_cast<void*>(dAPtr);
|
||||
params.CB_ptr = const_cast<void*>(cbPtr);
|
||||
params.last_token_ids_ptr = lastTokenIds;
|
||||
params.slot_mapping_ptr = slotMapping;
|
||||
}
|
||||
@ -162,24 +218,30 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
{
|
||||
// inputs
|
||||
// 0. input_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim]
|
||||
// 1. state [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
|
||||
// 2. delta [batch_size, max_seq_len, dim] or [num_tokens, dim]
|
||||
// 3. delta_bias [dim]
|
||||
// 4. A [dstate, dim]
|
||||
// 5. BC [batch_size, max_seq_len, dt_rank + dstate * 2] or [num_tokens, dt_rank + dstate * 2]
|
||||
// 6. D [dim]
|
||||
// 7. z [batch_size, max_seq_len, dim] or [num_tokens, dim]
|
||||
// 8. host_request_types [batch_size] int32. 0: context; 1: generation.
|
||||
// 9. last_token_ids [batch_size] int32
|
||||
// 1. state mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
|
||||
// mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state
|
||||
// 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
|
||||
// 3. delta_bias, [dim] for mamba, [nheads] for mamba2
|
||||
// 4. A, [dstate, dim] for mamba, [nheads] for mamba2
|
||||
// 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
|
||||
// mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for
|
||||
// remove_input_padding
|
||||
// 6. D, [dim] for mamba, [nheads] for mamba2
|
||||
// 7. host_request_types [batch_size] int32. 0: context; 1: generation.
|
||||
// 8. last_token_ids [batch_size] int32
|
||||
// 9. host_context_lengths [batch_size] int32, optional for remove_input_padding
|
||||
// 10. state_slot_mapping [batch_size] int32, optional for paged state
|
||||
// 11. z [batch_size, max_seq_len, dim] or [num_tokens, dim]
|
||||
// outputs
|
||||
// 0. output_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim]
|
||||
// 1. state [batch_size, dstate, dim]
|
||||
// 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2
|
||||
auto const batch_size = inputDesc[getHostRequestTypesIdx()].dims.d[0];
|
||||
int max_seq_len;
|
||||
if (mRemovePadding)
|
||||
{
|
||||
max_seq_len = -1;
|
||||
int const* host_context_length = static_cast<int const*>(inputs[getHostContextLengthIdx()]);
|
||||
max_seq_len = *std::max_element(host_context_length, host_context_length + batch_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -192,17 +254,72 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
SSMParamsBase ssm_params;
|
||||
|
||||
int const* slotMapping = mPagedState ? static_cast<int const*>(inputs[getSlotMappingIdx()]) : nullptr;
|
||||
void const* z = mZEnabled ? inputs[getZIdx()] : nullptr;
|
||||
|
||||
void* statePtr = mPagedState ? *reinterpret_cast<void**>(const_cast<void*>(inputs[getStateIdx()])) : outputs[1];
|
||||
|
||||
setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mIsVariableB, mIsVariableC, statePtr,
|
||||
// Workspace pointer shift
|
||||
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
|
||||
size_t offset = 0;
|
||||
|
||||
T* mxOs = nullptr;
|
||||
float* mxSt = nullptr;
|
||||
float* mxdc = nullptr;
|
||||
float* mxdA = nullptr;
|
||||
T* mxCB = nullptr;
|
||||
|
||||
if (!mIsMamba2) /* no workspace needed */
|
||||
;
|
||||
else if (mRemovePadding)
|
||||
{
|
||||
int B = inputDesc[getLastTokenIdsIdx()].dims.d[0];
|
||||
int BxL = inputDesc[getInputTensorIdx()].dims.d[0]; // num_tokens
|
||||
int H = mNHeads;
|
||||
int P = inputDesc[getInputTensorIdx()].dims.d[1] / H;
|
||||
int G = mNGroups;
|
||||
int N = inputDesc[getBCIdx()].dims.d[1] / G / 2;
|
||||
int Q = mChunkSize;
|
||||
int BxC = (BxL + Q - 1) / Q + B;
|
||||
|
||||
mxOs = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 2));
|
||||
mxSt = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 4));
|
||||
mxdc = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4));
|
||||
mxdA = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4));
|
||||
mxCB = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * G * Q * Q * 2));
|
||||
}
|
||||
else
|
||||
{
|
||||
int B = inputDesc[getInputTensorIdx()].dims.d[0];
|
||||
int L = inputDesc[getInputTensorIdx()].dims.d[1];
|
||||
int H = mNHeads;
|
||||
int P = inputDesc[getInputTensorIdx()].dims.d[2] / H;
|
||||
int G = mNGroups;
|
||||
int N = inputDesc[getBCIdx()].dims.d[2] / G / 2;
|
||||
int Q = mChunkSize;
|
||||
int C = (L + Q - 1) / Q;
|
||||
|
||||
mxOs = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 2));
|
||||
mxSt = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 4));
|
||||
mxdc = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4));
|
||||
mxdA = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4));
|
||||
mxCB = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * G * Q * Q * 2));
|
||||
}
|
||||
|
||||
setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, statePtr,
|
||||
inputs[getInputTensorIdx()], inputs[getDeltaIdx()], inputs[getDeltaBiasIdx()], inputs[getAIdx()],
|
||||
inputs[getBCIdx()], inputs[getDIdx()], inputs[getZIdx()], static_cast<int const*>(inputs[getLastTokenIdsIdx()]),
|
||||
slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding);
|
||||
inputs[getBCIdx()], inputs[getDIdx()], z, mxOs, mxSt, mxdc, mxdA, mxCB,
|
||||
static_cast<int const*>(inputs[getLastTokenIdsIdx()]), slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding);
|
||||
|
||||
if (reqTypes[0] == RequestType::kCONTEXT)
|
||||
{
|
||||
invokeSelectiveScan<T, float>(ssm_params, stream);
|
||||
if (mIsMamba2)
|
||||
{
|
||||
invokeChunkScan<T, float>(ssm_params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeSelectiveScan<T, float>(ssm_params, stream);
|
||||
}
|
||||
}
|
||||
else if (reqTypes[0] == RequestType::kGENERATION)
|
||||
{
|
||||
@ -276,8 +393,9 @@ void SelectiveScanPlugin::terminate() noexcept {}
|
||||
|
||||
size_t SelectiveScanPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mIsVariableB) + sizeof(mIsVariableC)
|
||||
+ sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState);
|
||||
return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mNHeads) + sizeof(mNGroups) + sizeof(mChunkSize)
|
||||
+ sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState) + sizeof(mZEnabled)
|
||||
+ sizeof(mIsMamba2);
|
||||
}
|
||||
|
||||
void SelectiveScanPlugin::serialize(void* buffer) const noexcept
|
||||
@ -286,12 +404,15 @@ void SelectiveScanPlugin::serialize(void* buffer) const noexcept
|
||||
write(d, mDim);
|
||||
write(d, mDState);
|
||||
write(d, mDtRank);
|
||||
write(d, mIsVariableB);
|
||||
write(d, mIsVariableC);
|
||||
write(d, mNHeads);
|
||||
write(d, mNGroups);
|
||||
write(d, mChunkSize);
|
||||
write(d, mDeltaSoftplus);
|
||||
write(d, mType);
|
||||
write(d, mRemovePadding);
|
||||
write(d, mPagedState);
|
||||
write(d, mZEnabled);
|
||||
write(d, mIsMamba2);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
@ -306,15 +427,18 @@ SelectiveScanPluginCreator::SelectiveScanPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 16));
|
||||
mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 16));
|
||||
mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 16));
|
||||
mPluginAttributes.emplace_back(PluginField("is_variable_B", nullptr, PluginFieldType::kINT8, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("is_variable_C", nullptr, PluginFieldType::kINT8, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("nheads", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("ngroups", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("chunk_size", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("delta_softplus", nullptr, PluginFieldType::kINT8, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("z_enabled", nullptr, PluginFieldType::kINT8, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("is_mamba2", nullptr, PluginFieldType::kINT8, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
@ -337,8 +461,8 @@ PluginFieldCollection const* SelectiveScanPluginCreator::getFieldNames() noexcep
|
||||
IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int dim, dstate, dtRank;
|
||||
bool isVariableB, isVariableC, deltaSoftplus, removePadding, pagedState;
|
||||
int dim, dstate, dtRank, nHeads, nGroups, chunkSize;
|
||||
bool deltaSoftplus, removePadding, pagedState, zEnabled, isMamab2;
|
||||
nvinfer1::DataType type;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
@ -359,15 +483,20 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFiel
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
dtRank = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "is_variable_B"))
|
||||
else if (!strcmp(attrName, "nheads"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
isVariableB = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
nHeads = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "is_variable_C"))
|
||||
else if (!strcmp(attrName, "ngroups"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
isVariableC = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
nGroups = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "chunk_size"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
chunkSize = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "delta_softplus"))
|
||||
{
|
||||
@ -389,11 +518,21 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFiel
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
pagedState = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "z_enabled"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
zEnabled = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "is_mamba2"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
isMamab2 = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
|
||||
}
|
||||
}
|
||||
try
|
||||
{
|
||||
auto* obj = new SelectiveScanPlugin(
|
||||
dim, dstate, dtRank, isVariableB, isVariableC, deltaSoftplus, type, removePadding, pagedState);
|
||||
auto* obj = new SelectiveScanPlugin(dim, dstate, dtRank, nHeads, nGroups, chunkSize, deltaSoftplus, type,
|
||||
removePadding, pagedState, zEnabled, isMamab2);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -30,25 +30,30 @@ namespace tensorrt_llm::plugins
|
||||
|
||||
// inputs
|
||||
// 0. input_tensor [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// 1. state [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
|
||||
// 2. delta [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// 3. delta_bias [dim]
|
||||
// 4. A [dstate, dim]
|
||||
// 5. BC [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
|
||||
// 6. D [dim]
|
||||
// 7. z [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// 8. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none.
|
||||
// 9. last_token_ids [batch_size] int32
|
||||
// 1. state, mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
|
||||
// mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state
|
||||
// 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
|
||||
// 3. delta_bias, [dim] for mamba, [nheads] for mamba2
|
||||
// 4. A, [dstate, dim] for mamba, [nheads] for mamba2
|
||||
// 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
|
||||
// mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for
|
||||
// remove_input_padding
|
||||
// 6. D, [dim] for mamba, [nheads] for mamba2
|
||||
// 7. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none.
|
||||
// 8. last_token_ids [batch_size] int32
|
||||
// 9. host_context_lengths [batch_size] int32, optional for remove_input_padding
|
||||
// 10. state_slot_mapping [batch_size] int32, optional for paged state
|
||||
// 11. z [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// outputs
|
||||
// 0. output_tensor [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
// 1. state [batch_size, dstate, dim]
|
||||
// 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2
|
||||
|
||||
class SelectiveScanPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
SelectiveScanPlugin(int dim, int dstate, int dt_rank, bool isVariableB, bool isVariableC, bool deltaSoftplus,
|
||||
nvinfer1::DataType type, bool removePadding, bool pagedState);
|
||||
SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize, bool deltaSoftplus,
|
||||
nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2);
|
||||
|
||||
SelectiveScanPlugin(void const* data, size_t length);
|
||||
|
||||
@ -128,45 +133,63 @@ private:
|
||||
return 6;
|
||||
};
|
||||
|
||||
IndexType getZIdx() const
|
||||
IndexType getHostRequestTypesIdx() const
|
||||
{
|
||||
return 7;
|
||||
};
|
||||
|
||||
IndexType getHostRequestTypesIdx() const
|
||||
IndexType getLastTokenIdsIdx() const
|
||||
{
|
||||
return 8;
|
||||
};
|
||||
|
||||
IndexType getLastTokenIdsIdx() const
|
||||
IndexType getHostContextLengthIdx() const
|
||||
{
|
||||
return 9;
|
||||
if (mRemovePadding)
|
||||
return 9;
|
||||
else
|
||||
return 8;
|
||||
};
|
||||
|
||||
IndexType getSlotMappingIdx() const
|
||||
{
|
||||
return 10;
|
||||
if (mPagedState)
|
||||
return getHostContextLengthIdx() + 1;
|
||||
else
|
||||
return getHostContextLengthIdx();
|
||||
};
|
||||
|
||||
IndexType getZIdx() const
|
||||
{
|
||||
if (mZEnabled)
|
||||
return getSlotMappingIdx() + 1;
|
||||
else
|
||||
return getSlotMappingIdx();
|
||||
};
|
||||
|
||||
void setSSMParams(tensorrt_llm::kernels::SSMParamsBase& params,
|
||||
// sizes
|
||||
const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dstate, const size_t dtRank,
|
||||
bool const isVariableB, bool const isVariableC,
|
||||
const size_t nHeads, const size_t nGroups, const size_t chunkSize,
|
||||
// device pointers
|
||||
void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC,
|
||||
void const* D, void const* z, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus,
|
||||
void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr, void const* dAPtr,
|
||||
void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus,
|
||||
bool removePadding);
|
||||
|
||||
private:
|
||||
int mDim;
|
||||
int mDState;
|
||||
int mDtRank;
|
||||
bool mIsVariableB;
|
||||
bool mIsVariableC;
|
||||
int mNHeads;
|
||||
int mNGroups;
|
||||
int mChunkSize;
|
||||
bool mDeltaSoftplus;
|
||||
nvinfer1::DataType mType;
|
||||
bool mRemovePadding = false;
|
||||
bool mPagedState = false;
|
||||
bool mZEnabled = true;
|
||||
bool mIsMamba2 = false;
|
||||
};
|
||||
|
||||
class SelectiveScanPluginCreator : public BaseCreator
|
||||
|
||||
@ -300,6 +300,9 @@ int WeightOnlyQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc const* input
|
||||
int const n = TLLM_INT32_CAST(inputDesc[1].dims.d[1]);
|
||||
int const k = TLLM_INT32_CAST(inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]);
|
||||
|
||||
if (m == 0)
|
||||
return 0;
|
||||
|
||||
bool const use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled;
|
||||
#if defined(ENABLE_BF16)
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16,
|
||||
|
||||
@ -69,9 +69,9 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
|
||||
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(std::move(tensorMap), mRequestId);
|
||||
inferenceRequest->setIsStreaming(isStreaming());
|
||||
|
||||
if (mlogitsPostProcessor)
|
||||
if (mLogitsPostProcessor)
|
||||
{
|
||||
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mlogitsPostProcessor));
|
||||
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mLogitsPostProcessor));
|
||||
}
|
||||
|
||||
return inferenceRequest;
|
||||
@ -79,7 +79,7 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
|
||||
|
||||
std::string InferenceRequest::serialize() const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mlogitsPostProcessor == std::nullopt,
|
||||
TLLM_CHECK_WITH_INFO(mLogitsPostProcessor == std::nullopt,
|
||||
"Serializing InferenceRequest with logitsPostProcessor set is not supported."
|
||||
"Please set the callback after de-serialization");
|
||||
std::vector<std::int64_t> serialized{toTrtLlm()->serialize()};
|
||||
|
||||
@ -228,15 +228,14 @@ void InitBindings(pybind11::module_& m)
|
||||
std::optional<SizeType32> const&, std::optional<SizeType32> const&,
|
||||
std::optional<std::list<VecTokens>>, std::optional<std::list<VecTokens>>, std::optional<Tensor>,
|
||||
std::optional<tle::ExternalDraftTokensConfig>, std::optional<tle::PromptTuningConfig>,
|
||||
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>, bool>(),
|
||||
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>>(),
|
||||
py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false,
|
||||
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
|
||||
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(),
|
||||
py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(),
|
||||
py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(),
|
||||
py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(),
|
||||
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(),
|
||||
py::arg("return_all_generated_tokens") = false)
|
||||
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none())
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -255,9 +254,7 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
|
||||
&tle::Request::setLogitsPostProcessorName)
|
||||
.def_property(
|
||||
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
|
||||
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
|
||||
&tle::Request::setReturnAllGeneratedTokens);
|
||||
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds);
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
|
||||
py::class_<tle::Result>(m, "Result")
|
||||
|
||||
@ -46,8 +46,8 @@ FieldType parseJsonFieldOr(Json const& json, std::string_view name, FieldType de
|
||||
}
|
||||
catch (nlohmann::json::out_of_range& e)
|
||||
{
|
||||
TLLM_LOG_INFO("Parameter %s cannot be read from json:", std::string(name).c_str());
|
||||
TLLM_LOG_INFO(e.what());
|
||||
TLLM_LOG_DEBUG("Parameter %s cannot be read from json:", std::string(name).c_str());
|
||||
TLLM_LOG_DEBUG(e.what());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@ -62,13 +62,13 @@ std::optional<FieldType> parseJsonFieldOptional(Json const& json, std::string_vi
|
||||
}
|
||||
catch (nlohmann::json::out_of_range const& e)
|
||||
{
|
||||
TLLM_LOG_INFO(e.what());
|
||||
TLLM_LOG_INFO("Optional value for parameter %s will not be set.", std::string(name).c_str());
|
||||
TLLM_LOG_DEBUG(e.what());
|
||||
TLLM_LOG_DEBUG("Optional value for parameter %s will not be set.", std::string(name).c_str());
|
||||
}
|
||||
catch (nlohmann::json::type_error const& e)
|
||||
{
|
||||
TLLM_LOG_INFO(e.what());
|
||||
TLLM_LOG_INFO("Optional value for parameter %s will not be set.", std::string(name).c_str());
|
||||
TLLM_LOG_DEBUG(e.what());
|
||||
TLLM_LOG_DEBUG("Optional value for parameter %s will not be set.", std::string(name).c_str());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@ -427,10 +427,17 @@ GptJsonConfig parseJson(InputType&& input)
|
||||
auto const& stateSize = pretrainedConfig.at("state_size").template get<SizeType32>();
|
||||
auto const& convKernel = pretrainedConfig.at("conv_kernel").template get<SizeType32>();
|
||||
auto const& rnnHiddenSize = pretrainedConfig.at("rnn_hidden_size").template get<SizeType32>();
|
||||
auto const& rnnConvDimSize = pretrainedConfig.at("rnn_conv_dim_size").template get<SizeType32>();
|
||||
ModelConfig::RnnConfig rnnConfig{};
|
||||
rnnConfig.stateSize = stateSize;
|
||||
rnnConfig.convKernel = convKernel;
|
||||
rnnConfig.rnnHiddenSize = rnnHiddenSize;
|
||||
rnnConfig.rnnConvDimSize = rnnConvDimSize;
|
||||
if (pretrainedConfig.contains("rnn_head_size"))
|
||||
{
|
||||
auto const& rnnHeadSize = pretrainedConfig.at("rnn_head_size").template get<SizeType32>();
|
||||
rnnConfig.rnnHeadSize = rnnHeadSize;
|
||||
}
|
||||
modelConfig.setRnnConfig(rnnConfig);
|
||||
}
|
||||
}
|
||||
@ -449,10 +456,17 @@ GptJsonConfig parseJson(InputType&& input)
|
||||
auto const& stateSize = builderConfig.at("state_size").template get<SizeType32>();
|
||||
auto const& convKernel = builderConfig.at("conv_kernel").template get<SizeType32>();
|
||||
auto const& rnnHiddenSize = builderConfig.at("rnn_hidden_size").template get<SizeType32>();
|
||||
auto const& rnnConvDimSize = builderConfig.at("rnn_conv_dim_size").template get<SizeType32>();
|
||||
ModelConfig::RnnConfig rnnConfig{};
|
||||
rnnConfig.stateSize = stateSize;
|
||||
rnnConfig.convKernel = convKernel;
|
||||
rnnConfig.rnnHiddenSize = rnnHiddenSize;
|
||||
rnnConfig.rnnConvDimSize = rnnConvDimSize;
|
||||
if (builderConfig.contains("rnn_head_size"))
|
||||
{
|
||||
auto const& rnnHeadSize = builderConfig.at("rnn_head_size").template get<SizeType32>();
|
||||
rnnConfig.rnnHeadSize = rnnHeadSize;
|
||||
}
|
||||
modelConfig.setRnnConfig(rnnConfig);
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ RnnStateBuffers::RnnStateBuffers(
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(modelConfig.isRnnBased());
|
||||
TLLM_CHECK_WITH_INFO(modelConfig.hasRnnConfig(), "RNN only support Mamba now.");
|
||||
TLLM_CHECK_WITH_INFO(modelConfig.hasRnnConfig(), "RNN only support Mamba1/Mamba2/RecurrentGemma now.");
|
||||
auto maxBatchSize = modelConfig.getMaxBatchSize();
|
||||
auto maxBeamWidth = modelConfig.getMaxBeamWidth();
|
||||
auto maxBatchBeam = maxBatchSize * maxBeamWidth;
|
||||
@ -46,23 +46,37 @@ RnnStateBuffers::RnnStateBuffers(
|
||||
mConvKernel = rnnConfig->convKernel;
|
||||
mStateSize = rnnConfig->stateSize;
|
||||
mRnnHiddenSize = rnnConfig->rnnHiddenSize;
|
||||
mRnnHeadSize = rnnConfig->rnnHeadSize;
|
||||
mRnnConvDimSize = rnnConfig->rnnConvDimSize;
|
||||
auto dType = modelConfig.getDataType();
|
||||
auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism());
|
||||
mLocalNbLayers = localNbLayers;
|
||||
mMaxBeamWidth = maxBeamWidth;
|
||||
mUseMambaConv1dPlugin = modelConfig.useMambaConv1dPlugin();
|
||||
auto rnnStatesShape = ITensor::makeShape({localNbLayers * maxBatchBeam, mStateSize, mRnnHiddenSize});
|
||||
auto const rnnStatesShape = [&]()
|
||||
{
|
||||
if (mRnnHeadSize > 0)
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{localNbLayers * maxBatchBeam, mRnnHiddenSize / mRnnHeadSize, mStateSize, mRnnHeadSize});
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{localNbLayers * maxBatchBeam, mStateSize, mRnnHiddenSize});
|
||||
}
|
||||
}();
|
||||
auto const convStatesShape = [&]()
|
||||
{
|
||||
if (mUseMambaConv1dPlugin)
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{localNbLayers * maxBatchBeam, mConvKernel - 1, mRnnHiddenSize});
|
||||
{localNbLayers * maxBatchBeam, mConvKernel - 1, mRnnConvDimSize});
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{localNbLayers * maxBatchBeam, mRnnHiddenSize, mConvKernel - 1});
|
||||
{localNbLayers * maxBatchBeam, mRnnConvDimSize, mConvKernel - 1});
|
||||
}
|
||||
}();
|
||||
auto& bufferManager = runtime.getBufferManager();
|
||||
@ -96,18 +110,30 @@ RnnStateBuffers::RnnStateBuffers(
|
||||
void RnnStateBuffers::reshape(SizeType32 batchSize)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto rnnStatesShape = ITensor::makeShape({mLocalNbLayers * batchSize * mMaxBeamWidth, mStateSize, mRnnHiddenSize});
|
||||
auto const rnnStatesShape = [&]()
|
||||
{
|
||||
if (mRnnHeadSize > 0)
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnHiddenSize / mRnnHeadSize, mStateSize, mRnnHeadSize});
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{mLocalNbLayers * batchSize * mMaxBeamWidth, mStateSize, mRnnHiddenSize});
|
||||
}
|
||||
}();
|
||||
auto const convStatesShape = [&]()
|
||||
{
|
||||
if (mUseMambaConv1dPlugin)
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{mLocalNbLayers * batchSize * mMaxBeamWidth, mConvKernel - 1, mRnnHiddenSize});
|
||||
{mLocalNbLayers * batchSize * mMaxBeamWidth, mConvKernel - 1, mRnnConvDimSize});
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensorrt_llm::runtime::ITensor::makeShape(
|
||||
{mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnHiddenSize, mConvKernel - 1});
|
||||
{mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnConvDimSize, mConvKernel - 1});
|
||||
}
|
||||
}();
|
||||
rnnStates->reshape(rnnStatesShape);
|
||||
|
||||
@ -39,7 +39,8 @@ public:
|
||||
TensorPtr convStates; // [layer_count * batch_beam, conv_kernel - 1, rnn_hidden_size]
|
||||
TensorPtr convStatesAlt; // [layer_count * batch_beam, conv_kernel - 1, rnn_hidden_size]
|
||||
|
||||
std::vector<TensorPtr> rnnState; // [batch_beam, state_size, rnn_hidden_size]
|
||||
std::vector<TensorPtr> rnnState; // [batch_beam, state_size, rnn_hidden_size] or
|
||||
// [batch_beam, num_heads, rnn_hidden_size, rnn_head_size]
|
||||
std::vector<TensorPtr> convState; // [batch_beam, conv_kernel - 1, rnn_hidden_size]
|
||||
std::vector<TensorPtr> convStateAlt; // [batch_beam, conv_kernel - 1, rnn_hidden_size]
|
||||
|
||||
@ -83,6 +84,8 @@ private:
|
||||
SizeType32 mConvKernel = 0;
|
||||
SizeType32 mStateSize = 0;
|
||||
SizeType32 mRnnHiddenSize = 0;
|
||||
SizeType32 mRnnHeadSize = 0;
|
||||
SizeType32 mRnnConvDimSize = 0;
|
||||
|
||||
int mLocalNbLayers = 0;
|
||||
int mMaxBeamWidth = 0;
|
||||
|
||||
@ -120,13 +120,27 @@ def main():
|
||||
parser.add_argument('--tp-size', type=int, default=1)
|
||||
parser.add_argument('--out-dir', type=Path, required=True)
|
||||
parser.add_argument('--num-loras', type=int, default=1)
|
||||
parser.add_argument('--num-layers', type=int, default=2)
|
||||
parser.add_argument('--adapter-size', type=int, default=8)
|
||||
parser.add_argument('--hidden-size', type=int, default=16)
|
||||
parser.add_argument('--mlp-hidden-size', type=int, default=32)
|
||||
parser.add_argument('--no-generate-cache-pages',
|
||||
action='store_true',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--config-ids-filter',
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Comma separated list of ids to include. For example, use --config-ids-filter=0 for attn_qkv only."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
num_layers = 2
|
||||
adapter_size = 8
|
||||
hidden_size = 16
|
||||
mlp_hidden_size = 32
|
||||
num_layers = args.num_layers
|
||||
adapter_size = args.adapter_size
|
||||
hidden_size = args.hidden_size
|
||||
mlp_hidden_size = args.mlp_hidden_size
|
||||
configs = [
|
||||
(0, num_layers, adapter_size, hidden_size, 3 * hidden_size), # attn_qkv
|
||||
(1, num_layers, adapter_size // 2, hidden_size, hidden_size), # attn_q
|
||||
@ -149,6 +163,9 @@ def main():
|
||||
(12, num_layers, adapter_size, hidden_size,
|
||||
hidden_size), # cross_attn_dense
|
||||
]
|
||||
if args.config_ids_filter:
|
||||
config_ids_filter = [int(x) for x in args.config_ids_filter.split(",")]
|
||||
configs = [c for c in configs if c[0] in config_ids_filter]
|
||||
|
||||
for lora_idx in range(args.num_loras):
|
||||
all_source = []
|
||||
@ -178,19 +195,20 @@ def main():
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# copy weights into cache pages
|
||||
for rank in range(args.tp_size):
|
||||
page_block = torch.zeros((8, 18, 128),
|
||||
dtype=torch.float32,
|
||||
device='cpu')
|
||||
copy_to_cache_pages(all_source,
|
||||
all_config,
|
||||
page_block,
|
||||
configs,
|
||||
tp_rank=rank,
|
||||
tp_size=args.tp_size)
|
||||
if not args.no_generate_cache_pages:
|
||||
for rank in range(args.tp_size):
|
||||
page_block = torch.zeros((8, 18, 128),
|
||||
dtype=torch.float32,
|
||||
device='cpu')
|
||||
copy_to_cache_pages(all_source,
|
||||
all_config,
|
||||
page_block,
|
||||
configs,
|
||||
tp_rank=rank,
|
||||
tp_size=args.tp_size)
|
||||
|
||||
out_path = output_dir / f'cache_pages_rank{rank}.npy'
|
||||
np.save(out_path, page_block)
|
||||
out_path = output_dir / f'cache_pages_rank{rank}.npy'
|
||||
np.save(out_path, page_block)
|
||||
|
||||
source_out_path = output_dir / 'source.npy'
|
||||
config_out_path = output_dir / 'config.npy'
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse as _arg
|
||||
import copy
|
||||
import logging as _log
|
||||
import os as _os
|
||||
import pathlib as _pl
|
||||
@ -135,9 +136,18 @@ def run_tests(build_dir: _pl.Path,
|
||||
"--num-loras=128",
|
||||
]
|
||||
|
||||
generate_gpt2_lora_data_args_tp1 = [
|
||||
python_exe,
|
||||
str(resources_dir / "scripts" / "generate_test_lora_weights.py"),
|
||||
"--out-dir=cpp/tests/resources/data/lora-test-weights-gpt2-tp1",
|
||||
"--tp-size=1", "--hidden-size=768", "--num-layers=12",
|
||||
"--config-ids-filter=0", "--no-generate-cache-pages"
|
||||
]
|
||||
|
||||
run_command(generate_lora_data_args_tp1, cwd=root_dir, timeout=100)
|
||||
run_command(generate_lora_data_args_tp2, cwd=root_dir, timeout=100)
|
||||
run_command(generate_multi_lora_tp2_args, cwd=root_dir, timeout=100)
|
||||
run_command(generate_gpt2_lora_data_args_tp1, cwd=root_dir, timeout=100)
|
||||
|
||||
if not skip_unit_tests:
|
||||
run_unit_tests(build_dir=build_dir, timeout=test_timeout)
|
||||
@ -484,9 +494,15 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
|
||||
]
|
||||
run_command(trt_model_test, cwd=tests_dir, env=cpp_env,
|
||||
timeout=timeout) # expecting ~ 1200s
|
||||
cpp_blocking_env = copy.copy(cpp_env)
|
||||
cpp_blocking_env["CUDA_LAUNCH_BLOCKING"] = '1'
|
||||
run_command(trt_model_test,
|
||||
cwd=tests_dir,
|
||||
env=cpp_blocking_env,
|
||||
timeout=timeout) # expecting ~ 1200s
|
||||
|
||||
#Executor test in leader mode
|
||||
new_env = cpp_env
|
||||
new_env = copy.copy(cpp_env)
|
||||
xml_output_file = build_dir / "results-multi-gpu-llama-exec-leader-mode.xml"
|
||||
new_env["RUN_LLAMA_MULTI_GPU"] = "true"
|
||||
trt_model_test = [
|
||||
@ -507,7 +523,7 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
|
||||
run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500)
|
||||
|
||||
#EncDec test in leader mode
|
||||
new_env = cpp_env
|
||||
new_env = copy.copy(cpp_env)
|
||||
xml_output_file = build_dir / "results-multi-gpu-t5-exec-leader-mode.xml"
|
||||
trt_model_test = [
|
||||
"mpirun", "-n", "4", "--allow-run-as-root", "executor/executorTest",
|
||||
|
||||
@ -52,7 +52,7 @@ The `awaitResponses` method of the `Executor` class returns a vector of response
|
||||
|
||||
### The Result Class
|
||||
|
||||
The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = false` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request.
|
||||
The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request.
|
||||
|
||||
## C++ Executor API Example
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
(quick-start-guide-compile)=
|
||||
## Compile the Model into a TensorRT Engine
|
||||
|
||||
Use the included [Llama model definition](https://nvidia.github.io/TensorRT-LLM/_modules/tensorrt_llm/models/llama/model.html#LLaMAModel). This is a minimal example that includes some of the optimizations available in TensorRT-LLM.
|
||||
Use the included [Llama model definition](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama). This is a minimal example that includes some of the optimizations available in TensorRT-LLM.
|
||||
|
||||
```bash
|
||||
# Launch the Tensorrt-LLM container
|
||||
@ -138,3 +138,7 @@ In this Quick Start Guide, you:
|
||||
For more examples, refer to:
|
||||
|
||||
- [examples/](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for showcases of how to run a quick benchmark on latest LLMs.
|
||||
|
||||
## Links
|
||||
- [Best Practices Guide](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/perf-best-practices.md)
|
||||
- [Support Matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html)
|
||||
|
||||
@ -83,7 +83,7 @@ The following table shows the supported software for TensorRT-LLM.
|
||||
- [mT5](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/enc_dec)
|
||||
- [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt)
|
||||
- [Phi-1.5/Phi-2/Phi-3](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/phi)
|
||||
- [Qwen](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen)
|
||||
- [Qwen/Qwen1.5/Qwen2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen)
|
||||
- [Qwen-VL](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwenvl)
|
||||
- [RecurrentGemma](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/recurrentgemma)
|
||||
- [Replit Code](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mpt)
|
||||
@ -103,6 +103,7 @@ The following table shows the supported software for TensorRT-LLM.
|
||||
- [Fuyu](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
- [Kosmos](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
- [LLaVA-v1.5](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
- [LLaVa-Next](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
- [NeVA](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
- [Nougat](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
- [Phi-3-vision](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.15.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
protobuf
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -29,8 +29,10 @@ from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer,
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._ipc_utils import set_peer_access
|
||||
from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch
|
||||
from tensorrt_llm.lora_manager import LoraManager
|
||||
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
||||
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
||||
|
||||
|
||||
@ -387,19 +389,27 @@ class TRTLLMEncDecModel:
|
||||
(max_input_length, ),
|
||||
dtype=hidden_states_dtype('max_input_length'),
|
||||
device=self.device).contiguous()
|
||||
batch_size = input_lengths.size(0)
|
||||
inputs['host_request_types'] = torch.IntTensor([0] *
|
||||
batch_size).to('cpu')
|
||||
if self.encoder_model_config.remove_input_padding:
|
||||
inputs['host_context_lengths'] = input_lengths.to('cpu')
|
||||
|
||||
if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:
|
||||
if self.encoder_model_config.use_custom_all_reduce and self.encoder_runtime_mapping.tp_size > 1:
|
||||
set_peer_access(self.encoder_runtime_mapping)
|
||||
ipc_buffers, all_reduce_workspace = CustomAllReduceHelper.allocate_workspace(
|
||||
self.encoder_runtime_mapping,
|
||||
CustomAllReduceHelper.max_workspace_size_auto(
|
||||
self.encoder_runtime_mapping.tp_size))
|
||||
inputs['all_reduce_workspace'] = all_reduce_workspace
|
||||
|
||||
if self.encoder_model_config.lora_plugin:
|
||||
inputs.update(
|
||||
self.encoder_lora_manager.input_buffers(
|
||||
self.lora_task_uids,
|
||||
self.encoder_runtime_mapping,
|
||||
self.encoder_model_config.num_layers,
|
||||
))
|
||||
batch_size = input_lengths.size(0)
|
||||
inputs['host_request_types'] = torch.IntTensor([0] *
|
||||
batch_size).to('cpu')
|
||||
if self.encoder_model_config.remove_input_padding:
|
||||
inputs['host_context_lengths'] = input_lengths.to('cpu')
|
||||
|
||||
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
|
||||
self.encoder_session.set_shapes(inputs)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
transformers>=4.31.0
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
|
||||
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
|
||||
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
flax~=0.8.0
|
||||
# jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
|
||||
jax~=0.4.19; platform_system == "Windows"
|
||||
|
||||
@ -414,6 +414,21 @@ trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu \
|
||||
|
||||
Note that GPT attention plugin is required to be enabled for SmoothQuant for now.
|
||||
|
||||
User can also use `ModelOpt` to do INT8 quantization. Especially for gpt variant Starcoder2.
|
||||
```bash
|
||||
python3 example/quantization/quantize.py --model_dir starcoder2 \
|
||||
--dtype float16 \
|
||||
--qformat int8_sq \
|
||||
--output_dir starcoder2/trt_ckpt/int8-sq/
|
||||
```
|
||||
Then, use `trtllm-build` to build engine(s).
|
||||
|
||||
```bash
|
||||
trtllm-build --checkpoint_dir starcoder2/trt_ckpt/int8-sq/ \
|
||||
--output_dir starcoder2/trt_engine/int8-sq/ \
|
||||
--builder_opt 4
|
||||
```
|
||||
|
||||
|
||||
### INT8 KV Cache
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
rouge_score~=0.1.2
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -4,12 +4,49 @@ Here we show you a preview of how it works and how to use it.
|
||||
|
||||
Note that the APIs are not stable and only support the LLaMA model. We appreciate your patience and understanding as we improve this API.
|
||||
|
||||
## Quick start
|
||||
|
||||
Please install the required packages first:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Here is a simple example to show how to use the HLAPI:
|
||||
|
||||
Firstly, import the `LLM` and `SamplingParams` from the `tensorrt_llm` package, and create an LLM object with a HuggingFace (HF) model directly. Here we use the TinyLlama model as an example, `LLM` will download the model from the HuggingFace model hub automatically. You can also specify local models, either in HF format, TensorRT-LLM engine format or TensorRT-LLM checkpoint format.
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
```
|
||||
|
||||
Secondly, generate text with the `generate` method of the `LLM` object directly with a batch of prompts, the `sampling_params` is optional, and you can customize the sampling strategy with it.
|
||||
|
||||
```python
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
Please refer to the [LLM quickstart](./quickstart_example.py) for the complete example.
|
||||
|
||||
## Examples
|
||||
|
||||
You can refer to [llm_examples.py](llm_examples.py) for all of the examples, and run it with the [run_examples.py](./run_examples.py) script, the command is as follows:
|
||||
|
||||
```sh
|
||||
@ -34,15 +71,16 @@ python3 llm_examples.py --task run_llm_on_tensor_parallel \
|
||||
```
|
||||
|
||||
## Model preparation
|
||||
The HLAPI supports three kinds of model formats:
|
||||
The `LLM` class supports four kinds of model inputs:
|
||||
|
||||
1. HuggingFace models
|
||||
2. TensorRT-LLM engine built by trtllm-build tool or saved by the HLAPI
|
||||
3. TensorRT-LLM checkpoints, converted by `convert_checkpoint.py` in examples
|
||||
1. **HuggingFace model name**: triggers a download from the HuggingFace model hub, e.g. `TinyLlama/TinyLlama-1.1B-Chat-v1.0` in the quickstart.
|
||||
1. **Local HuggingFace models**: uses a locally stored HuggingFace model.
|
||||
2. **Local TensorRT-LLM engine**: built by `trtllm-build` tool or saved by the HLAPI
|
||||
3. **Local TensorRT-LLM checkpoints**: converted by `convert_checkpoint.py` script in the examples
|
||||
|
||||
All kinds of models could be used directly by the HLAPI, and the `LLM(model=<any-model-path>)` could accept any kind of them.
|
||||
All kinds of the model inputs can be seamlessly integrated with the HLAPI, and the `LLM(model=<any-model-path>)` construcotr can accommodate models in any of the above formats.
|
||||
|
||||
Let's elaborate on the preparation of the three kinds of model formats.
|
||||
Let's delve into the preparation of the three kinds of local model formats.
|
||||
|
||||
### Option 1: From HuggingFace models
|
||||
|
||||
@ -143,16 +181,16 @@ It is easy to enable Tensor Parallelism in the HLAPI. For example, setting `para
|
||||
```python
|
||||
from tensorrt_llm.hlapi import LLM
|
||||
|
||||
llm = LLM(<llama_model_path>, tensor_parallel_size=2)
|
||||
llm = LLM(<llama_model_path>,
|
||||
tensor_parallel_size=2)
|
||||
```
|
||||
|
||||
### Pipeline Parallelism
|
||||
Similar to Tensor Parallelism, you can enable Pipeline Parallelism in the HLAPI with following code:
|
||||
|
||||
```python
|
||||
config.parallel_config.pp_size = 4
|
||||
# you can also mix TP and PP
|
||||
# config.parallel_config.tp_size = 2
|
||||
llm = LLM(<llama_model_path>,
|
||||
pipeline_parallel_size=4)
|
||||
```
|
||||
|
||||
### Automatic Parallelism (in preview)
|
||||
@ -266,17 +304,28 @@ Please refer to these classes for more details.
|
||||
|
||||
## LLM pipeline configuration
|
||||
|
||||
### Runtime customization
|
||||
### Build configuration
|
||||
Apart from the arguments mentioned above, you can also customize the build configuration with the `build_config` class and other arguments borrowed from the lower-level APIs. For example:
|
||||
|
||||
```python
|
||||
llm = LLM(<model-path>,
|
||||
build_config=BuildConfig(
|
||||
max_new_tokens=4096,
|
||||
max_batch_size=128,
|
||||
max_beam_width=4))
|
||||
```
|
||||
|
||||
### Runtime customization
|
||||
Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other arguments borrowed from the lower-level APIs. For example:
|
||||
|
||||
For `kv_cache_config` and `streaming_llm` features, please refer to LLaMA's [README](../llama/README.md) for more details, the high-level API supports these features as well by setting the corresponding fields in the `LLM()` constructor.
|
||||
|
||||
```python
|
||||
from tensorrt_llm.hlapi import LLM, KvCacheConfig
|
||||
|
||||
llm = LLM(<llama_model_path>,
|
||||
kv_cache_config=KvCacheConfig(
|
||||
max_new_tokens=128,
|
||||
free_gpu_memory_fraction=0.8))
|
||||
max_new_tokens=128,
|
||||
free_gpu_memory_fraction=0.8))
|
||||
```
|
||||
|
||||
### Tokenizer customization
|
||||
@ -313,3 +362,13 @@ RequestOutput(request_id=1, prompt=None, prompt_token_ids=[1, 15043, 29892, 590,
|
||||
```
|
||||
|
||||
Note that the `text` field in `CompletionOutput` is empty since the tokenizer is deactivated.
|
||||
|
||||
### Build caching
|
||||
Although the HLAPI runs the engine building in the background, you can also cache the built engine to disk and load it in the next run to save the engine building time.
|
||||
|
||||
To enable the build cache, there are two ways to do it:
|
||||
|
||||
1. Use the environment variable: `export TLLM_HLAPI_BUILD_CACHE=1` to enable the build cache globally, and optionally export `TLLM_HLAPI_BUILD_CACHE_ROOT` to specify the cache root directory.
|
||||
2. Pass the `build_cache_config` to the `LLM` constructor
|
||||
|
||||
The build cache will reuse the built engine if all the building settings are the same, or it will rebuild the engine.
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets==2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
|
||||
@ -440,11 +440,17 @@ expected results:
|
||||
|
||||
#### 1M long context test case
|
||||
|
||||
- Prepare 1M needle-in-a-haystack datasets
|
||||
|
||||
```bash
|
||||
python examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 7
|
||||
```
|
||||
|
||||
- Llama-3-8B example
|
||||
|
||||
```bash
|
||||
git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/
|
||||
|
||||
python examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 7
|
||||
|
||||
python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
|
||||
--output_dir /tmp/llama-3-8B-1048k/trt_ckpts \
|
||||
--dtype float16 \
|
||||
@ -454,8 +460,8 @@ python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt
|
||||
--output_dir /tmp/llama-3-8B-1048k/trt_engines \
|
||||
--gemm_plugin float16 \
|
||||
--max_num_tokens 4096 \
|
||||
--max_input_len 1048576 \
|
||||
--max_output_len 10 \
|
||||
--max_input_len 1048566 \
|
||||
--max_seq_len 1048576 \
|
||||
--use_paged_context_fmha enable \
|
||||
--workers 4
|
||||
|
||||
@ -463,7 +469,37 @@ mpirun -n 4 --allow-run-as-root python examples/eval_long_context.py --task pas
|
||||
--engine_dir /tmp/llama-3-8B-1048k/trt_engines \
|
||||
--tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
|
||||
--stop_idx 1 \
|
||||
--max_input_length 1048576 \
|
||||
--max_input_length 1048566 \
|
||||
--enable_chunked_context \
|
||||
--max_tokens_in_paged_kv_cache 1100000
|
||||
```
|
||||
|
||||
- Llama-3-70B example
|
||||
|
||||
For the 70B model, at least 8 A100 80GB GPUs are required.
|
||||
|
||||
```bash
|
||||
git-lfs clone https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k/
|
||||
|
||||
python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-70B-Instruct-Gradient-1048k/ \
|
||||
--output_dir /tmp/llama-3-70B-1048k/trt_ckpts \
|
||||
--dtype float16 \
|
||||
--tp_size 8
|
||||
|
||||
python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-70B-1048k/trt_ckpts \
|
||||
--output_dir /tmp/llama-3-70B-1048k/trt_engines \
|
||||
--gemm_plugin float16 \
|
||||
--max_num_tokens 4096 \
|
||||
--max_input_len 1048566 \
|
||||
--max_seq_len 1048576 \
|
||||
--use_paged_context_fmha enable \
|
||||
--workers 8
|
||||
|
||||
mpirun -n 8 --allow-run-as-root python examples/eval_long_context.py --task passkey \
|
||||
--engine_dir /tmp/llama-3-70B-1048k/trt_engines \
|
||||
--tokenizer_dir ./Llama-3-70B-Instruct-Gradient-1048k/ \
|
||||
--stop_idx 1 \
|
||||
--max_input_length 1048566 \
|
||||
--enable_chunked_context \
|
||||
--max_tokens_in_paged_kv_cache 1100000
|
||||
```
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -2,6 +2,15 @@
|
||||
|
||||
This document shows how to build and run a [Mamba](https://github.com/state-spaces/mamba) model in TensorRT-LLM on a single GPU.
|
||||
|
||||
- [Mamba](#mamba)
|
||||
- [Overview](#overview)
|
||||
- [Support Matrix](#support-matrix)
|
||||
- [Usage](#usage)
|
||||
- [1. Download weights from HuggingFace Transformers](#1-download-weights-from-huggingface-transformers)
|
||||
- [2. Convert weights from HF Transformers to TensorRT-LLM format](#2-convert-dweights-from-hf-transformers-to-tensorrt-llm-format)
|
||||
- [3. Build TensorRT engine(s)](#3-build-tensorrt-engines)
|
||||
- [4. Run summarization task with the TensorRT engine(s)](#4-run-summarization-task-with-the-tensorrt-engines)
|
||||
|
||||
## Overview
|
||||
|
||||
The TensorRT-LLM Mamba implementation can be found in [`tensorrt_llm/models/mamba/model.py`](../../tensorrt_llm/models/mamba/model.py). The TensorRT-LLM Mamba example code is located in [`examples/mamba`](./). There is one main file:
|
||||
@ -15,8 +24,13 @@ In addition, there are two shared files in the parent folder [`examples`](../) f
|
||||
|
||||
|
||||
## Support Matrix
|
||||
* FP16
|
||||
* BF16
|
||||
|
||||
| Model Name | FP16 | BF16 |
|
||||
| :--------------: | :---: | :---: |
|
||||
| Mamba1 | Y | Y |
|
||||
| Mamba2 | Y | Y |
|
||||
|
||||
* Mamba2: TensorRT-LLM can only support the pure Mamba model for now, will support the hybrid models later.
|
||||
|
||||
## Usage
|
||||
|
||||
@ -32,23 +46,20 @@ pip install -r requirements.txt
|
||||
git lfs install
|
||||
```
|
||||
|
||||
There are five HF checkpoints available. Use one of the following commands to fetch the checkpoint you are interested in.
|
||||
There are different HF checkpoints available. For Mamba1, TensorRT-LLM can support those Transformers compatible models. Here're some examples to fetch the checkpoint.
|
||||
|
||||
```bash
|
||||
# mamba-2.8b
|
||||
git clone https://huggingface.co/state-spaces/mamba-2.8b-hf ./mamba_model/mamba-2.8b
|
||||
|
||||
# mamba-1.4b
|
||||
git clone https://huggingface.co/state-spaces/mamba-1.4b-hf ./mamba_model/mamba-1.4b
|
||||
|
||||
# mamba-790m
|
||||
git clone https://huggingface.co/state-spaces/mamba-790m-hf ./mamba_model/mamba-790m
|
||||
|
||||
# mamba-370m
|
||||
git clone https://huggingface.co/state-spaces/mamba-370m-hf ./mamba_model/mamba-370m
|
||||
|
||||
# mamba-130m
|
||||
git clone https://huggingface.co/state-spaces/mamba-130m-hf ./mamba_model/mamba-130m
|
||||
|
||||
# mamba2-2.7b
|
||||
git clone https://huggingface.co/state-spaces/mamba2-2.7b ./mamba_model/mamba2-2.7b
|
||||
|
||||
# mamba2-130m
|
||||
git clone https://huggingface.co/state-spaces/mamba2-130m ./mamba_model/mamba2-130m
|
||||
```
|
||||
|
||||
Since mamba models use tokenizer from gpt-neox-20b model, use the following command to fetch the checkpoint of gpt-neox-20b.
|
||||
@ -67,25 +78,20 @@ python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \
|
||||
--dtype bfloat16 \
|
||||
--output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/
|
||||
|
||||
# mamba-1.4b
|
||||
python convert_checkpoint.py --model_dir ./mamba_model/mamba-1.4b/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/
|
||||
|
||||
# mamba-790m
|
||||
python convert_checkpoint.py --model_dir ./mamba_model/mamba-790m/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/
|
||||
|
||||
# mamba-370m
|
||||
python convert_checkpoint.py --model_dir ./mamba_model/mamba-370m/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/
|
||||
|
||||
# mamba-130m
|
||||
python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/
|
||||
|
||||
# mamba2-2.7b
|
||||
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-2.7b/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/
|
||||
|
||||
# mamba2-130m
|
||||
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-130m/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/
|
||||
```
|
||||
|
||||
### 3. Build TensorRT engine(s)
|
||||
@ -101,33 +107,6 @@ trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
|
||||
|
||||
# mamba-1.4b
|
||||
trtllm-build --checkpoint_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \
|
||||
--paged_kv_cache disable \
|
||||
--gemm_plugin auto \
|
||||
--max_batch_size 8 \
|
||||
--max_input_len 924 \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba-790m
|
||||
trtllm-build --checkpoint_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ \
|
||||
--paged_kv_cache disable \
|
||||
--gemm_plugin auto \
|
||||
--max_batch_size 8 \
|
||||
--max_input_len 924 \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba-370m
|
||||
trtllm-build --checkpoint_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ \
|
||||
--paged_kv_cache disable \
|
||||
--gemm_plugin auto \
|
||||
--max_batch_size 8 \
|
||||
--max_input_len 924 \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba-130m
|
||||
trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
|
||||
--paged_kv_cache disable \
|
||||
@ -136,6 +115,24 @@ trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
|
||||
--max_input_len 924 \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba2-2.7b
|
||||
trtllm-build --checkpoint_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/ \
|
||||
--paged_kv_cache disable \
|
||||
--gemm_plugin auto \
|
||||
--max_batch_size 8 \
|
||||
--max_input_len 924 \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba2-130m
|
||||
trtllm-build --checkpoint_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/ \
|
||||
--paged_kv_cache disable \
|
||||
--gemm_plugin auto \
|
||||
--max_batch_size 8 \
|
||||
--max_input_len 924 \
|
||||
--max_seq_len 1024 \
|
||||
--output_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
|
||||
```
|
||||
|
||||
Note that when building Mamba models, you need to disable the `paged_kv_cache` as it is used for
|
||||
@ -148,7 +145,6 @@ The following section describes how to run a TensorRT-LLM Mamba model to summari
|
||||
[cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. For each summary, the script can compute the
|
||||
[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation.
|
||||
|
||||
### Run
|
||||
```bash
|
||||
# mamba-2.8b
|
||||
python ../summarize.py --test_trt_llm \
|
||||
@ -157,31 +153,24 @@ python ../summarize.py --test_trt_llm \
|
||||
--data_type bf16 \
|
||||
--engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
|
||||
|
||||
# mamba-1.4b
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./mamba_model/mamba-1.4b/ \
|
||||
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba-790m
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./mamba_model/mamba-790m/ \
|
||||
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba-370m
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./mamba_model/mamba-370m/ \
|
||||
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba-130m
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./mamba_model/mamba-130m/ \
|
||||
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba2-2.7b
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./mamba_model/mamba2-2.7b/ \
|
||||
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
|
||||
|
||||
# mamba2-130m
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./mamba_model/mamba2-130m/ \
|
||||
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
|
||||
```
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.utils import CONFIG_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import logger
|
||||
@ -55,7 +59,10 @@ def get_tllm_linear_weight(weight, prefix, bias=None):
|
||||
return results
|
||||
|
||||
|
||||
def convert_hf_mamba(hf_mamba, rank=0, dtype='float32'):
|
||||
def convert_hf_mamba(hf_mamba,
|
||||
rank=0,
|
||||
dtype='float32',
|
||||
mamba_version: str = 'Mamba1'):
|
||||
weights = {}
|
||||
tik = time.time()
|
||||
|
||||
@ -130,9 +137,12 @@ def rename_hf_to_tllm(name: str):
|
||||
# change layer name
|
||||
if 'embeddings.' in name:
|
||||
name = name.replace('embeddings', 'vocab_embedding')
|
||||
elif 'embedding.' in name:
|
||||
name = name.replace('embedding', 'vocab_embedding')
|
||||
norm_pattern = r'\d\.norm\.'
|
||||
if 'mixer.' in name:
|
||||
name = name.replace('mixer.', 'ssm.')
|
||||
elif 'norm.' in name:
|
||||
elif re.search(norm_pattern, name):
|
||||
name = name.replace('norm.', 'input_layernorm.')
|
||||
elif 'norm_f.' in name:
|
||||
name = name.replace('norm_f.', 'ln_f.')
|
||||
@ -147,7 +157,8 @@ def rename_hf_to_tllm(name: str):
|
||||
|
||||
def convert_from_hf_checkpoint(model_dir: Union[str, Path],
|
||||
rank=0,
|
||||
dtype: Union[str, torch.dtype] = torch.float32):
|
||||
dtype: Union[str, torch.dtype] = torch.float32,
|
||||
mamba_version: str = 'Mamba1'):
|
||||
logger.info('Loading weights from HF Mamba...')
|
||||
tik = time.time()
|
||||
|
||||
@ -164,15 +175,19 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path],
|
||||
param = param.detach().cpu()
|
||||
if 'A_log' in name:
|
||||
param = -torch.exp(param.float())
|
||||
param = param.permute(1, 0).contiguous()
|
||||
if mamba_version == 'Mamba1':
|
||||
param = param.permute(1, 0).contiguous()
|
||||
elif 'D' in name:
|
||||
param = param.float()
|
||||
elif 'dt_proj.bias' in name:
|
||||
param = param.float()
|
||||
elif 'dt_bias' in name:
|
||||
param = param.float()
|
||||
elif 'conv1d.weight' in name:
|
||||
param = param.unsqueeze(3)
|
||||
|
||||
if 'in_proj' in name:
|
||||
# split in_proj in Mamba1
|
||||
if 'in_proj' in name and mamba_version == 'Mamba1':
|
||||
in_proj_params = torch.split(param, param.size(0) // 2, dim=0)
|
||||
weights[tllm_name.replace('proj', 'proj_x')] = in_proj_params[0]
|
||||
weights[tllm_name.replace('proj', 'proj_z')] = in_proj_params[1]
|
||||
@ -181,9 +196,10 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path],
|
||||
del model_params
|
||||
|
||||
# lm_head
|
||||
if 'lm_head.weight' not in weights:
|
||||
weights['lm_head.weight'] = copy.deepcopy(
|
||||
weights['backbone.vocab_embedding.weight'])
|
||||
emb = weights['backbone.vocab_embedding.weight']
|
||||
if 'lm_head.weight' not in weights or weights['lm_head.weight'].data_ptr(
|
||||
) == emb.data_ptr():
|
||||
weights['lm_head.weight'] = copy.deepcopy(emb)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
@ -208,6 +224,72 @@ def convert(worker_rank, args, convert_args):
|
||||
args.output_dir / f'rank{rank}.safetensors')
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaConfig:
|
||||
|
||||
d_model: int = 2560
|
||||
d_intermediate: int = 0
|
||||
n_layer: int = 64
|
||||
vocab_size: int = 50277
|
||||
ssm_cfg: dict = field(default_factory=dict)
|
||||
attn_layer_idx: list = field(default_factory=list)
|
||||
attn_cfg: dict = field(default_factory=dict)
|
||||
rms_norm: bool = True
|
||||
residual_in_fp32: bool = True
|
||||
fused_add_norm: bool = True
|
||||
pad_vocab_size_multiple: int = 8
|
||||
tie_embeddings: bool = True
|
||||
hidden_size: int = 2560
|
||||
num_hidden_layers: int = 64
|
||||
intermediate_size: int = 0
|
||||
state_size: int = 128
|
||||
conv_kernel: int = 4
|
||||
use_bias: bool = False
|
||||
headdim: int = 64
|
||||
ngroups: int = 1
|
||||
chunk_size: int = 256
|
||||
ssm_rmsnorm: bool = True
|
||||
|
||||
def update(self, data_dict):
|
||||
self.__dict__.update(data_dict)
|
||||
|
||||
|
||||
def load_config_hf(model_name):
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
||||
config = json.load(open(resolved_archive_file))
|
||||
if 'transformers_version' in config: # transformer compatible models
|
||||
hf_config = AutoConfig.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
# TODO: change mamba_version when transformers can support Mamba2 models
|
||||
mamba_version = 'Mamba1'
|
||||
else: # state-spaces/mamba models
|
||||
hf_config = MambaConfig(**config)
|
||||
hf_config.hidden_size = hf_config.d_model
|
||||
hf_config.num_hidden_layers = hf_config.n_layer
|
||||
if 'expand' in hf_config.ssm_cfg:
|
||||
expand = hf_config.ssm_cfg['hf_config']
|
||||
hf_config.intermediate_size = expand * hf_config.d_model
|
||||
else:
|
||||
hf_config.intermediate_size = 2 * hf_config.d_model
|
||||
ssm_cfg_to_hf_cfg = {
|
||||
'd_state': 'state_size',
|
||||
'd_conv': 'conv_kernel',
|
||||
'bias': 'use_bias',
|
||||
'headdim': 'headdim',
|
||||
'ngroups': 'ngroups',
|
||||
'chunk_size': 'chunk_size',
|
||||
'rmsnorm': 'ssm_rmsnorm',
|
||||
}
|
||||
cfg_dict = {}
|
||||
for k, v in hf_config.ssm_cfg.items():
|
||||
if k in ssm_cfg_to_hf_cfg:
|
||||
cfg_dict[ssm_cfg_to_hf_cfg[k]] = v
|
||||
hf_config.update(cfg_dict)
|
||||
mamba_version = hf_config.ssm_cfg.pop("layer", "Mamba1")
|
||||
return hf_config, mamba_version
|
||||
|
||||
|
||||
def main():
|
||||
print(tensorrt_llm.__version__)
|
||||
|
||||
@ -217,8 +299,8 @@ def main():
|
||||
|
||||
args.output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(args.model_dir,
|
||||
trust_remote_code=True)
|
||||
hf_config, mamba_version = load_config_hf(args.model_dir)
|
||||
|
||||
vocab_size = hf_config.vocab_size
|
||||
pad_vocab_size_multiple = hf_config.pad_vocab_size_multiple
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
@ -239,15 +321,29 @@ def main():
|
||||
'hidden_act': 'silu',
|
||||
'num_attention_heads': 1,
|
||||
'rnn_hidden_size': hf_config.intermediate_size,
|
||||
'rnn_conv_dim_size': hf_config.intermediate_size,
|
||||
'state_size': hf_config.state_size,
|
||||
'conv_kernel': hf_config.conv_kernel,
|
||||
'use_bias': hf_config.use_bias,
|
||||
'mamba_version': mamba_version,
|
||||
}
|
||||
if mamba_version == 'Mamba2':
|
||||
conv_dim = hf_config.intermediate_size + 2 * hf_config.ngroups * hf_config.state_size
|
||||
mamba2_cfg = {
|
||||
'rnn_head_size': hf_config.headdim,
|
||||
'rnn_conv_dim_size': conv_dim,
|
||||
'ngroups': hf_config.ngroups,
|
||||
'chunk_size': hf_config.chunk_size,
|
||||
'ssm_rmsnorm': hf_config.ssm_rmsnorm,
|
||||
}
|
||||
config.update(mamba2_cfg)
|
||||
|
||||
with (args.output_dir / 'config.json').open('w') as f:
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
convert_from_ckpt = do_convert_from_ckpt(args)
|
||||
# TODO: Add convert_hf_mamba support for Mamba2 when transformers can support Mamba2 models
|
||||
assert convert_from_ckpt or mamba_version == 'Mamba2', "Mamba2 can only support convert from checkpoints."
|
||||
if not convert_from_ckpt:
|
||||
logger.info(f'Convert by using model')
|
||||
hf_mamba = AutoModelForCausalLM.from_pretrained(args.model_dir,
|
||||
@ -264,6 +360,7 @@ def main():
|
||||
convert_args['model_dir'] = args.model_dir
|
||||
else:
|
||||
convert_args['hf_mamba'] = hf_mamba
|
||||
convert_args['mamba_version'] = mamba_version
|
||||
|
||||
convert(0, args, convert_args)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
transformers>=4.39.0
|
||||
datasets~=2.14.5
|
||||
evaluate
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
transformers==4.38.2
|
||||
accelerate==0.25.0
|
||||
|
||||
@ -255,6 +255,7 @@ class Pipeline:
|
||||
self.pad_id = pad_id
|
||||
self.end_id = end_id
|
||||
self.max_attention_window_size = max_attention_window_size
|
||||
self.output_len = 2
|
||||
|
||||
def __call__(self, prompt):
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
@ -263,7 +264,7 @@ class Pipeline:
|
||||
batch_input_ids = [inputs]
|
||||
|
||||
# For multi-choice tasks like MMLU, we don't need to adjust following parameters
|
||||
output_len = 2
|
||||
output_len = self.output_len
|
||||
top_k = 1
|
||||
top_p = 0.0
|
||||
|
||||
@ -313,7 +314,8 @@ class Pipeline:
|
||||
def check_valid_length(self, prompt):
|
||||
if isinstance(self.model, nn.Module):
|
||||
return True
|
||||
return len(self.tokenizer.encode(prompt)) <= self.model.max_input_len
|
||||
input_len = len(self.tokenizer.encode(prompt))
|
||||
return input_len <= self.model.max_input_len and input_len + self.output_len <= self.model.max_seq_len
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -391,7 +393,7 @@ def main():
|
||||
model = auto_model_cls.from_pretrained(
|
||||
args.hf_model_dir,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=DTYPE_STR_MAPPING[args.data_type],
|
||||
torch_dtype=DTYPE_STR_MAPPING[args.hf_data_type],
|
||||
device_map="auto" if args.hf_device_map_auto else None,
|
||||
)
|
||||
if not args.hf_device_map_auto:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -12,7 +12,7 @@ We first describe how to run each model on a single GPU. We then provide general
|
||||
- [Deplot](#deplot)
|
||||
- [Fuyu](#fuyu)
|
||||
- [Kosmos-2](#kosmos-2)
|
||||
- [LLaVA and VILA](#llava-and-vila)
|
||||
- [LLaVA, LLaVa-NeXT and VILA](#llava-llava-next-and-vila)
|
||||
- [NeVA](#neva)
|
||||
- [Nougat](#nougat)
|
||||
- [Phi-3-vision](#phi-3-vision)
|
||||
@ -361,9 +361,9 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
|
||||
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu
|
||||
```
|
||||
|
||||
## LLaVA and VILA
|
||||
## LLaVA, LLaVa-NeXT and VILA
|
||||
|
||||
[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options.
|
||||
[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options. [LLaVA-NeXT](https://huggingface.co/collections/llava-hf/llava-next-65f75c4afac77fd37dbbe6cf) is an extension of LLaVA. TRT-LLM currently supports [Mistral-7b](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) and [ Nous-Hermes-2-Yi-34B](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) variant of LLaVA-NeXT.
|
||||
|
||||
1. Download Huggingface model weights. These models have both visual and LLM components
|
||||
unlike BLIP2 example which downloads only LLM components from Huggingface.
|
||||
@ -374,6 +374,12 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
|
||||
export MODEL_NAME="llava-1.5-7b-hf" # also llava-1.5-13b-hf
|
||||
git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
|
||||
```
|
||||
For LLaVA-NeXT,
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="llava-v1.6-mistral-7b-hf" #for 34b variant "llava-v1.6-34b-hf"
|
||||
git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
|
||||
```
|
||||
|
||||
For VILA, we need a few more steps until it is added to HF model zoo
|
||||
|
||||
@ -408,6 +414,18 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
|
||||
--max_seq_len 2560 \
|
||||
--max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) for LLaVA
|
||||
|
||||
trtllm-build \
|
||||
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
|
||||
--output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
|
||||
--gpt_attention_plugin float16 \
|
||||
--gemm_plugin float16 \
|
||||
--max_batch_size 1 \
|
||||
--max_input_len 4096 \
|
||||
--max_seq_len 5120 \
|
||||
--max_num_tokens 4096 \ # 1 (max_batch_size) * 4096 (max_input_len)
|
||||
--max_multimodal_len 4096 \ # 1 (max_batch_size) * 4096 (max_input_len)
|
||||
--use_fused_mlp
|
||||
|
||||
trtllm-build \
|
||||
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
|
||||
--output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
|
||||
@ -426,6 +444,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
|
||||
```bash
|
||||
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava # for LLaVA
|
||||
|
||||
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava_next --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 5 # 1 (max_batch_size) * 5 (because LLAVA-NeXT visual encoder can have at most 5 patches) # for LLaVA-NeXT
|
||||
|
||||
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type vila --vila_path ${VILA_PATH} # for VILA
|
||||
```
|
||||
|
||||
@ -435,7 +455,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir visual_engines/${MODEL_NAME} \
|
||||
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
|
||||
--input_text "Question: which city is this? Answer:" # for LLaVA
|
||||
--input_text "Question: which city is this? Answer:" # for LLaVA and for LLaVA-NeXT
|
||||
```
|
||||
|
||||
For VILA, you can use either local file or web url as input images.
|
||||
|
||||
@ -11,13 +11,6 @@ import yaml
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
from tensorrt_llm.builder import Builder
|
||||
# isort: on
|
||||
import json
|
||||
import math
|
||||
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq, AutoProcessor,
|
||||
Blip2ForConditionalGeneration, Blip2Processor,
|
||||
@ -25,6 +18,13 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
LlavaForConditionalGeneration, NougatProcessor,
|
||||
Pix2StructForConditionalGeneration,
|
||||
VisionEncoderDecoderModel)
|
||||
# isort: on
|
||||
import json
|
||||
import math
|
||||
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@ -33,8 +33,8 @@ def parse_arguments():
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[
|
||||
'blip2', 'llava', 'vila', 'nougat', 'cogvlm',
|
||||
'fuyu', 'pix2struct', 'neva', 'kosmos-2',
|
||||
'blip2', 'llava', 'llava_next', 'vila', 'nougat',
|
||||
'cogvlm', 'fuyu', 'pix2struct', 'neva', 'kosmos-2',
|
||||
'video-neva', 'phi-3-vision'
|
||||
],
|
||||
help="Model type")
|
||||
@ -80,7 +80,7 @@ class VisionEngineBuilder:
|
||||
build_blip2_engine(args)
|
||||
elif args.model_type == 'pix2struct':
|
||||
build_pix2struct_engine(args)
|
||||
elif args.model_type == 'llava':
|
||||
elif 'llava' in args.model_type:
|
||||
build_llava_engine(args)
|
||||
elif args.model_type == 'vila':
|
||||
assert args.vila_path is not None, "Please clone and provide VILA source code path"
|
||||
@ -305,30 +305,59 @@ def build_pix2struct_engine(args):
|
||||
|
||||
def build_llava_engine(args):
|
||||
processor = AutoProcessor.from_pretrained(args.model_path)
|
||||
raw_image = Image.new('RGB', [10, 10]) # dummy image
|
||||
image = processor(text="dummy", images=raw_image,
|
||||
return_tensors="pt")['pixel_values'].to(
|
||||
args.device, torch.float16)
|
||||
if args.model_type == "llava":
|
||||
raw_image = Image.new('RGB', [10, 10]) # dummy image
|
||||
image = processor(text="dummy", images=raw_image,
|
||||
return_tensors="pt")['pixel_values'].to(
|
||||
args.device, torch.float16)
|
||||
|
||||
class LlavaVisionWrapper(torch.nn.Module):
|
||||
class LlavaVisionWrapper(torch.nn.Module):
|
||||
|
||||
def __init__(self, tower, projector, feature_layer):
|
||||
super().__init__()
|
||||
self.tower = tower
|
||||
self.projector = projector
|
||||
self.feature_layer = feature_layer
|
||||
def __init__(self, tower, projector, feature_layer):
|
||||
super().__init__()
|
||||
self.tower = tower
|
||||
self.projector = projector
|
||||
self.feature_layer = feature_layer
|
||||
|
||||
def forward(self, image):
|
||||
all_hidden_states = self.tower(
|
||||
image, output_hidden_states=True).hidden_states
|
||||
features = all_hidden_states[self.feature_layer][:, 1:]
|
||||
return self.projector(features)
|
||||
def forward(self, image):
|
||||
all_hidden_states = self.tower(
|
||||
image, output_hidden_states=True).hidden_states
|
||||
features = all_hidden_states[self.feature_layer][:, 1:]
|
||||
return self.projector(features)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16)
|
||||
wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device),
|
||||
model.multi_modal_projector.to(args.device),
|
||||
model.config.vision_feature_layer)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16)
|
||||
wrapper = LlavaVisionWrapper(
|
||||
model.vision_tower.to(args.device),
|
||||
model.multi_modal_projector.to(args.device),
|
||||
model.config.vision_feature_layer)
|
||||
elif args.model_type == "llava_next":
|
||||
from transformers import LlavaNextForConditionalGeneration
|
||||
raw_image = Image.new('RGB', [512, 512])
|
||||
image = processor(text="dummy", images=raw_image,
|
||||
return_tensors="pt")['pixel_values'].to(
|
||||
args.device, torch.float16)[0]
|
||||
|
||||
class LlavaNextVisionWrapper(torch.nn.Module):
|
||||
|
||||
def __init__(self, vision_tower, projector):
|
||||
super().__init__()
|
||||
self.vision_tower = vision_tower
|
||||
self.projector = projector
|
||||
|
||||
def forward(self, pixel_values):
|
||||
image_features = self.vision_tower(pixel_values,
|
||||
output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[-2][:, 1:]
|
||||
image_features = self.projector(selected_image_feature)
|
||||
return image_features # (bs, 576, c)
|
||||
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16)
|
||||
wrapper = LlavaNextVisionWrapper(
|
||||
model.vision_tower.vision_model.to(args.device),
|
||||
model.multi_modal_projector.to(args.device),
|
||||
)
|
||||
|
||||
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
|
||||
build_trt_engine(
|
||||
@ -336,6 +365,11 @@ def build_llava_engine(args):
|
||||
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
|
||||
args.output_dir,
|
||||
args.max_batch_size)
|
||||
if args.model_type == "llava_next":
|
||||
image_newline = model.image_newline.data
|
||||
tensor_img_newline = {"image_newline": image_newline}
|
||||
save_file(tensor_img_newline,
|
||||
os.path.join(args.output_dir, "image_newline.safetensors"))
|
||||
|
||||
|
||||
def build_vila_engine(args):
|
||||
@ -517,7 +551,12 @@ def build_neva_engine(args):
|
||||
vision_x = self.connector(vision_x)
|
||||
return vision_x
|
||||
|
||||
encoder = AutoModel.from_pretrained(vision_config["from_pretrained"],
|
||||
vision_path = vision_config["from_pretrained"]
|
||||
joined_path = os.path.join(os.path.dirname(args.model_path),
|
||||
os.path.basename(vision_path))
|
||||
if os.path.isdir(joined_path):
|
||||
vision_path = joined_path
|
||||
encoder = AutoModel.from_pretrained(vision_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True)
|
||||
vision_encoder = encoder.vision_model
|
||||
|
||||
@ -95,6 +95,130 @@ def trt_dtype_to_torch(dtype):
|
||||
raise TypeError("%s is not supported" % dtype)
|
||||
|
||||
|
||||
class LlavaNextUtils:
|
||||
# https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
||||
|
||||
@staticmethod
|
||||
def select_best_resolution(original_size, possible_resolutions):
|
||||
"""
|
||||
Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
|
||||
Args:
|
||||
original_size (tuple): The original size of the image in the format (width, height).
|
||||
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||||
|
||||
Returns:
|
||||
tuple: The best fit resolution in the format (width, height).
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
best_fit = None
|
||||
max_effective_resolution = 0
|
||||
min_wasted_resolution = float('inf')
|
||||
|
||||
for width, height in possible_resolutions:
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(
|
||||
original_width * scale), int(original_height * scale)
|
||||
effective_resolution = min(downscaled_width * downscaled_height,
|
||||
original_width * original_height)
|
||||
wasted_resolution = (width * height) - effective_resolution
|
||||
|
||||
if effective_resolution > max_effective_resolution or (
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution):
|
||||
max_effective_resolution = effective_resolution
|
||||
min_wasted_resolution = wasted_resolution
|
||||
best_fit = (width, height)
|
||||
|
||||
return best_fit
|
||||
|
||||
@staticmethod
|
||||
def get_anyres_image_grid_shape(image_size, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (tuple): The size of the input image in the format (width, height).
|
||||
patch_size (int): The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
IMAGE_GRID_PINPOINTS = [[336, 672], [672, 336], [672, 672], [1008, 336],
|
||||
[336, 1008]]
|
||||
width, height = LlavaNextUtils.select_best_resolution(
|
||||
image_size, IMAGE_GRID_PINPOINTS)
|
||||
return width // patch_size, height // patch_size
|
||||
|
||||
@staticmethod
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
||||
original_size (tuple): The original size of the image (width, height).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The unpadded image tensor.
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
@staticmethod
|
||||
def rearrange_image_features(image_feature, image_newline, image_size):
|
||||
"""
|
||||
Combine PyTorch feature grids from image patches.
|
||||
|
||||
Args:
|
||||
image_feature (torch.Tensor): The feature grids, assumed to be in NxCxHxW format.
|
||||
image_newline (torch.Tensor): The newline embedding.
|
||||
image_size (tuple): Size of the original image (width, height).
|
||||
"""
|
||||
CLIP_IMAGE_SIZE = 336
|
||||
CLIP_PATCH_SIZE = 14
|
||||
NUM_PATCHES_PER_SIDE = CLIP_IMAGE_SIZE // CLIP_PATCH_SIZE
|
||||
if image_feature.shape[0] == 1:
|
||||
return torch.cat((image_feature, image_newline[None]), dim=0)
|
||||
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = NUM_PATCHES_PER_SIDE
|
||||
assert height * width == base_image_feature.shape[0]
|
||||
|
||||
num_patch_width, num_patch_height = LlavaNextUtils.get_anyres_image_grid_shape(
|
||||
image_size, CLIP_IMAGE_SIZE)
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width,
|
||||
height, width, -1)
|
||||
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = LlavaNextUtils.unpad_image(image_feature, image_size)
|
||||
image_feature = torch.cat(
|
||||
(image_feature, image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1)),
|
||||
dim=-1)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
return image_feature
|
||||
|
||||
|
||||
class MultimodalModelRunner:
|
||||
|
||||
def __init__(self, args):
|
||||
@ -123,7 +247,9 @@ class MultimodalModelRunner:
|
||||
|
||||
if self.model_type == 'video-neva':
|
||||
self.num_frames = config['builder_config'].get('num_frames', None)
|
||||
|
||||
if self.model_type == "llava_next":
|
||||
self.llm_name = AutoConfig.from_pretrained(
|
||||
args.hf_model_dir).text_config._name_or_path
|
||||
self.profiling_iterations = 20
|
||||
|
||||
self.init_image_encoder()
|
||||
@ -203,6 +329,14 @@ class MultimodalModelRunner:
|
||||
device="cuda") as f:
|
||||
for k in f.keys():
|
||||
self.image_newlines[k] = f.get_tensor(k)
|
||||
if self.model_type == "llava_next":
|
||||
self.image_newlines = {}
|
||||
image_newlines_path = os.path.join(self.args.visual_engine_dir,
|
||||
'image_newline.safetensors')
|
||||
with safe_open(image_newlines_path, framework="pt",
|
||||
device="cuda") as f:
|
||||
for k in f.keys():
|
||||
self.image_newlines[k] = f.get_tensor(k)
|
||||
|
||||
def init_llm(self):
|
||||
if self.decoder_llm:
|
||||
@ -276,6 +410,12 @@ class MultimodalModelRunner:
|
||||
image = input['pixel_values']
|
||||
bs = image.shape[0]
|
||||
image = image.flatten(0, 1)
|
||||
elif self.model_type == 'llava_next':
|
||||
input = image
|
||||
image = input['pixel_values']
|
||||
bs = image.shape[0]
|
||||
image = image[0]
|
||||
image_size = input['image_sizes'][0].cpu()
|
||||
|
||||
if not warmup:
|
||||
profiler.start("Vision")
|
||||
@ -366,6 +506,13 @@ class MultimodalModelRunner:
|
||||
input_ids = self.ptuning_setup_phi3(visual_features, input_ids,
|
||||
num_img_tokens)
|
||||
length = input_ids.shape[1]
|
||||
elif self.model_type == 'llava_next':
|
||||
visual_features = LlavaNextUtils.rearrange_image_features(
|
||||
visual_features, self.image_newlines["image_newline"],
|
||||
image_size)
|
||||
input_ids = self.ptuning_setup_llava_next(visual_features,
|
||||
pre_prompt, post_prompt)
|
||||
length = input_ids.shape[1]
|
||||
else:
|
||||
pre_input_ids = self.tokenizer(pre_prompt,
|
||||
return_tensors="pt",
|
||||
@ -387,7 +534,9 @@ class MultimodalModelRunner:
|
||||
input_lengths = torch.IntTensor([length] * args.batch_size).to(
|
||||
torch.int32)
|
||||
|
||||
if self.model_type in ['fuyu', 'kosmos-2', 'phi-3-vision']:
|
||||
if self.model_type in [
|
||||
'fuyu', 'kosmos-2', 'phi-3-vision', 'llava_next'
|
||||
]:
|
||||
return input_ids, input_lengths, [visual_features], visual_features
|
||||
|
||||
input_ids, ptuning_args = self.setup_fake_prompts(
|
||||
@ -667,6 +816,19 @@ class MultimodalModelRunner:
|
||||
res_input_ids.append(cur_input_ids)
|
||||
return res_input_ids
|
||||
|
||||
def ptuning_setup_llava_next(self, visual_features, pre_prompt,
|
||||
post_prompt):
|
||||
input_ids = []
|
||||
fake_prompt_ids = list(
|
||||
range(self.model_config.vocab_size,
|
||||
self.model_config.vocab_size + visual_features.shape[0]))
|
||||
input_ids = self.tokenizer.encode(
|
||||
pre_prompt[0]) + fake_prompt_ids + self.tokenizer.encode(
|
||||
post_prompt[0])[self.tokenizer.add_bos_token:]
|
||||
input_ids = [input_ids] * len(pre_prompt)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
return input_ids
|
||||
|
||||
def ptuning_setup_phi3(self, visual_features, input_ids, num_img_tokens):
|
||||
fake_prompt_id = torch.arange(
|
||||
self.model_config.vocab_size,
|
||||
@ -869,6 +1031,32 @@ class MultimodalModelRunner:
|
||||
pre_prompt = """<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n<extra_id_1>User"""
|
||||
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n" ""
|
||||
|
||||
elif self.model_type == "llava_next":
|
||||
if self.llm_name == "mistralai/Mistral-7B-Instruct-v0.2":
|
||||
pre_prompt = "[INST] "
|
||||
if input_text is None:
|
||||
input_text = "Question: which city is this? Answer:"
|
||||
post_prompt = f"\n{input_text} [/INST]"
|
||||
prompt = pre_prompt + post_prompt
|
||||
|
||||
elif self.llm_name == "NousResearch/Nous-Hermes-2-Yi-34B":
|
||||
pre_prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n"
|
||||
if input_text is None:
|
||||
input_text = "Question: which city is this? Answer:"
|
||||
post_prompt = f"\n{input_text}<|im_end|><|im_start|>assistant\n"
|
||||
prompt = pre_prompt + post_prompt
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
f"Prompt template for {self.llm_name} for not included currently"
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(args.hf_model_dir,
|
||||
trust_remote_code=True)
|
||||
image = processor(text=prompt,
|
||||
images=raw_image,
|
||||
return_tensors="pt")
|
||||
|
||||
elif self.model_type in ['llava', 'vila', 'fuyu', 'kosmos-2']:
|
||||
# LLaVA and VILA
|
||||
if self.model_type == "llava":
|
||||
@ -924,7 +1112,8 @@ class MultimodalModelRunner:
|
||||
pre_prompt = [pre_prompt] * self.args.batch_size
|
||||
post_prompt = [post_prompt] * self.args.batch_size
|
||||
if self.model_type not in [
|
||||
'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision'
|
||||
'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision',
|
||||
'llava_next'
|
||||
]:
|
||||
if image.dim() == 5:
|
||||
image = image.expand(args.batch_size, -1, -1, -1,
|
||||
@ -932,7 +1121,6 @@ class MultimodalModelRunner:
|
||||
else:
|
||||
image = image.expand(args.batch_size, -1, -1, -1).contiguous()
|
||||
image = image.to(self.device)
|
||||
|
||||
# Generate decoder_input_ids for enc-dec models
|
||||
# Custom prompts can be added as:
|
||||
# decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
|
||||
@ -955,7 +1143,6 @@ class MultimodalModelRunner:
|
||||
def run(self, input_text, input_image, max_new_tokens):
|
||||
input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = model.setup_inputs(
|
||||
input_text, input_image)
|
||||
|
||||
model.generate(pre_prompt,
|
||||
post_prompt,
|
||||
processed_image,
|
||||
@ -999,7 +1186,9 @@ class MultimodalModelRunner:
|
||||
elif self.model_type == "pix2struct":
|
||||
assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[
|
||||
0][0].lower()
|
||||
elif self.model_type in ['blip2', 'neva', 'phi-3-vision']:
|
||||
elif self.model_type in [
|
||||
'blip2', 'neva', 'phi-3-vision', 'llava_next'
|
||||
]:
|
||||
assert 'singapore' in output_text[0][0].lower()
|
||||
elif self.model_type == 'video-neva':
|
||||
assert 'robot' in output_text[0][0].lower()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
transformers==4.40.2
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets>=2.14.4
|
||||
nemo-toolkit[all]<=1.20.0,>=1.18.0
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.16.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.16.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -496,6 +496,7 @@ def main():
|
||||
rnn_hidden_size=ckpt_config["lru_width"],
|
||||
logits_soft_cap=ckpt_config["logits_soft_cap"],
|
||||
emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"],
|
||||
rnn_conv_dim_size=ckpt_config["lru_width"],
|
||||
)
|
||||
|
||||
trt_llm_config_dict = trt_llm_config.to_dict()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
git+https://github.com/google-deepmind/recurrentgemma.git
|
||||
flax>=0.8.2
|
||||
jax~=0.4.23
|
||||
|
||||
@ -36,7 +36,6 @@ if PYTHON_BINDINGS:
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
# see `add_common_args` for extended list of arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_input_length', type=int, default=923)
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
@ -319,18 +318,6 @@ def main(args):
|
||||
"Debug mode is not supported in C++ session for now, fallback to Python session."
|
||||
)
|
||||
args.use_py_session = True
|
||||
if args.return_all_generated_tokens and args.use_py_session:
|
||||
raise ValueError(
|
||||
"Returning all the generated tokens at each step is not supported in the Python session, use C++ session instead."
|
||||
)
|
||||
if (not args.return_all_generated_tokens) and args.streaming and (
|
||||
args.num_beams > 1):
|
||||
logger.warning(
|
||||
"Setting return_all_generated_tokens to True since streaming AND beam search are done simultaneously. "
|
||||
"Returning the full beams at each streaming step is needed because beam search + streaming can change previous outputs. "
|
||||
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)."
|
||||
)
|
||||
args.return_all_generated_tokens = True
|
||||
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
@ -360,7 +347,8 @@ def main(args):
|
||||
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
|
||||
kv_cache_free_gpu_memory_fraction=args.
|
||||
kv_cache_free_gpu_memory_fraction,
|
||||
enable_chunked_context=args.enable_chunked_context)
|
||||
enable_chunked_context=args.enable_chunked_context,
|
||||
)
|
||||
runner = runner_cls.from_dir(**runner_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -394,8 +382,7 @@ def main(args):
|
||||
output_sequence_lengths=True,
|
||||
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||
return_dict=True,
|
||||
medusa_choices=args.medusa_choices,
|
||||
return_all_generated_tokens=args.return_all_generated_tokens)
|
||||
medusa_choices=args.medusa_choices)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if args.streaming:
|
||||
@ -482,9 +469,7 @@ def main(args):
|
||||
prompt_tasks=args.prompt_tasks,
|
||||
streaming=args.streaming,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=args.return_all_generated_tokens
|
||||
)
|
||||
return_dict=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
tensorrt_llm.profiler.start("tmp")
|
||||
@ -516,9 +501,7 @@ def main(args):
|
||||
prompt_tasks=args.prompt_tasks,
|
||||
streaming=args.streaming,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=args.return_all_generated_tokens
|
||||
)
|
||||
return_dict=True)
|
||||
torch.cuda.synchronize()
|
||||
tensorrt_llm.profiler.stop("tmp")
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets~=2.16.1
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -415,10 +415,6 @@ def main(args):
|
||||
"Python bindings of C++ session is unavailable, fallback to Python session."
|
||||
)
|
||||
args.use_py_session = True
|
||||
if args.return_all_generated_tokens:
|
||||
raise ValueError(
|
||||
"Returning all the generated tokens at each step is not supported in summarize.py"
|
||||
)
|
||||
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
|
||||
runner_kwargs = dict(engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
|
||||
@ -329,16 +329,4 @@ def add_common_args(parser):
|
||||
action='store_true',
|
||||
help="Use device map 'auto' to load a pretrained HF model. This may "
|
||||
"help to test a large model that cannot fit into a singlue GPU.")
|
||||
|
||||
parser.add_argument(
|
||||
"--return_all_generated_tokens",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="if false, return only generated tokens at each streaming step."
|
||||
"If true, return the full beams/outputs at each step"
|
||||
"Overwritten to True if num_beams>1 and streaming"
|
||||
"(only available with cpp session). "
|
||||
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024070200
|
||||
tensorrt_llm==0.12.0.dev2024070900
|
||||
tiktoken
|
||||
datasets
|
||||
kaldialign
|
||||
|
||||
@ -20,7 +20,7 @@ tokenizers>=0.14
|
||||
# Default torch is CPU-only on Windows, so need to specify a torch version with GPU support
|
||||
torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-win_amd64.whl
|
||||
nvidia-modelopt~=0.13,<0.14
|
||||
transformers==4.38.2
|
||||
transformers>=4.38.2
|
||||
wheel
|
||||
optimum
|
||||
evaluate
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import array
|
||||
import struct
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -83,7 +84,7 @@ class IpcMemory():
|
||||
self.local_ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
if self.open_ipc:
|
||||
if not sys.is_finalizing() and self.open_ipc:
|
||||
IpcMemory.close_ipc_memory(self.mapping, self.peer_ptrs)
|
||||
|
||||
def serialize(self) -> List[int]:
|
||||
|
||||
@ -69,23 +69,6 @@ _bandwidths = {
|
||||
"PCIe-5": 64,
|
||||
}
|
||||
|
||||
_templates = {
|
||||
"H100-SXM":
|
||||
dict(
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=450,
|
||||
intra_node_sharp=True,
|
||||
memory_bw=3350,
|
||||
math_throughput=MathThroughput(
|
||||
int8=1979,
|
||||
fp8=1979,
|
||||
float16=989,
|
||||
bfloat16=989,
|
||||
float32=495,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
cluster_infos = {
|
||||
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
||||
"A100-SXM-80GB":
|
||||
@ -119,18 +102,18 @@ cluster_infos = {
|
||||
# from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
|
||||
"H100-SXM":
|
||||
ClusterInfo(
|
||||
**_templates["H100-SXM"],
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=450,
|
||||
intra_node_sharp=True,
|
||||
memory_bw=3350,
|
||||
memory_budget_per_device=80,
|
||||
),
|
||||
"H100-SXM-64G":
|
||||
ClusterInfo(
|
||||
**_templates["H100-SXM"],
|
||||
memory_budget_per_device=64,
|
||||
),
|
||||
"H100-SXM-94G":
|
||||
ClusterInfo(
|
||||
**_templates["H100-SXM"],
|
||||
memory_budget_per_device=94,
|
||||
math_throughput=MathThroughput(
|
||||
int8=1979,
|
||||
fp8=1979,
|
||||
float16=989,
|
||||
bfloat16=989,
|
||||
float32=495,
|
||||
),
|
||||
),
|
||||
"H100-PCIe":
|
||||
ClusterInfo(
|
||||
@ -369,12 +352,6 @@ def infer_cluster_key() -> str:
|
||||
return "H100-SXM"
|
||||
else:
|
||||
return "H100-PCIe"
|
||||
elif match("H100XS", device_name):
|
||||
return "H100-SXM-64G"
|
||||
elif match("H100XM", device_name):
|
||||
return "H100-SXM"
|
||||
elif match("H100XL", device_name):
|
||||
return "H100-SXM-94G"
|
||||
elif match("L40S", device_name):
|
||||
return "L40S"
|
||||
elif match("L40", device_name):
|
||||
|
||||
@ -27,11 +27,11 @@ import torch
|
||||
from ..auto_parallel import infer_cluster_config
|
||||
from ..auto_parallel.cluster_info import cluster_infos
|
||||
from ..builder import BuildConfig, Engine, build
|
||||
from ..functional import PositionEmbeddingType
|
||||
from ..logger import logger
|
||||
from ..lora_manager import LoraConfig, LoraManager
|
||||
from ..models import MODEL_MAP, PretrainedConfig
|
||||
from ..models.modeling_utils import (WEIGHT_LOADER_MODELS,
|
||||
SpeculativeDecodingMode)
|
||||
from ..models.modeling_utils import SpeculativeDecodingMode
|
||||
from ..plugin import PluginConfig, add_plugin_argument
|
||||
|
||||
|
||||
@ -248,13 +248,6 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
|
||||
def preprocess_model_config(model_config, **kwargs):
|
||||
if model_config.architecture in WEIGHT_LOADER_MODELS:
|
||||
model_config.mapping.tp_size = kwargs['tp_size']
|
||||
model_config.mapping.pp_size = kwargs['pp_size']
|
||||
model_config.mapping.world_size = kwargs['tp_size'] * kwargs['pp_size']
|
||||
|
||||
|
||||
def build_model(
|
||||
build_config: BuildConfig,
|
||||
rank: int = 0,
|
||||
@ -428,7 +421,6 @@ def main():
|
||||
ckpt_dir = ckpt_dir_or_model_config
|
||||
|
||||
model_config = PretrainedConfig.from_json_file(config_path)
|
||||
preprocess_model_config(model_config, **kwargs)
|
||||
|
||||
if args.build_config is None:
|
||||
if args.multiple_profiles == "enable" and args.opt_num_tokens is not None:
|
||||
@ -472,7 +464,6 @@ def main():
|
||||
deduced_max_seq_len = model_config.max_position_embeddings
|
||||
|
||||
# Step 2: Scale max_seq_len with rotary scaling
|
||||
rotary_scaling = getattr(model_config, "rotary_scaling", None)
|
||||
if rotary_factor != 1:
|
||||
deduced_max_seq_len *= rotary_factor
|
||||
logger.warning(
|
||||
@ -485,8 +476,18 @@ def main():
|
||||
f'max_seq_len is not specified, using value {deduced_max_seq_len}'
|
||||
)
|
||||
else:
|
||||
if not plugin_config.streamingllm and model_config.max_position_embeddings is not None:
|
||||
assert args.max_seq_len <= model_config.max_position_embeddings * rotary_factor, f'max_seq_len {args.max_seq_len} can\'t be larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}'
|
||||
if not plugin_config.streamingllm and model_config.max_position_embeddings is not None \
|
||||
and model_config.position_embedding_type != PositionEmbeddingType.relative:
|
||||
if args.max_seq_len > model_config.max_position_embeddings * rotary_factor:
|
||||
logger.warning(
|
||||
f'max_seq_len {args.max_seq_len} is larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}, '
|
||||
'the model accuracy might be affected')
|
||||
|
||||
if args.max_input_len > args.max_seq_len:
|
||||
logger.warning(
|
||||
f'max_input_len is {args.max_input_len} is larger than max_seq_len {args.max_seq_len}, clipping it to max_seq_len'
|
||||
)
|
||||
args.max_input_len = args.max_seq_len
|
||||
|
||||
build_config = BuildConfig.from_dict(
|
||||
{
|
||||
|
||||
@ -76,9 +76,9 @@ class GenerationRequest:
|
||||
# The following options in the Executor API are not yet exposed by the HLAPI:
|
||||
# https://jirasw.nvidia.com/browse/TRTLLM-489
|
||||
"bad_words":
|
||||
self.sampling_params.bad_words or [],
|
||||
self.sampling_params._get_bad_words(),
|
||||
"stop_words":
|
||||
self.sampling_params.stop_words or [],
|
||||
self.sampling_params._get_stop_words(),
|
||||
"embedding_bias":
|
||||
self.sampling_params.embedding_bias,
|
||||
"external_draft_tokens_config":
|
||||
@ -182,6 +182,15 @@ class GenerationResult:
|
||||
self.outputs[i].generation_logits = generation_logits[
|
||||
i, :self.outputs[i].length]
|
||||
|
||||
if self.finished and not self._generation_request.sampling_params.include_stop_str_in_output:
|
||||
for beam_output in self.outputs:
|
||||
for stop_ids in self._generation_request.sampling_params._get_stop_words(
|
||||
):
|
||||
if beam_output.token_ids[-len(stop_ids):] == stop_ids:
|
||||
beam_output.token_ids = beam_output.token_ids[:-len(
|
||||
stop_ids)]
|
||||
break
|
||||
|
||||
if context_logits is not None:
|
||||
self.context_logits = context_logits
|
||||
|
||||
|
||||
@ -5637,17 +5637,20 @@ def selective_scan(input: Tensor,
|
||||
A: Tensor,
|
||||
BC: Tensor,
|
||||
D: Tensor,
|
||||
z: Tensor,
|
||||
host_request_types: Tensor,
|
||||
last_token_ids: Tensor,
|
||||
dim: int,
|
||||
dstate: int,
|
||||
dt_rank: int,
|
||||
is_variable_B: bool,
|
||||
is_variable_C: bool,
|
||||
delta_softplus: bool,
|
||||
dtype: str,
|
||||
slot_mapping: Optional[Tensor] = None):
|
||||
z: Optional[Tensor] = None,
|
||||
host_context_lengths: Optional[Tensor] = None,
|
||||
slot_mapping: Optional[Tensor] = None,
|
||||
nheads: int = 1,
|
||||
ngroups: int = 1,
|
||||
chunk_size: int = 256,
|
||||
mamba_version: str = 'Mamba1'):
|
||||
'''
|
||||
Parameters:
|
||||
input : Tensor (On GPU)
|
||||
@ -5658,27 +5661,34 @@ def selective_scan(input: Tensor,
|
||||
Or the CPU tensor of shape [1] for the pointer of paged states.
|
||||
|
||||
delta : Tensor (On GPU)
|
||||
The delta tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
The delta tensor.
|
||||
mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
|
||||
|
||||
delta_bias : Tensor (On GPU)
|
||||
The delta bias tensor. Its shape is [dim]
|
||||
The delta bias tensor.
|
||||
mamba: Its shape is [dim]
|
||||
mamba2: Its shape is [nheads]
|
||||
|
||||
A : Tensor (On GPU)
|
||||
A matrix. Its shape is [dstate, dim]
|
||||
A matrix.
|
||||
mamba: Its shape is [dstate, dim]
|
||||
mamba2: Its shape is [nheads]
|
||||
|
||||
BC : Tensor (On GPU)
|
||||
B matrix. Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
|
||||
B and C matrix.
|
||||
mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
|
||||
mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding
|
||||
|
||||
D : Tensor (On GPU)
|
||||
D matrix. Its shape is [dim]
|
||||
|
||||
z : Tensor (On GPU)
|
||||
The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
D matrix.
|
||||
mamba: Its shape is [dim]
|
||||
mamba2: Its shape is [nheads]
|
||||
|
||||
host_request_types : Tensor (On CPU)
|
||||
The tensor on the host that indicates if a request is in context or
|
||||
generation phase. Its shape is [batch_size]. See Inflight Batching
|
||||
in docs/gpt_attention.md,
|
||||
in docs/gpt_attention.md
|
||||
|
||||
last_token_ids : Tensor (On GPU)
|
||||
The inclusive prefix-sum of the lengths or the lengths of the
|
||||
@ -5693,22 +5703,32 @@ def selective_scan(input: Tensor,
|
||||
dt_rank: int
|
||||
The rank dimension of dt_proj
|
||||
|
||||
is_variable_B : bool
|
||||
Is the matrix B a variable? Set to 'True' if B is a dynamic matrix
|
||||
during inference, 'False' otherwise
|
||||
|
||||
is_variable_C : bool
|
||||
Is the matrix C a variable? Set to 'True' if C is a dynamic matrix
|
||||
during inference, 'False' otherwise
|
||||
|
||||
delta_softplus : bool
|
||||
Do we apply softplus to the delta.
|
||||
|
||||
dtype: str
|
||||
data type
|
||||
|
||||
z : Tensor (On GPU) (Optional)
|
||||
The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
|
||||
|
||||
host_context_lengths: Tensor (On CPU) (Optional)
|
||||
A host tensor that contains the lengths of the different inputs,
|
||||
|
||||
slot_mapping: Tensor (On GPU) (Optional)
|
||||
Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]
|
||||
|
||||
nheads: int (Optional)
|
||||
The number of heads.
|
||||
|
||||
ngroups: int (Optional)
|
||||
The number of groups.
|
||||
|
||||
chunk_size: int (Optional)
|
||||
The chunk_size is used for the chunk_scan kernel.
|
||||
|
||||
mamba_version: int (Optional)
|
||||
Mamba version, support Mamba1 as default.
|
||||
'''
|
||||
assert host_request_types is not None
|
||||
selective_scan_plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
||||
@ -5721,12 +5741,13 @@ def selective_scan(input: Tensor,
|
||||
trt.PluginFieldType.INT32)
|
||||
dt_rank = trt.PluginField("dt_rank", np.array(dt_rank, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
is_variable_B = trt.PluginField(
|
||||
"is_variable_B", np.array(np.int8(is_variable_B), dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
is_variable_C = trt.PluginField(
|
||||
"is_variable_C", np.array(np.int8(is_variable_C), dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
nheads = trt.PluginField("nheads", np.array(nheads, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
ngroups = trt.PluginField("ngroups", np.array(ngroups, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
chunk_size = trt.PluginField("chunk_size",
|
||||
np.array(chunk_size, dtype=np.int32),
|
||||
trt.PluginFieldType.INT32)
|
||||
delta_softplus = trt.PluginField(
|
||||
"delta_softplus", np.array(np.int8(delta_softplus), dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
@ -5741,20 +5762,34 @@ def selective_scan(input: Tensor,
|
||||
"paged_state",
|
||||
np.array(np.int8(default_net().plugin_config.paged_state),
|
||||
dtype=np.int8), trt.PluginFieldType.INT8)
|
||||
if z is None:
|
||||
z_enabled = trt.PluginField("z_enabled", np.array(0, dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
else:
|
||||
z_enabled = trt.PluginField("z_enabled", np.array(1, dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
is_mamba2 = trt.PluginField(
|
||||
"is_mamba2",
|
||||
np.array(1 if mamba_version == 'Mamba2' else 0, dtype=np.int8),
|
||||
trt.PluginFieldType.INT8)
|
||||
|
||||
pfc = trt.PluginFieldCollection([
|
||||
dim, dstate, dt_rank, is_variable_B, is_variable_C, delta_softplus,
|
||||
pf_type, remove_input_padding, paged_state
|
||||
dim, dstate, dt_rank, nheads, ngroups, chunk_size, delta_softplus,
|
||||
pf_type, remove_input_padding, paged_state, z_enabled, is_mamba2
|
||||
])
|
||||
selective_scan_plug = selective_scan_plg_creator.create_plugin(
|
||||
"selective_scan", pfc)
|
||||
|
||||
plug_inputs = [
|
||||
input, state_or_ptr, delta, delta_bias, A, BC, D, z, host_request_types,
|
||||
input, state_or_ptr, delta, delta_bias, A, BC, D, host_request_types,
|
||||
last_token_ids
|
||||
]
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
plug_inputs += [host_context_lengths]
|
||||
if default_net().plugin_config.paged_state:
|
||||
plug_inputs += [slot_mapping]
|
||||
if z is not None:
|
||||
plug_inputs += [z]
|
||||
plug_inputs = [i.trt_tensor for i in plug_inputs]
|
||||
|
||||
layer = default_trtnet().add_plugin_v2(plug_inputs, selective_scan_plug)
|
||||
|
||||
@ -27,32 +27,59 @@ def get_build_cache_config_from_env() -> tuple[bool, str]:
|
||||
return build_cache_enabled, build_cache_root
|
||||
|
||||
|
||||
class BuildCache:
|
||||
'''
|
||||
The BuildCache class is a class that manages the intermediate products from the build steps.
|
||||
class BuildCacheConfig:
|
||||
"""
|
||||
Configuration for the build cache.
|
||||
|
||||
NOTE: currently, only engine-building is supported
|
||||
TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc.
|
||||
'''
|
||||
# The version of the cache, will be used to determine if the cache is compatible
|
||||
CACHE_VERSION = 0
|
||||
Attributes:
|
||||
cache_root (str): The root directory for the build cache.
|
||||
max_records (int): The maximum number of records to store in the cache.
|
||||
max_cache_storage_gb (float): The maximum amount of storage (in GB) to use for the cache.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cache_root: Optional[Path] = None,
|
||||
max_records: int = 10,
|
||||
max_cache_storage_gb: int = 256):
|
||||
'''
|
||||
Args:
|
||||
cache_root (Path): The root directory of the cache
|
||||
max_records (int): The maximum number of records to keep in the cache
|
||||
max_cache_storage_gb (int): The maximum storage size of the cache
|
||||
'''
|
||||
_, default_cache_root = get_build_cache_config_from_env()
|
||||
self.cache_root = cache_root or Path(default_cache_root)
|
||||
self.max_records = max_records
|
||||
self.max_cache_storage_gb = max_cache_storage_gb
|
||||
max_cache_storage_gb: float = 256):
|
||||
self._cache_root = cache_root
|
||||
self._max_records = max_records
|
||||
self._max_cache_storage_gb = max_cache_storage_gb
|
||||
|
||||
if max_records < 1:
|
||||
@property
|
||||
def cache_root(self) -> Path:
|
||||
_build_cache_enabled, _build_cache_root = get_build_cache_config_from_env(
|
||||
)
|
||||
return self._cache_root or Path(_build_cache_root)
|
||||
|
||||
@property
|
||||
def max_records(self) -> int:
|
||||
return self._max_records
|
||||
|
||||
@property
|
||||
def max_cache_storage_gb(self) -> float:
|
||||
return self._max_cache_storage_gb
|
||||
|
||||
|
||||
class BuildCache:
|
||||
"""
|
||||
The BuildCache class is a class that manages the intermediate products from the build steps.
|
||||
|
||||
NOTE: currently, only engine-building is supported
|
||||
TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc.
|
||||
"""
|
||||
# The version of the cache, will be used to determine if the cache is compatible
|
||||
CACHE_VERSION = 0
|
||||
|
||||
def __init__(self, config: Optional[BuildCacheConfig] = None):
|
||||
|
||||
_, default_cache_root = get_build_cache_config_from_env()
|
||||
config = config or BuildCacheConfig()
|
||||
|
||||
self.cache_root = config.cache_root or Path(default_cache_root)
|
||||
self.max_records = config.max_records
|
||||
self.max_cache_storage_gb = config.max_cache_storage_gb
|
||||
|
||||
if config.max_records < 1:
|
||||
raise ValueError("max_records should be greater than 0")
|
||||
|
||||
def get_engine_building_cache_stage(self,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user