Update TensorRT-LLM (#1918)

This commit is contained in:
Kaiyu Xie 2024-07-09 14:42:22 +08:00 committed by GitHub
parent 9dbc5b38ba
commit a96cccafcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
133 changed files with 6966 additions and 879 deletions

View File

@ -172,16 +172,16 @@ struct BenchmarkParams
std::optional<std::vector<std::vector<SizeType32>>> medusaChoices;
};
class InferenceRequestsSyncSend
class InferenceRequestsAsyncSend
{
public:
InferenceRequestsSyncSend(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
InferenceRequestsAsyncSend(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
std::list<std::shared_ptr<InferenceRequest>> const& inferenceRequests, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start send requests to rank %d", peer);
mNumNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
comm->send(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
mRequest1 = comm->sendAsync(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
if (mNumNewWorkItems > 0)
{
for (auto const& infReq : inferenceRequests)
@ -191,16 +191,31 @@ public:
mPacked.insert(mPacked.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
mVecSize = static_cast<int64_t>(mPacked.size());
comm->send(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1);
comm->send(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2);
mRequest2 = comm->sendAsync(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1);
mRequest3 = comm->sendAsync(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
~InferenceRequestsAsyncSend()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mRequest1->wait();
if (mRequest2)
mRequest2->wait();
if (mRequest3)
mRequest3->wait();
TLLM_LOG_DEBUG("end send requests");
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
private:
int64_t mNumNewWorkItems;
int64_t mVecSize;
std::vector<int64_t> mPacked;
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest1;
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest2;
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest3;
};
} // namespace
@ -930,7 +945,6 @@ public:
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
, mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
, mActiveCount(0)
, mInferReqSyncSndHdl(nullptr)
{
auto const jsonConfig = GptJsonConfig::parse(trtEnginePath / "config.json");
mWorldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
@ -966,6 +980,12 @@ public:
~GptServer()
{
if (mInferReqWaitThread)
{
mInferReqWaitThread->join();
mInferReqWaitThread.reset(nullptr);
}
mWorkItemsQueue.clear();
}
@ -1031,7 +1051,11 @@ public:
// Return up to max_num_requests inference requests.
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
{
mInferReqSyncSndHdl = nullptr;
if (mInferReqWaitThread)
{
mInferReqWaitThread->join();
mInferReqWaitThread.reset(nullptr);
}
std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
auto& comm = COMM_SESSION;
if (max_num_requests > 0)
@ -1134,8 +1158,9 @@ public:
if (!mWorldConfig.isLastPipelineParallelRank())
{
auto const peer = mWorldConfig.getPipelineParallelRank() + 1;
mInferReqSyncSndHdl
= std::make_shared<InferenceRequestsSyncSend>(mCommPipelineParallel, inferenceRequests, peer);
auto inferReqAsyncSndHdl
= std::make_unique<InferenceRequestsAsyncSend>(mCommPipelineParallel, inferenceRequests, peer);
mInferReqWaitThread = std::make_unique<std::thread>([handle = std::move(inferReqAsyncSndHdl)]() {});
}
}
return inferenceRequests;
@ -1184,7 +1209,7 @@ private:
WorldConfig mWorldConfig;
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommTensorParallel;
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommPipelineParallel;
std::shared_ptr<InferenceRequestsSyncSend> mInferReqSyncSndHdl;
std::unique_ptr<std::thread> mInferReqWaitThread;
}; // class GptServer

View File

@ -60,14 +60,20 @@ class BuildConfig:
parallel_attention: bool = None
new_decoder_architecture: bool = None
state_size: int = 0
state_dtype: Optional[str] = None
state_dtype: Optional[str] = ""
conv_kernel: int = 0
layer_types: List[str] = field(default_factory=list)
rnn_hidden_size: int = 0
rnn_head_size: int = 0
rnn_conv_dim_size: int = 0
logits_soft_cap: float = 0.0
opt_batch_size: Optional[int] = None
opt_num_tokens: Optional[int] = None
use_bias: bool = None
mamba_version: str = 'Mamba1'
ssm_rmsnorm: bool = True
ngroups: int = 1
chunk_size: int = 256
@dataclass
@ -1218,6 +1224,7 @@ _allowed_configs = {
state_size=16,
conv_kernel=4,
rnn_hidden_size=5120,
rnn_conv_dim_size=5120,
layer_types=["recurrent"],
use_bias=False,
)),
@ -1238,6 +1245,7 @@ _allowed_configs = {
state_size=16,
conv_kernel=4,
rnn_hidden_size=4096,
rnn_conv_dim_size=4096,
layer_types=["recurrent"],
use_bias=False,
)),
@ -1258,6 +1266,7 @@ _allowed_configs = {
state_size=16,
conv_kernel=4,
rnn_hidden_size=3072,
rnn_conv_dim_size=3072,
layer_types=["recurrent"],
use_bias=False,
)),
@ -1278,6 +1287,7 @@ _allowed_configs = {
state_size=16,
conv_kernel=4,
rnn_hidden_size=2048,
rnn_conv_dim_size=2048,
layer_types=["recurrent"],
use_bias=False,
)),
@ -1298,9 +1308,62 @@ _allowed_configs = {
state_size=16,
conv_kernel=4,
rnn_hidden_size=1536,
rnn_conv_dim_size=1536,
layer_types=["recurrent"],
use_bias=False,
)),
"mamba2_2.7b":
ModelConfig(name="mamba2_2.7b",
family="mamba",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=64,
num_heads=1,
hidden_size=2560,
vocab_size=50288,
hidden_act="silu",
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_seq_len=2048,
state_size=128,
conv_kernel=4,
rnn_hidden_size=5120,
rnn_conv_dim_size=5376,
rnn_head_size=64,
layer_types=["recurrent"],
use_bias=False,
mamba_version='Mamba2',
ssm_rmsnorm=True,
ngroups=1,
chunk_size=256,
)),
"mamba2_130m":
ModelConfig(name="mamba2_130m",
family="mamba",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=24,
num_heads=1,
hidden_size=768,
vocab_size=50288,
hidden_act="silu",
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_seq_len=2048,
state_size=128,
conv_kernel=4,
rnn_hidden_size=1536,
rnn_conv_dim_size=1792,
rnn_head_size=64,
layer_types=["recurrent"],
use_bias=False,
mamba_version='Mamba2',
ssm_rmsnorm=True,
ngroups=1,
chunk_size=256,
)),
"whisper_large_v3":
ModelConfig(name="whisper_large_v3",
family="whisper",
@ -1344,6 +1407,7 @@ _allowed_configs = {
state_size=1,
layer_types=["recurrent", "recurrent", "attention"],
rnn_hidden_size=2560,
rnn_conv_dim_size=2560,
logits_soft_cap=30.0,
state_dtype="float32",
)),

View File

@ -295,7 +295,8 @@ def build_gpt(args):
builder_config_extra_kwargs = {}
extra_items = [
'layer_types', 'conv_kernel', 'rnn_hidden_size', 'logits_soft_cap',
'state_size', 'use_bias'
'state_size', 'use_bias', 'rnn_head_size', 'rnn_conv_dim_size',
'mamba_version', 'ssm_rmsnorm', 'ngroups', 'chunk_size'
]
for item in extra_items:
if item in build_config:
@ -876,10 +877,16 @@ def build_gpt(args):
'state_size': build_config['state_size'],
'conv_kernel': build_config['conv_kernel'],
'rnn_hidden_size': build_config['rnn_hidden_size'],
'rnn_head_size': build_config['rnn_head_size'],
'rnn_conv_dim_size': build_config['rnn_conv_dim_size'],
'rms_norm': True,
'residual_in_fp32': True,
'pad_vocab_size_multiple': 8,
'use_bias': build_config['use_bias'],
'mamba_version': build_config['mamba_version'],
'ssm_rmsnorm': build_config['ssm_rmsnorm'],
'ngroups': build_config['ngroups'],
'chunk_size': build_config['chunk_size'],
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.MambaForCausalLM(config)
@ -912,6 +919,8 @@ def build_gpt(args):
'state_size': build_config['state_size'],
'layer_types': build_config['layer_types'],
'rnn_hidden_size': build_config['rnn_hidden_size'],
'rnn_head_size': build_config['rnn_head_size'],
'rnn_conv_dim_size': build_config['rnn_conv_dim_size'],
'logits_soft_cap': build_config['logits_soft_cap'],
'rotary_pct': build_config['rotary_pct'],
}

View File

@ -126,7 +126,7 @@ class GPTBenchmark(BaseBenchmark):
rnn_config_items = [
'conv_kernel', 'layer_types', 'rnn_hidden_size', 'state_size',
'state_dtype'
'state_dtype', 'rnn_head_size', 'rnn_conv_dim_size'
]
rnn_configs_kwargs = {}
for item in rnn_config_items:

View File

@ -116,7 +116,7 @@ public:
uint64_t requestId, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId}
, mIsStreaming{false}
, mlogitsPostProcessor(logitsPostProcessor)
, mLogitsPostProcessor(logitsPostProcessor)
{
}
@ -125,7 +125,7 @@ public:
: mRequestId{requestId}
, mIsStreaming{false}
, mInputTensors{std::move(tensorMap)}
, mlogitsPostProcessor(logitsPostProcessor)
, mLogitsPostProcessor(logitsPostProcessor)
{
for (auto const& [name, tensor] : mInputTensors)
{
@ -161,12 +161,12 @@ public:
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> cb)
{
mlogitsPostProcessor = cb;
mLogitsPostProcessor = cb;
}
std::optional<LogitsPostProcessor> getLogitsPostProcessor()
{
return mlogitsPostProcessor;
return mLogitsPostProcessor;
}
static std::array constexpr kTensorNames = {
@ -280,7 +280,7 @@ protected:
uint64_t mRequestId;
bool mIsStreaming;
TensorMap mInputTensors;
std::optional<LogitsPostProcessor> mlogitsPostProcessor;
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
};
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>

View File

@ -248,16 +248,6 @@ public:
}
}
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
{
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
}
[[nodiscard]] std::vector<int> const& getNumPrepopulatedTokens() const
{
return mNumPrepopulatedTokens;
}
private:
// Slot id of the sequence
SizeType32 mSeqSlotIdx;
@ -267,10 +257,6 @@ private:
SizeType32 mBeamWidth;
// List of blocks allocated for each beam of the sequence
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
// Number of tokens already in kv cache before context phase.
// A value > 0 indicates cached kv cache blocks were reused.
// One value per beam.
std::vector<int> mNumPrepopulatedTokens;
};
// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
@ -400,7 +386,10 @@ public:
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx, SizeType32 seqSlotIdx);
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
//! \brief Add single block to all beams of sequence.
void addBlockToAllBeams(BlockPtr& block, GenerationRequest& sequence);
//! \brief Store blocks in cached blocks.
//! \param blockedTokens Tokens of each block.
@ -410,11 +399,8 @@ private:
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
//! \param blockedTokens Tokens of each block.
//! \param sequence Sequence to which blocks are assigned.
//! \param beamIdx Beam of sequence to which blocks are assigned.
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType32 loadOrAllocateBlocks(std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence,
SizeType32 beamIdx, SizeType32 seqSlotIdx);
SizeType32 loadOrAllocateBlocks(std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence);
//! \brief Find best primary block to free.
//! \details The best primary block to free is the primary block that appears first in the queue and have no primary
@ -598,12 +584,6 @@ public:
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
[[nodiscard]] SizeType32 getNumPrepopulatedTokens(SizeType32 batchSlotIdx, SizeType32 beamIdx) const
{
auto const& prepopulatedTokens = mSequences.at(batchSlotIdx)->getNumPrepopulatedTokens();
return prepopulatedTokens.size() > 0 ? prepopulatedTokens.at(beamIdx) : 0;
}
[[nodiscard]] bool isEnableBlockReuse() const
{
return mEnableBlockReuse;

View File

@ -84,7 +84,6 @@ public:
, mSamplingConfig(samplingConfig)
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(isStreaming)
, mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1))
, mEndId(endId)
, mPadId(padId)
, mLogitsPostProcessor(logitsPostProcessor)
@ -127,7 +126,6 @@ public:
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(req.getStreaming())
, mReturnAllGeneratedTokens(req.getReturnAllGeneratedTokens())
, mEndId(req.getEndId())
, mPadId(req.getPadId())
, mOrigPromptLen(mPromptLen)
@ -154,16 +152,6 @@ public:
, mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput)
, mDecodingIter(0)
{
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnAllGeneratedTokens == false)
{
TLLM_LOG_WARNING(
"Setting mReturnAllGeneratedTokens to True since streaming AND beam search are done simultaneously. "
"Returning the full beams at each streaming step is needed because beam search + streaming can change "
"previous outputs. Initialize request with mReturnAllGeneratedTokens = True to dismiss this error."
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output "
"length).");
mReturnAllGeneratedTokens = true;
}
if (req.getEncoderInputTokenIds())
{
mState = REQUEST_STATE_ENCODER_INIT;
@ -575,6 +563,16 @@ public:
return mOrigPromptLen;
}
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen)
{
mPrepopulatedPromptLen = prepopulatedPromptLen;
}
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
{
return mPrepopulatedPromptLen;
}
void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
{
mDraftTokens = draftTokens;
@ -585,7 +583,7 @@ public:
mDraftLogits = draftLogits;
}
SizeType32 getNumDraftTokens() const
[[nodiscard]] SizeType32 getNumDraftTokens() const
{
return mDraftTokens->size();
}
@ -604,7 +602,7 @@ public:
mNumTokensPerIteration = numTokensPerIteration;
}
SizeType32 getNumTokensPerIteration() const
[[nodiscard]] SizeType32 getNumTokensPerIteration() const
{
return mNumTokensPerIteration;
}
@ -883,22 +881,16 @@ public:
// FIXME(nkorobov): For streaming we do not allow beam search and
// streaming index calculation here applies only for sampling
// getNumTokensPerIteration takes accepted draft tokens into account
auto nbTokensOut
= (mReturnAllGeneratedTokens || !mIsStreaming) ? maxNbTokens : std::max(getNumTokensPerIteration(), 1);
int nbTokensOut = mIsStreaming ? std::max(getNumTokensPerIteration(), 1) : maxNbTokens;
if (mExcludeInputFromOutput && !mIsStreaming)
{
nbTokensOut -= getOrigPromptLen();
}
result.outputTokenIds.resize(nbBeams);
SizeType32 tokenPos = maxNbTokens - nbTokensOut;
// in the case of streaming + beam search
// we need to return the full beams at all iterations
SizeType32 tokenPos{maxNbTokens - nbTokensOut};
auto const shouldSendResponse = isGenerationCompleteState()
|| (mIsStreaming && tokenPos > getMaxSentTokenPos()) || mReturnAllGeneratedTokens;
bool shouldSendResponse = isGenerationCompleteState() || (mIsStreaming && tokenPos > getMaxSentTokenPos());
if (!shouldSendResponse)
{
@ -909,8 +901,7 @@ public:
for (SizeType32 beam = 0; beam < nbBeams; ++beam)
{
auto tokens = getTokens(beam);
auto nbTokens = (mReturnAllGeneratedTokens || !mIsStreaming) ? tokens.size()
: (tokenPos - getMaxSentTokenPos());
auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size();
// Take accepted draft tokens into account when streaming
auto const numAcceptedTokens = std::max(0, getNumTokensPerIteration() - 1);
@ -982,8 +973,6 @@ public:
runtime::SamplingConfig mSamplingConfig;
LlmRequestState_t mState;
bool mIsStreaming;
// whether to return the full beams on each iteration. True when doing streaming + beamsearch
bool mReturnAllGeneratedTokens;
std::optional<TokenIdType> mEndId;
std::optional<TokenIdType> mPadId;
std::optional<SizeType32> mSeqSlot;
@ -993,6 +982,10 @@ public:
protected:
BeamTokens mTokens;
SizeType32 mOrigPromptLen;
// Number of tokens already in KV cache before context phase.
// A value > 0 indicates cached KV cache blocks were reused.
// Up to inputLen - 1 tokens can be reused.
SizeType32 mPrepopulatedPromptLen{0};
SizeType32 mMaxSentTokenPos;
std::optional<TensorPtr> mEmbeddingBias;

View File

@ -385,7 +385,7 @@ private:
bool mFreeComm;
};
void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED, bool forwardAbortToParent = false);
void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_MULTIPLE, bool forwardAbortToParent = false);
} // namespace tensorrt_llm::mpi

View File

@ -252,8 +252,6 @@ public:
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
/// name provided to the ExecutorConfig.
/// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models
/// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens
/// after every streaming step.
Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@ -264,7 +262,7 @@ public:
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
std::optional<LoraConfig> loraConfig = std::nullopt,
std::optional<std::string> logitsPostProcessorName = std::nullopt,
std::optional<VecTokens> encoderInputTokenIds = std::nullopt, bool returnAllGeneratedTokens = false);
std::optional<VecTokens> encoderInputTokenIds = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
@ -290,7 +288,6 @@ public:
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
[[nodiscard]] std::optional<VecTokens> getEncoderInputTokenIds() const;
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig const& config);
@ -305,7 +302,6 @@ public:
void setLoraConfig(LoraConfig const& loraConfig);
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds);
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
private:
friend class Serialization;

View File

@ -175,7 +175,7 @@ public:
{
TLLM_CHECK(data.size() <= std::numeric_limits<DimType64>::max());
}
return of(data.data(), {static_cast<Shape::DimType64 const>(data.size())});
return of(data.data(), {static_cast<Shape::DimType64>(data.size())});
}
Tensor() noexcept = default;

View File

@ -51,6 +51,8 @@ public:
SizeType32 stateSize = 0;
SizeType32 convKernel = 0;
SizeType32 rnnHiddenSize = 0;
SizeType32 rnnHeadSize = 0;
SizeType32 rnnConvDimSize = 0;
};
enum class LayerType : std::int32_t

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33f2d6b3e871b0a0e651883607887777fe03d6822f06e4154ffc7e35a8d5cc70
size 3938416
oid sha256:5804fde474d6489db29204259b7e6c368117acadb7fb6dc807868ee0391c458b
size 3953206

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8412aa4ca15c232ced1cd4bdfcc54177c7b257aef493d50650c960e0fb527cfc
size 4002178
oid sha256:85802a0e66148acb17d017a64dd982287775ce7bf5aa4e8bb7e5466b3736c7ee
size 4019734

View File

@ -1,3 +1,3 @@
d4aa7db860caf8feedb79a280aa70da3 libtensorrt_llm_batch_manager_static.a
02b4363342ccea3e2abccc474f3506bb libtensorrt_llm_batch_manager_static.pre_cxx11.a
0e1417f27d93de67940c1062cf230017cd8be5f1 commit
00fb525bdf4ff217c16940540b2357c4 libtensorrt_llm_batch_manager_static.a
97d2db7f62745001d871bc89fb38eed6 libtensorrt_llm_batch_manager_static.pre_cxx11.a
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:86f34c84883f1dfed04c6fb18811198da636e4457617a47db71f045cb3066eb4
size 3825822
oid sha256:33a724d7e9eabc358c0d674151d45cef8849ae702cc5f2f88b259299a8306574
size 3842582

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c07c30d986591bbe93bb30d67fc8ebbba3eb55c5875ce939c3265151747656ae
size 3782506
oid sha256:490a93ff13a67949a30e279fc3df27456c7f5d4084158c3089befccf78118b7f
size 3799140

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e0190b794e437fa6a0e2140e9446195413abde0dfbc5109423c790397fbb95a6
size 22445474
oid sha256:663a163c3177644ed86fa7a2145fe5e9dbf6f2f0ed06c96d367236da323a3432
size 22523526

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8729077e2bfb9cf3f647cc6ca9be42a8953c0ddf58426485ae3bded76dc9d5c3
size 1403008
oid sha256:497b00031131c1dc705e848e52f3d43148f55505e37bdad97f4933b2c074469d
size 1400502

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2b68c06565f1b3f795e070420c73d085c620b42c1c2131f9895d2687178a6b54
size 1427780
oid sha256:417978bdb5c19f97d9758475acacfa18a4038fc3c5a83f981b02ee220104e0c7
size 1425792

View File

@ -1,3 +1,3 @@
db98ffd911c3c1dde3413e934ce8deb8 libtensorrt_llm_executor_static.a
8dc57746aa2c29d8a2fa50196b552bc3 libtensorrt_llm_executor_static.pre_cxx11.a
0e1417f27d93de67940c1062cf230017cd8be5f1 commit
1df55ac2948ca7b7fe2d5e79934e660e libtensorrt_llm_executor_static.a
ea1641928d184d117deec0696763b274 libtensorrt_llm_executor_static.pre_cxx11.a
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cbc3a279681e877a982c0ebbdd0c13d7792d67a87bad0be125ec81bfe3f87399
size 1454684
oid sha256:d0441d473852d11f50bcf23f4934b38d7e4c6d4a42f057eb04beb8aea4211cac
size 1451118

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa15303c38a748c4bf7b82e1f9c58cb63418efbd60bfede62820f5a62d65710a
size 1381738
oid sha256:dc8619f99cf5a2e04bdb1482f157a9852bd745e90cf9e03a7878f73ed07e5610
size 1383936

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1a3e774d6700444b7164e1b31e26936ea6fcddc73e3e17bba1d8492c65a57b78
size 14036486
oid sha256:772d1b83e739b926729b99999fbb81768569ffb172c2e120665b2d31b987bb47
size 14071986

View File

@ -0,0 +1,329 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <tuple>
#ifdef __CUDACC__ // for CUDA
#define FT_DEV_CEXPR __device__ __host__ inline constexpr
#else
#define FT_DEV_CEXPR inline constexpr
#endif
//----------------------------------------------------------------------------
// Cn: constant integer
//----------------------------------------------------------------------------
template <auto value_>
struct Cn : public std::integral_constant<decltype(value_), value_>
{
};
template <auto value_>
constexpr auto cn = Cn<value_>();
//----------------------------------------------------------------------------
// Operators for Cn
//----------------------------------------------------------------------------
template <auto value_>
FT_DEV_CEXPR auto operator+(Cn<value_>)
{
return cn<+value_>;
}
template <auto value_>
FT_DEV_CEXPR auto operator-(Cn<value_>)
{
return cn<-value_>;
}
template <auto value_>
FT_DEV_CEXPR auto operator!(Cn<value_>)
{
return cn<!value_>;
}
template <auto value_>
FT_DEV_CEXPR auto operator~(Cn<value_>)
{
return cn<~value_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator+(Cn<a_>, Cn<b_>)
{
return cn<a_ + b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator-(Cn<a_>, Cn<b_>)
{
return cn<a_ - b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator*(Cn<a_>, Cn<b_>)
{
return cn<a_ * b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator/(Cn<a_>, Cn<b_>)
{
return cn<a_ / b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator%(Cn<a_>, Cn<b_>)
{
return cn<a_ % b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator<<(Cn<a_>, Cn<b_>)
{
return cn<(a_ << b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator>>(Cn<a_>, Cn<b_>)
{
return cn<(a_ >> b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator<(Cn<a_>, Cn<b_>)
{
return cn<(a_ < b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator<=(Cn<a_>, Cn<b_>)
{
return cn<(a_ <= b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator>(Cn<a_>, Cn<b_>)
{
return cn<(a_ > b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator>=(Cn<a_>, Cn<b_>)
{
return cn<(a_ >= b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator==(Cn<a_>, Cn<b_>)
{
return cn<(a_ == b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator!=(Cn<a_>, Cn<b_>)
{
return cn<(a_ != b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator^(Cn<a_>, Cn<b_>)
{
return cn<a_ ^ b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator&(Cn<a_>, Cn<b_>)
{
return cn<a_ & b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator&&(Cn<a_>, Cn<b_>)
{
return cn < a_ && b_ > ;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator|(Cn<a_>, Cn<b_>)
{
return cn<a_ | b_>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto operator||(Cn<a_>, Cn<b_>)
{
return cn < a_ || b_ > ;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator*(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator/(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator%(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator<<(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator>>(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator&(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> operator&&(Cn<a_>, B_)
{
return cn<a_>;
}
template <class A_, auto b_>
FT_DEV_CEXPR std::enable_if_t<b_ == 0, Cn<b_>> operator*(A_, Cn<b_>)
{
return cn<b_>;
}
template <class A_, auto b_>
FT_DEV_CEXPR std::enable_if_t<b_ == +1, Cn<decltype(b_)(0)>> operator%(A_, Cn<b_>)
{
return cn<decltype(b_)(0)>;
}
template <class A_, auto b_>
FT_DEV_CEXPR std::enable_if_t<b_ == -1, Cn<decltype(b_)(0)>> operator%(A_, Cn<b_>)
{
return cn<decltype(b_)(0)>;
}
template <class A_, auto b_>
FT_DEV_CEXPR std::enable_if_t<b_ == 0, Cn<b_>> operator&(A_, Cn<b_>)
{
return cn<b_>;
}
template <class A_, auto b_>
FT_DEV_CEXPR std::enable_if_t<b_ == 0, Cn<b_>> operator&&(A_, Cn<b_>)
{
return cn<b_>;
}
//----------------------------------------------------------------------------
// div_up & round_up
//----------------------------------------------------------------------------
template <class T_>
FT_DEV_CEXPR auto cexpr_abs(T_ a_) // abs is not constexpr until C++20
{
return a_ >= cn<0> ? +a_ : -a_;
}
template <class T_, class U_>
FT_DEV_CEXPR auto div_up(T_ a_, U_ b_)
{
auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>);
return tmp / b_;
}
template <class T_, class U_>
FT_DEV_CEXPR auto round_up(T_ a_, U_ b_)
{
auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>);
return tmp - tmp % b_;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto div_up(Cn<a_>, Cn<b_>)
{
return cn<div_up(a_, b_)>;
}
template <auto a_, auto b_>
FT_DEV_CEXPR auto round_up(Cn<a_>, Cn<b_>)
{
return cn<round_up(a_, b_)>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> div_up(Cn<a_>, B_)
{
return cn<a_>;
}
template <auto a_, class B_>
FT_DEV_CEXPR std::enable_if_t<a_ == 0, Cn<a_>> round_up(Cn<a_>, B_)
{
return cn<a_>;
}
//----------------------------------------------------------------------------
// IsTuple: std::tuple, but not std::pair, std::array, etc.
//----------------------------------------------------------------------------
template <class T_>
struct IsTuple : public std::false_type
{
};
template <class... Ts_>
struct IsTuple<std::tuple<Ts_...>> : public std::true_type
{
};
template <class T_>
struct IsTuple<const T_> : public IsTuple<T_>
{
};
template <class T_>
struct IsTuple<T_&> : public IsTuple<T_>
{
};
template <class T_>
struct IsTuple<T_&&> : public IsTuple<T_>
{
};
template <class T_>
constexpr bool IsTuple_v = IsTuple<T_>::value;
// vim: ts=2 sw=2 sts=2 et sta

View File

@ -0,0 +1,166 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
extern "C" __device__ unsigned __nvvm_get_smem_pointer(void* ptr);
template <int mode_, int line_, class T_>
__device__ inline int swizzle(int x_)
{
return x_ ^ x_ / line_ % (mode_ / 16) * (16 / sizeof(T_));
}
template <class T_>
__device__ inline int swizzle(int x_, int y_)
{
return x_ ^ y_ * (16 / sizeof(T_));
}
template <int size_>
__device__ inline void cp_shared_global(unsigned s_ptr, void const* g_ptr)
{
static_assert(size_ == 4 || size_ == 8 || size_ == 16);
#if __CUDA_ARCH__ >= 800
if constexpr (size_ == 16)
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_));
else if constexpr (size_ == 8)
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_));
else if constexpr (size_ == 4)
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_));
#else
register unsigned tmp[size_ / 4];
if constexpr (size_ == 16)
{
asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3])
: "l"(g_ptr));
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]),
"r"(tmp[3]));
}
else if constexpr (size_ == 8)
{
asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" : "=r"(tmp[0]), "=r"(tmp[1]) : "l"(g_ptr));
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]));
}
else if constexpr (size_ == 4)
{
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr));
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0]));
}
#endif
}
template <int size_>
__device__ inline void cp_shared_global(unsigned s_ptr, void const* g_ptr, bool valid_)
{
static_assert(size_ == 4 || size_ == 8 || size_ == 16);
#if __CUDA_ARCH__ >= 800
if constexpr (size_ == 16)
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_),
"r"(valid_ ? size_ : 0));
else if constexpr (size_ == 8)
asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_),
"r"(valid_ ? size_ : 0));
else if constexpr (size_ == 4)
asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_),
"r"(valid_ ? size_ : 0));
#else
register unsigned tmp[size_ / 4];
if constexpr (size_ == 16)
{
if (valid_)
{
asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3])
: "l"(g_ptr));
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]),
"r"(tmp[2]), "r"(tmp[3]));
}
else
{
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "n"(0), "n"(0), "n"(0), "n"(0));
}
}
else if constexpr (size_ == 8)
{
if (valid_)
{
asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" : "=r"(tmp[0]), "=r"(tmp[1]) : "l"(g_ptr));
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]));
}
else
{
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "n"(0), "n"(0));
}
}
else if constexpr (size_ == 4)
{
if (valid_)
{
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr));
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0]));
}
else
{
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "n"(0));
}
}
#endif
}
__device__ inline void cp_commit_group()
{
#if __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;\n");
#endif
}
template <int remain_>
__device__ inline void cp_wait_group()
{
#if __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;\n" ::"n"(remain_));
#endif
}
template <bool trans_ = false>
__device__ inline void ldmatrix(unsigned& r0_, unsigned& r1_, unsigned& r2_, unsigned& r3_, unsigned addr_)
{
if (trans_)
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r0_), "=r"(r1_), "=r"(r2_), "=r"(r3_)
: "r"(addr_));
else
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r0_), "=r"(r1_), "=r"(r2_), "=r"(r3_)
: "r"(addr_));
}
typedef __nv_bfloat16 bf16;
typedef __nv_bfloat162 bf162;
template <int mode_ = 128, int line_ = 64>
__device__ int swz(int x_)
{
return x_ ^ x_ / line_ % (mode_ / 16) * 8;
}
// vim: ts=2 sw=2 sts=2 et sta

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,438 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp8.h>
#include <mma.h>
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "Common.h"
#include "Poly.h"
namespace tensorrt_llm
{
namespace kernels
{
typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int G_, int N_,
// const half *g_mxY_, // B*L*H*P
// const half *g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
// const float *g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
half* g_mxCB_, // B*C*G*Q*Q
half const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int G_, int N_,
// const bf16 *g_mxY_, // B*L*H*P
// const bf16 *g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
// const float *g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
bf16* g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
int warpM_, int warpN_, // warp number
int pipeS_, class Tp_>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> bmm_chunk_kernel(int B_,
int L_, int G_, int N_,
// const Tp_ *g_mxY_, // B*L*H*P
// const Tp_ *g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
// const float *g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
Tp_* g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
using namespace tensorrt_llm::common;
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
// auto H = Rn<ID>{H_};
// auto P = Rn<ID>{P_};
auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
if (removePadding_)
{
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
C = Rn<ID>{div_up(L.var, Q_)};
}
else
{
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
C = Rn<ID>{div_up(L.var, Q_)};
}
if (blockIdx_y * Q >= L)
return;
auto gStart = blockIdx_x / (Q / cn<tileN_>) / (Q / cn<tileM_>);
auto mStart = blockIdx_x / (Q / cn<tileN_>) % (Q / cn<tileM_>);
auto nStart = blockIdx_x % (Q / cn<tileN_>);
extern __shared__ float smem[];
Tp_* s_mxC = (Tp_*) smem;
Tp_* s_mxB = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
Tp_* s_mxCB = (Tp_*) smem;
unsigned b_base = __nvvm_get_smem_pointer(smem);
unsigned b_mxC = b_base;
unsigned b_mxB = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
unsigned b_mxCB = b_base;
using std::array;
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxCB
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxC;
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxB;
constexpr int step = std::max(
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
auto baseC = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
auto baseB = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
auto thread = [=](auto iStep)
{
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
+ threadIdx_x * cn<8>;
};
#pragma unroll
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
g_mxBC_
+ get(
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
cp_shared_global<16>(b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
g_mxBC_
+ get(
(aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
+ cn<0> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
= int4{0, 0, 0, 0};
cp_commit_group();
}
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
for (int iK = pipeS_; iK < N_ / tileK_ + pipeS_; iK++)
{
#pragma unroll
for (int k = 0; k < tileK_ / wmmaK_; k++)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
{
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
int y1 = x1 - tileN_ / wmmaN_ / warpN_
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
{
if (wmmaK_ == 16)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3])
: "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2)
+ 2
* swz<tileK_ * 2, tileK_>(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_
+ threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_
+ threadIdx.x / 16 * 8)));
}
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
{
if (wmmaK_ == 16 && x1 % 2 == 0)
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxB[x1][0]), "=r"(r_mxB[x1][1]), "=r"(r_mxB[x1 + 1][0]),
"=r"(r_mxB[x1 + 1][1])
: "r"(b_mxB + iK % pipeS_ * (tileK_ * tileN_ * 2)
+ 2
* swz<tileK_ * 2, tileK_>(x1 * warpN_ * wmmaN_ * tileK_
+ k * wmmaK_ + threadIdx.y * wmmaN_ * tileK_
+ threadIdx.x % 8 * tileK_ + threadIdx.x / 8 % 2 * 8
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_ * tileK_)));
}
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (wmmaK_ == 16)
{
if (std::is_same_v<Tp_, half>)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxCB[y][x][0]), "+f"(r_mxCB[y][x][1]), "+f"(r_mxCB[y][x][2]),
"+f"(r_mxCB[y][x][3])
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
"r"(r_mxB[x][0]), "r"(r_mxB[x][1]));
else
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxCB[y][x][0]), "+f"(r_mxCB[y][x][1]), "+f"(r_mxCB[y][x][2]),
"+f"(r_mxCB[y][x][3])
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
"r"(r_mxB[x][0]), "r"(r_mxB[x][1]));
}
}
}
__syncthreads();
if (iK * tileK_ < N_)
{
auto jK = Rn<>{iK};
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
* G * N
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
cp_shared_global<16>(
b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2>
* G * N
+ cn<0> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
= int4{0, 0, 0, 0};
}
asm volatile("cp.async.commit_group;\n" ::);
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (std::is_same_v<Tp_, half>)
{
*(half2*) &r_mxCB[y][x][0] = __floats2half2_rn(r_mxCB[y][x][0], r_mxCB[y][x][1]);
*(half2*) &r_mxCB[y][x][2] = __floats2half2_rn(r_mxCB[y][x][2], r_mxCB[y][x][3]);
}
else
{
*(bf162*) &r_mxCB[y][x][0] = __floats2bfloat162_rn(r_mxCB[y][x][0], r_mxCB[y][x][1]);
*(bf162*) &r_mxCB[y][x][2] = __floats2bfloat162_rn(r_mxCB[y][x][2], r_mxCB[y][x][3]);
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxCB
+ 2
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_
+ (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
"r"(*(unsigned*) &r_mxCB[y][x][0]));
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxCB
+ 2
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_
+ x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
"r"(*(unsigned*) &r_mxCB[y][x][2]));
}
__syncthreads();
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileN_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileN_>)
*(int4*) (g_mxCB_
+ get(cStart * G * Q * Q + blockIdx_y * G * Q * Q + gStart * Q * Q
+ (mStart * cn<tileM_> + thread(iStep) / cn<tileN_>) *Q + nStart * cn<tileN_>
+ thread(iStep) % cn<tileN_>))
= *(int4*) ((char*) s_mxCB + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>));
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
#endif
}
BmmChunkKernelFuncFp16 getBmmChunkKernelFp16(
int B_, int L_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
// int H = H_;
// int P = P_;
int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 128;
int tileN = 64;
int tileK = 32;
int warpM = 2;
int warpN = 1;
int pipeS = 2;
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2, (tileM * tileN) * 2);
*blockDims_ = dim3(G * Q / tileN * Q / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return bmm_chunk_kernel<128, 128, 64, 32, 16, 8, 16, 2, 1, 2, half>;
else if (Q_ == 256)
return bmm_chunk_kernel<256, 128, 64, 32, 16, 8, 16, 2, 1, 2, half>;
else
return nullptr;
}
BmmChunkKernelFuncBf16 getBmmChunkKernelBf16(
int B_, int L_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
// int H = H_;
// int P = P_;
int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 128;
int tileN = 64;
int tileK = 32;
int warpM = 2;
int warpN = 1;
int pipeS = 2;
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2, (tileM * tileN) * 2);
*blockDims_ = dim3(G * Q / tileN * Q / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return bmm_chunk_kernel<128, 128, 64, 32, 16, 8, 16, 2, 1, 2, bf16>;
else if (Q_ == 256)
return bmm_chunk_kernel<256, 128, 64, 32, 16, 8, 16, 2, 1, 2, bf16>;
else
return nullptr;
}
} // namespace kernels
} // namespace tensorrt_llm
// vim: ts=2 sw=2 sts=2 et sta

View File

@ -0,0 +1,245 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp8.h>
#include <mma.h>
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "Common.h"
#include "Poly.h"
namespace tensorrt_llm
{
namespace kernels
{
typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_,
// const half *g_mxY_, // B*L*H*P
// const half *g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float* g_mxdc_, // B*C*H*Q
float* g_mxdA_, // B*C*H*Q
half const* g_mxdt_, // B*L*H
float const* g_mxdb_, // H
float const* g_mxA_, // H
// const half *g_mxCB_, // B*C*G*Q*Q
// const half *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_,
// const bf16 *g_mxY_, // B*L*H*P
// const bf16 *g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float* g_mxdc_, // B*C*H*Q
float* g_mxdA_, // B*C*H*Q
bf16 const* g_mxdt_, // B*L*H
float const* g_mxdb_, // H
float const* g_mxA_, // H
// const bf16 *g_mxCB_, // B*C*G*Q*Q
// const bf16 *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileH_, int warpH_, bool dtSoftplus_, class Tp_, class Wt_ = float>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_cumsum_kernel(int B_,
int L_, int H_,
// const Tp_ *g_mxY_, // B*L*H*P
// const Tp_ *g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float* g_mxdc_, // B*C*H*Q
float* g_mxdA_, // B*C*H*Q
Tp_ const* g_mxdt_, // B*L*H
Wt_ const* g_mxdb_, // H
Wt_ const* g_mxA_, // H
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
// const Tp_ *g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_)
{
using namespace tensorrt_llm::common;
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
auto threadIdx_y = Rn<ID, warpH_>{int(threadIdx.y)};
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
auto H = Rn<ID>{H_};
// auto P = Rn<ID>{P_};
// auto G = Rn<ID>{G_};
// auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
if (removePadding_)
{
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
C = Rn<ID>{div_up(L.var, Q_)};
}
else
{
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
C = Rn<ID>{div_up(L.var, Q_)};
}
if (blockIdx_y * Q >= L)
return;
auto thread = [=](auto iStep) { return iStep * cn<warpH_ * 32> + threadIdx_y * cn<32> + threadIdx_x; };
#pragma unroll
for (Rn<UNROLL, div_up(tileH_, warpH_ * 32)> iStep; iStep.var < iStep.size; iStep.var++)
{
float r_A = 0.f, r_db = 0.f, sum = 0.f;
if (thread(iStep) < cn<tileH_>)
r_A = g_mxA_[get(blockIdx_x * cn<tileH_> + thread(iStep))];
if (thread(iStep) < cn<tileH_> && g_mxdb_)
r_db = g_mxdb_[get(blockIdx_x * cn<tileH_> + thread(iStep))];
#pragma unroll
for (Rn<UNROLL, Q_> iQ; iQ.var < iQ.size; iQ.var++)
{
float r_dt = 0.f;
if (thread(iStep) < cn<tileH_> && blockIdx_y * Q + iQ < L)
{
r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * H + blockIdx_x * cn<tileH_> + thread(iStep))])
+ r_db;
if (dtSoftplus_)
r_dt = r_dt > 32.f ? r_dt : log1p(expf(r_dt));
sum += r_dt;
}
if (thread(iStep) < cn<tileH_>)
{
g_mxdc_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn<tileH_> + thread(iStep)) * Q + iQ)] = r_dt;
g_mxdA_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn<tileH_> + thread(iStep)) * Q + iQ)]
= sum * r_A;
}
}
}
}
ChunkCumsumKernelFuncFp16 getChunkCumsumKernelFp16(
int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
// int P = P_;
// int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileH = 1;
int warpH = 1;
auto sharedMem = 0;
*blockDims_ = dim3(H / tileH, C, B);
*threadDims_ = dim3(32, warpH);
*sharedMem_ = sharedMem;
if (dtSoftPlus_)
{
if (Q_ == 128)
return chunk_cumsum_kernel<128, 1, 1, true, half>;
else if (Q_ == 256)
return chunk_cumsum_kernel<256, 1, 1, true, half>;
else
return nullptr;
}
else
{
if (Q_ == 128)
return chunk_cumsum_kernel<128, 1, 1, false, half>;
else if (Q_ == 256)
return chunk_cumsum_kernel<256, 1, 1, false, half>;
else
return nullptr;
}
}
ChunkCumsumKernelFuncBf16 getChunkCumsumKernelBf16(
int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
// int P = P_;
// int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileH = 1;
int warpH = 1;
auto sharedMem = 0;
*blockDims_ = dim3(H / tileH, C, B);
*threadDims_ = dim3(32, warpH);
*sharedMem_ = sharedMem;
if (dtSoftPlus_)
{
if (Q_ == 128)
return chunk_cumsum_kernel<128, 1, 1, true, bf16>;
else if (Q_ == 256)
return chunk_cumsum_kernel<256, 1, 1, true, bf16>;
else
return nullptr;
}
else
{
if (Q_ == 128)
return chunk_cumsum_kernel<128, 1, 1, false, bf16>;
else if (Q_ == 256)
return chunk_cumsum_kernel<256, 1, 1, false, bf16>;
else
return nullptr;
}
}
} // namespace kernels
} // namespace tensorrt_llm
// vim: ts=2 sw=2 sts=2 et sta

View File

@ -0,0 +1,639 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp8.h>
#include <mma.h>
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "Common.h"
#include "Poly.h"
namespace tensorrt_llm
{
namespace kernels
{
typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
half* g_mxY_, // B*L*H*P
half const* g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
half const* g_mxCB_, // B*C*G*Q*Q
half const* g_mxBC_, // B*L*2*G*N
float const* g_mxD_, // H
half const* g_mxX_, // B*L*H*P
half const* g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
bf16* g_mxY_, // B*L*H*P
bf16 const* g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
bf16 const* g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxBC_, // B*L*2*G*N
float const* g_mxD_, // H
bf16 const* g_mxX_, // B*L*H*P
bf16 const* g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
int warpM_, int warpN_, // warp number
int pipeS_, class Tp_, class Wt_ = float>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_scan_kernel(int B_,
int L_, int H_, int P_, int G_, int N_,
Tp_* g_mxY_, // B*L*H*P
Tp_ const* g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
Tp_ const* g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxBC_, // B*L*2*G*N
Wt_ const* g_mxD_, // H
Tp_ const* g_mxX_, // B*L*H*P
Tp_ const* g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
using namespace tensorrt_llm::common;
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
auto H = Rn<ID>{H_};
auto P = Rn<ID>{P_};
auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
if (removePadding_)
{
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
C = Rn<ID>{div_up(L.var, Q_)};
}
else
{
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
C = Rn<ID>{div_up(L.var, Q_)};
}
if (blockIdx_y * Q >= L)
return;
auto hStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) / (Q / cn<tileM_>) };
auto mStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) % (Q / cn<tileM_>) };
auto nStart = Rn<ID>{blockIdx_x.var % (P_ / cn<tileN_>) };
auto gStart = Rn<ID>{hStart.var / (H_ / G_)};
extern __shared__ float smem[];
Tp_* s_mxC = (Tp_*) smem;
Tp_* s_mxOs = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
Tp_* s_mxY = (Tp_*) smem;
float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2;
float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_;
unsigned b_base = __nvvm_get_smem_pointer(smem);
unsigned b_mxC = b_base;
unsigned b_mxOs = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
unsigned b_mxY = b_base;
using std::array;
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxY
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxC;
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxOs;
constexpr int step = std::max(
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
auto baseC = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
auto baseOs = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
auto thread = [=](auto iStep)
{
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
+ threadIdx_x * cn<8>;
};
#pragma unroll
for (Rn<UNROLL, div_up(Q_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<Q_>)
{
#pragma unroll
for (int i = 0; i < 8; i += 4)
{
*(int4*) (s_mxdc + get(thread(iStep)) + i)
= *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
*(int4*) (s_mxdA + get(thread(iStep)) + i)
= *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
}
}
#pragma unroll
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
g_mxBC_
+ get(
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(iK) * cn<2>),
g_mxOs_
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
+ (iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *P + nStart * cn<tileN_>
+ thread(iStep) % cn<tileN_>));
cp_commit_group();
}
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
for (int iK = pipeS_; iK < (N_ + Q_) / tileK_ + pipeS_; iK++)
{
auto jK = Rn<>{iK};
if ((iK - pipeS_) * cn<tileK_> == N_)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
float2 tmp2 = float2{expf(s_mxdA[get(mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>)]),
expf(s_mxdA[get(mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>)])};
r_mxY[y][x][0] *= tmp2.x;
r_mxY[y][x][1] *= tmp2.x;
r_mxY[y][x][2] *= tmp2.y;
r_mxY[y][x][3] *= tmp2.y;
}
}
if ((iK - pipeS_) * cn<tileK_> >= N_)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>)
{
register Tp_ tmpCB[8];
*(int4*) &tmpCB[0] = *(int4*) ((char*) s_mxC
+ swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>));
#pragma unroll
for (int i = 0; i < 8; i += 2)
{
float2 tmp2 = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmpCB[i])
: bf1622float2(*(bf162*) &tmpCB[i]);
int kStart = (iK - pipeS_) * cn<tileK_> - N_;
tmp2.x *= expf(s_mxdA[get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)]
- s_mxdA[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i})])
* s_mxdc[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i})];
tmp2.y *= expf(s_mxdA[get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)]
- s_mxdA[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1})])
* s_mxdc[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1})];
if (get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)
< kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i}))
tmp2.x = 0;
if (get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)
< kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1}))
tmp2.y = 0;
if (std::is_same_v<Tp_, half>)
*(half2*) &tmpCB[i] = __float22half2_rn(tmp2);
else
*(bf162*) &tmpCB[i] = __float22bfloat162_rn(tmp2);
}
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= *(int4*) &tmpCB[0];
}
__syncthreads();
}
#pragma unroll
for (int k = 0; k < tileK_ / wmmaK_; k++)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
{
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
int y1 = x1 - tileN_ / wmmaN_ / warpN_
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
{
if (wmmaK_ == 16)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3])
: "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2)
+ 2
* swz<tileK_ * 2, tileK_>(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_
+ threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_
+ threadIdx.x / 16 * 8)));
}
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
{
if (wmmaK_ == 16 && x1 % 2 == 0)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxOs[x1][0]), "=r"(r_mxOs[x1][1]), "=r"(r_mxOs[x1 + 1][0]),
"=r"(r_mxOs[x1 + 1][1])
: "r"(b_mxOs + iK % pipeS_ * (tileK_ * tileN_ * 2)
+ 2
* swz<tileN_ * 2, tileN_>(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_
+ threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_)));
}
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (wmmaK_ == 16)
{
if (std::is_same_v<Tp_, half>)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3])
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
"r"(r_mxOs[x][0]), "r"(r_mxOs[x][1]));
else
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3])
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
"r"(r_mxOs[x][0]), "r"(r_mxOs[x][1]));
}
}
}
__syncthreads();
if (iK * cn<tileK_> < N_)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
* G * N
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
g_mxOs_
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
+ (jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *P + nStart * cn<tileN_>
+ thread(iStep) % cn<tileN_>));
}
else if (iK * cn<tileK_> < N_ + Q_)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxCB_
+ get((cStart + blockIdx_y) * G * Q * Q + gStart * Q * Q
+ (mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *Q + jK * cn<tileK_>
- N + thread(iStep) % cn<tileK_>));
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_> + N)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
g_mxX_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *H * P
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxOs
+ swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>))
= int4{0, 0, 0, 0};
}
asm volatile("cp.async.commit_group;\n" ::);
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
}
if (g_mxD_)
{
float r_D = g_mxD_[hStart.var];
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
Tp_ tmp16[4] = {0};
float tmp32[4] = {0};
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[0] = *(int*) (g_mxX_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[0] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[0])
: bf1622float2(*(bf162*) &tmp16[0]);
r_mxY[y][x][0] += r_D * tmp32[0];
r_mxY[y][x][1] += r_D * tmp32[1];
}
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[2] = *(int*) (g_mxX_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[2] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[2])
: bf1622float2(*(bf162*) &tmp16[2]);
r_mxY[y][x][2] += r_D * tmp32[2];
r_mxY[y][x][3] += r_D * tmp32[3];
}
}
}
if (g_mxZ_)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
Tp_ tmp16[4] = {0};
float tmp32[4] = {0};
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[0] = *(int*) (g_mxZ_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[0] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[0])
: bf1622float2(*(bf162*) &tmp16[0]);
r_mxY[y][x][0] *= tmp32[0] > 32.f ? tmp32[0] : tmp32[0] / (1.f + expf(-tmp32[0]));
r_mxY[y][x][1] *= tmp32[1] > 32.f ? tmp32[1] : tmp32[1] / (1.f + expf(-tmp32[1]));
}
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[2] = *(int*) (g_mxZ_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[2] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[2])
: bf1622float2(*(bf162*) &tmp16[2]);
r_mxY[y][x][2] *= tmp32[2] > 32.f ? tmp32[2] : tmp32[2] / (1.f + expf(-tmp32[2]));
r_mxY[y][x][3] *= tmp32[3] > 32.f ? tmp32[3] : tmp32[3] / (1.f + expf(-tmp32[3]));
}
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (std::is_same_v<Tp_, half>)
{
*(half2*) &r_mxY[y][x][0] = __floats2half2_rn(r_mxY[y][x][0], r_mxY[y][x][1]);
*(half2*) &r_mxY[y][x][2] = __floats2half2_rn(r_mxY[y][x][2], r_mxY[y][x][3]);
}
else
{
*(bf162*) &r_mxY[y][x][0] = __floats2bfloat162_rn(r_mxY[y][x][0], r_mxY[y][x][1]);
*(bf162*) &r_mxY[y][x][2] = __floats2bfloat162_rn(r_mxY[y][x][2], r_mxY[y][x][3]);
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY
+ 2
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_
+ (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
"r"(*(unsigned*) &r_mxY[y][x][0]));
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY
+ 2
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_
+ x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
"r"(*(unsigned*) &r_mxY[y][x][2]));
}
__syncthreads();
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileN_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileN_>
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
*(int4*) (g_mxY_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileN_>) *H * P + hStart * P
+ nStart * cn<tileN_> + thread(iStep) % cn<tileN_>))
= *(int4*) ((char*) s_mxY + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>));
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
#endif
}
ChunkScanKernelFuncFp16 getChunkScanKernelFp16(
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 128;
int tileN = 64;
int tileK = 32;
int warpM = 4;
int warpN = 1;
int pipeS = 2;
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2);
*blockDims_ = dim3(H * P / tileN * Q / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>;
else if (Q_ == 256)
return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>;
else
return nullptr;
}
ChunkScanKernelFuncBf16 getChunkScanKernelBf16(
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 128;
int tileN = 64;
int tileK = 32;
int warpM = 4;
int warpN = 1;
int pipeS = 2;
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2);
*blockDims_ = dim3(H * P / tileN * Q / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>;
else if (Q_ == 256)
return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>;
else
return nullptr;
}
} // namespace kernels
} // namespace tensorrt_llm
// vim: ts=2 sw=2 sts=2 et sta

View File

@ -0,0 +1,453 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp8.h>
#include <mma.h>
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "Common.h"
#include "Poly.h"
namespace tensorrt_llm
{
namespace kernels
{
typedef void (*ChunkStateKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
// const half *g_mxY_, // B*L*H*P
// const half *g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
float* g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const half *g_mxCB_, // B*C*G*Q*Q
half const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
half const* g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
// const bf16 *g_mxY_, // B*L*H*P
// const bf16 *g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
float* g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const bf16 *g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
bf16 const* g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
int warpM_, int warpN_, // warp number
int pipeS_, class Tp_>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_state_kernel(int B_,
int L_, int H_, int P_, int G_, int N_,
// const Tp_ *g_mxY_, // B*L*H*P
// const Tp_ *g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
float* g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
Tp_ const* g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
using namespace tensorrt_llm::common;
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
auto H = Rn<ID>{H_};
auto P = Rn<ID>{P_};
auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
if (removePadding_)
{
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
C = Rn<ID>{div_up(L.var, Q_)};
}
else
{
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
C = Rn<ID>{div_up(L.var, Q_)};
}
if (blockIdx_y * Q >= L)
return;
auto hStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) / (N_ / cn<tileM_>) };
auto mStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) % (N_ / cn<tileM_>) };
auto nStart = Rn<ID>{blockIdx_x.var % (P_ / cn<tileN_>) };
auto gStart = Rn<ID>{hStart.var / (H_ / G_)};
extern __shared__ float smem[];
Tp_* s_mxB = (Tp_*) smem;
Tp_* s_mxX = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2;
float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_;
unsigned b_base = __nvvm_get_smem_pointer(smem);
unsigned b_mxB = b_base;
unsigned b_mxX = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
using std::array;
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxSt
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxB;
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxX;
constexpr int step = std::max(
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
auto baseB = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
auto baseX = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
auto thread = [=](auto iStep)
{
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
+ threadIdx_x * cn<8>;
};
#pragma unroll
for (Rn<UNROLL, div_up(Q_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<Q_>)
{
#pragma unroll
for (int i = 0; i < 8; i += 4)
{
*(int4*) (s_mxdc + get(thread(iStep)) + i)
= *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
*(int4*) (s_mxdA + get(thread(iStep)) + i)
= *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
}
}
#pragma unroll
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - iK * cn<tileK_>)
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - iK * cn<tileK_>)
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>),
g_mxX_
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>))
= int4{0, 0, 0, 0};
cp_commit_group();
}
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
for (int iK = pipeS_; iK < Q_ / tileK_ + pipeS_; iK++)
{
auto jK = Rn<>{iK};
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>)
{
register Tp_ tmpB[8];
*(int4*) &tmpB[0] = *(
int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>));
#pragma unroll
for (int i = 0; i < 8; i += 2)
{
float2 tmp2 = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmpB[i])
: bf1622float2(*(bf162*) &tmpB[i]);
int kStart = (iK - pipeS_) * cn<tileK_>;
tmp2.x *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn<tileM_>)])
* s_mxdc[kStart + get(thread(iStep) / cn<tileM_>)];
tmp2.y *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn<tileM_>)])
* s_mxdc[kStart + get(thread(iStep) / cn<tileM_>)];
if (std::is_same_v<Tp_, half>)
*(half2*) &tmpB[i] = __float22half2_rn(tmp2);
else
*(bf162*) &tmpB[i] = __float22bfloat162_rn(tmp2);
}
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
= *(int4*) &tmpB[0];
}
__syncthreads();
#pragma unroll
for (int k = 0; k < tileK_ / wmmaK_; k++)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
{
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
int y1 = x1 - tileN_ / wmmaN_ / warpN_
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
{
if (wmmaK_ == 16)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxB[y1][0]), "=r"(r_mxB[y1][1]), "=r"(r_mxB[y1][2]), "=r"(r_mxB[y1][3])
: "r"(b_mxB + iK % pipeS_ * (tileM_ * tileK_ * 2)
+ 2
* swz<tileM_ * 2, tileM_>(y1 * warpM_ * wmmaM_ + k * wmmaK_ * tileM_
+ threadIdx.z * wmmaM_ + threadIdx.x % 8 * tileM_
+ threadIdx.x / 8 % 2 * 8 + threadIdx.x / wmmaK_ * 8 * tileM_)));
}
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
{
if (wmmaK_ == 16 && x1 % 2 == 0)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxX[x1][0]), "=r"(r_mxX[x1][1]), "=r"(r_mxX[x1 + 1][0]),
"=r"(r_mxX[x1 + 1][1])
: "r"(b_mxX + iK % pipeS_ * (tileK_ * tileN_ * 2)
+ 2
* swz<tileN_ * 2, tileN_>(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_
+ threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_)));
}
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (wmmaK_ == 16)
{
if (std::is_same_v<Tp_, half>)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]),
"+f"(r_mxSt[y][x][3])
: "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]),
"r"(r_mxX[x][0]), "r"(r_mxX[x][1]));
else
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]),
"+f"(r_mxSt[y][x][3])
: "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]),
"r"(r_mxX[x][0]), "r"(r_mxX[x][1]));
}
}
}
__syncthreads();
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_> && thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - jK * cn<tileK_>
&& jK * cn<tileK_> < Q)
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
else if (thread(iStep) < cn<tileM_ * tileK_> && jK * cn<tileK_> < Q)
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_> && thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_>
&& jK * cn<tileK_> < Q)
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>),
g_mxX_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_> && jK * cn<tileK_> < Q)
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>))
= int4{0, 0, 0, 0};
asm volatile("cp.async.commit_group;\n" ::);
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
*(float2*) (g_mxSt_
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
+ (mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + threadIdx_z * cn<wmmaM_>
+ threadIdx_x / cn<4>) *P
+ nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_> + threadIdx_y * cn<wmmaN_>
+ threadIdx_x % cn<4> * cn<2>))
= *(float2*) &r_mxSt[y][x][0];
*(float2*) (g_mxSt_
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
+ (mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8> + threadIdx_z * cn<wmmaM_>
+ threadIdx_x / cn<4>) *P
+ nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_> + threadIdx_y * cn<wmmaN_>
+ threadIdx_x % cn<4> * cn<2>))
= *(float2*) &r_mxSt[y][x][2];
}
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
#endif
}
ChunkStateKernelFuncFp16 getChunkStateKernelFp16(
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 64;
int tileN = 64;
int tileK = 32;
int warpM = 1;
int warpN = 2;
int pipeS = 3;
auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8;
*blockDims_ = dim3(H * P / tileN * N / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>;
else if (Q_ == 256)
return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>;
else
return nullptr;
}
ChunkStateKernelFuncBf16 getChunkStateKernelBf16(
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 64;
int tileN = 64;
int tileK = 32;
int warpM = 1;
int warpN = 2;
int pipeS = 3;
auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8;
*blockDims_ = dim3(H * P / tileN * N / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>;
else if (Q_ == 256)
return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>;
else
return nullptr;
}
} // namespace kernels
} // namespace tensorrt_llm
// vim: ts=2 sw=2 sts=2 et sta

View File

@ -0,0 +1,241 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp8.h>
#include <mma.h>
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "Common.h"
#include "Poly.h"
namespace tensorrt_llm
{
namespace kernels
{
typedef void (*StatePassingKernelFuncFp16)(int B_, int L_, int H_, int P_, int N_,
// const half *g_mxY_, // B*L*H*P
half* g_mxOs_, // B*C*H*N*P
half* g_mxFs_, // B *H*N*P
float const* g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const half *g_mxCB_, // B*C*G*Q*Q
// const half *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N_,
// const bf16 *g_mxY_, // B*L*H*P
bf16* g_mxOs_, // B*C*H*N*P
bf16* g_mxFs_, // B *H*N*P
float const* g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const bf16 *g_mxCB_, // B*C*G*Q*Q
// const bf16 *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
template <int Q_, int tileH_, int warpH_, class Tp_>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> state_passing_kernel(
int B_, int L_, int H_, int P_, int N_,
// const Tp_ *g_mxY_, // B*L*H*P
Tp_* g_mxOs_, // B*C*H*N*P
Tp_* g_mxFs_, // B *H*N*P
float const* g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
// const Tp_ *g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_)
{
using namespace tensorrt_llm::common;
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
auto threadIdx_y = Rn<ID, warpH_>{int(threadIdx.y)};
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
auto H = Rn<ID>{H_};
auto P = Rn<ID>{P_};
// auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
if (removePadding_)
{
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
C = Rn<ID>{div_up(L.var, Q_)};
}
else
{
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
C = Rn<ID>{div_up(L.var, Q_)};
}
if (stateSlotMappingPtr_)
{
g_mxFs_ += stateSlotMappingPtr_[blockIdx.z] * H_ * N_ * P_;
}
else
{
g_mxFs_ += blockIdx.z * H_ * N_ * P_;
}
auto hStart = Rn<ID>{blockIdx_x.var * tileH_ / N_ / P_};
register Tp_ r_mxOs[tileH_ / (warpH_ * 32)] = {0};
register float r_mxSt[tileH_ / (warpH_ * 32)] = {0};
for (int iC = 0; iC < C.var; iC++)
{
if (std::is_same_v<Tp_, half>)
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
*(half2*) &r_mxOs[i] = __float22half2_rn(*(float2*) &r_mxSt[i]);
else
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
*(bf162*) &r_mxOs[i] = __float22bfloat162_rn(*(float2*) &r_mxSt[i]);
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
*(int*) (g_mxOs_
+ get((cStart + Rn<>{iC}) * H * N * P + blockIdx_x * cn<tileH_>
+ (threadIdx_y * cn<32> + threadIdx_x) * cn<tileH_ / (warpH_ * 32)> + Rn<UNROLL>{i}))
= *(int*) &r_mxOs[i];
float scale = expf(g_mxdA_[get((cStart + Rn<>{iC}) * H * Q + hStart * Q + Q - cn<1>)]);
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i++)
{
float tmp = g_mxSt_[get((cStart + Rn<>{iC}) * H * N * P + blockIdx_x * cn<tileH_>
+ (threadIdx_y * cn<32> + threadIdx_x) * cn<tileH_ / (warpH_ * 32)> + Rn<UNROLL>{i})];
r_mxSt[i] = scale * r_mxSt[i] + tmp;
}
}
if (std::is_same_v<Tp_, half>)
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
*(half2*) &r_mxOs[i] = __float22half2_rn(*(float2*) &r_mxSt[i]);
else
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2)
*(bf162*) &r_mxOs[i] = __float22bfloat162_rn(*(float2*) &r_mxSt[i]);
#pragma unroll
for (int i = 0; i < tileH_ / (warpH_ * 32); i += 8)
*(int4*) (g_mxFs_
+ get(blockIdx_x * cn<tileH_> + (threadIdx_y * cn<32> + threadIdx_x) * cn<tileH_ / (warpH_ * 32)>
+ Rn<UNROLL>{i}))
= *(int4*) &r_mxOs[i];
}
StatePassingKernelFuncFp16 getStatePassingKernelFp16(
int B_, int L_, int H_, int P_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileH = 1024;
int warpH = 8;
auto sharedMem = 0;
*blockDims_ = dim3(H * N * P / tileH, 1, B);
*threadDims_ = dim3(32, warpH);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return state_passing_kernel<128, 1024, 8, half>;
else if (Q_ == 256)
return state_passing_kernel<256, 1024, 8, half>;
else
return nullptr;
}
StatePassingKernelFuncBf16 getStatePassingKernelBf16(
int B_, int L_, int H_, int P_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileH = 1024;
int warpH = 8;
auto sharedMem = 0;
*blockDims_ = dim3(H * N * P / tileH, 1, B);
*threadDims_ = dim3(32, warpH);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return state_passing_kernel<128, 1024, 8, bf16>;
else if (Q_ == 256)
return state_passing_kernel<256, 1024, 8, bf16>;
else
return nullptr;
}
} // namespace kernels
} // namespace tensorrt_llm
// vim: ts=2 sw=2 sts=2 et sta

View File

@ -61,9 +61,19 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin)
{
return false;
}
if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads == xqaParams.num_kv_heads)
if (xqaParams.num_kv_heads != 0 && xqaParams.num_q_heads % xqaParams.num_kv_heads != 0)
{
// Do not use XQA kernel for MHA.
return false;
}
bool is_vanilla_mha = xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads == xqaParams.num_kv_heads;
if (is_vanilla_mha && xqaParams.beam_width == 1)
{
// Do not use XQA kernel for vanilla MHA case for performance reasons.
return false;
}
if (is_vanilla_mha && xqaParams.head_size <= 128)
{
// TODO(yaoy): remove this when the kernel bug for num_kv_heads <= 128 gets fixed.
return false;
}
if (xqaParams.multi_block_mode)
@ -108,11 +118,7 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu
{
return false;
}
if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads % xqaParams.num_kv_heads != 0)
{
return false;
}
int32_t head_grp_size = xqaParams.num_q_heads / xqaParams.num_kv_heads;
int32_t head_grp_size = xqaParams.num_kv_heads == 0 ? 1 : xqaParams.num_q_heads / xqaParams.num_kv_heads;
if (head_grp_size * xqaParams.beam_width > 32)
{
return false;
@ -150,11 +156,7 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug
{
return false;
}
if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads % xqaParams.num_kv_heads != 0)
{
return false;
}
int32_t head_grp_size = xqaParams.num_q_heads / xqaParams.num_kv_heads;
int32_t head_grp_size = xqaParams.num_kv_heads == 0 ? 1 : xqaParams.num_q_heads / xqaParams.num_kv_heads;
if (head_grp_size * xqaParams.beam_width > 32)
{
return false;

View File

@ -1,2 +1,2 @@
8b0f8deb35940359b39f876fc5e94e4f libtensorrt_llm_nvrtc_wrapper.so
0e1417f27d93de67940c1062cf230017cd8be5f1 commit
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53746a0351295accb650f9e509303914ae8d8dc3c2605baf680f30cfc40d96f6
oid sha256:78209a1351f9f21f635bf9f763f4947031ea12b7526c5782094e9869b667a23f
size 1091072

View File

@ -482,12 +482,11 @@ __global__ void finalizeKernel(BeamHypotheses bh)
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s %s start", __FILE__, __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__);
int const nBM = bh.nBeamWidth;
size_t const smem_size = sizeof(int) * nBM * 2 + sizeof(float) * nBM * 2;
finalizeKernel<<<bh.nBatchSize, roundUp(nBM * 2, 32), smem_size, stream>>>(bh);
TLLM_LOG_TRACE("%s %s stop", __FILE__, __PRETTY_FUNCTION__);
}
__global__ void initializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const nMaxSeqLen)

View File

@ -28,6 +28,12 @@
#include "selectiveScan.h"
#include "chunkScan/bmmchunk.h"
#include "chunkScan/chunkcumsum.h"
#include "chunkScan/chunkscan.h"
#include "chunkScan/chunkstate.h"
#include "chunkScan/statepassing.h"
namespace tensorrt_llm
{
namespace kernels
@ -319,8 +325,6 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream)
int samples = params.batch;
int channels = params.dim;
TLLM_CHECK(params.is_variable_B);
TLLM_CHECK(params.is_variable_C);
TLLM_CHECK(params.dstate == 16);
int const threads = 128;
@ -331,6 +335,107 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream)
selective_scan_loop_kernel<input_t, weight_t><<<grid, block, 0, stream>>>(params);
}
template <typename input_t, typename weight_t>
void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
{
int B = params.batch;
int L = params.max_seqlen;
int H = params.nheads;
int P = params.dim / H;
int G = params.ngroups;
int N = params.dstate;
int Q = params.chunk_size;
bool dtsp = params.delta_softplus;
if constexpr (std::is_same_v<input_t, half>)
{
dim3 bds[5], tds[5];
int shms[5];
ChunkCumsumKernelFuncFp16 chunk_cumsum = getChunkCumsumKernelFp16(B, L, H, Q, dtsp, &bds[0], &tds[0], &shms[0]);
ChunkStateKernelFuncFp16 chunk_state = getChunkStateKernelFp16(B, L, H, P, G, N, Q, &bds[1], &tds[1], &shms[1]);
StatePassingKernelFuncFp16 state_passing
= getStatePassingKernelFp16(B, L, H, P, N, Q, &bds[2], &tds[2], &shms[2]);
BmmChunkKernelFuncFp16 bmm_chunk = getBmmChunkKernelFp16(B, L, G, N, Q, &bds[3], &tds[3], &shms[3]);
ChunkScanKernelFuncFp16 chunk_scan = getChunkScanKernelFp16(B, L, H, P, G, N, Q, &bds[4], &tds[4], &shms[4]);
half* mxY = (half*) params.out_ptr;
half* mxOs = (half*) params.Os_ptr;
half* mxFs = (half*) params.x_ptr;
float* mxSt = (float*) params.St_ptr;
float* mxdc = (float*) params.dc_ptr;
float* mxdA = (float*) params.dA_ptr;
half const* mxdt = (half const*) params.delta_ptr;
float const* mxdb = (float const*) params.delta_bias_ptr;
float const* mxA = (float const*) params.A_ptr;
half* mxCB = (half*) params.CB_ptr;
half const* mxBC = (half const*) params.BC_ptr;
float const* mxD = (float const*) params.D_ptr;
half const* mxX = (half const*) params.u_ptr;
half const* mxZ = (half const*) params.z_ptr;
auto rp = params.remove_padding;
auto ltip = params.last_token_ids_ptr;
auto ssmp = params.slot_mapping_ptr;
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
}
else if constexpr (std::is_same_v<input_t, __nv_bfloat16>)
{
dim3 bds[5], tds[5];
int shms[5];
ChunkCumsumKernelFuncBf16 chunk_cumsum = getChunkCumsumKernelBf16(B, L, H, Q, dtsp, &bds[0], &tds[0], &shms[0]);
ChunkStateKernelFuncBf16 chunk_state = getChunkStateKernelBf16(B, L, H, P, G, N, Q, &bds[1], &tds[1], &shms[1]);
StatePassingKernelFuncBf16 state_passing
= getStatePassingKernelBf16(B, L, H, P, N, Q, &bds[2], &tds[2], &shms[2]);
BmmChunkKernelFuncBf16 bmm_chunk = getBmmChunkKernelBf16(B, L, G, N, Q, &bds[3], &tds[3], &shms[3]);
ChunkScanKernelFuncBf16 chunk_scan = getChunkScanKernelBf16(B, L, H, P, G, N, Q, &bds[4], &tds[4], &shms[4]);
__nv_bfloat16* mxY = (__nv_bfloat16*) params.out_ptr;
__nv_bfloat16* mxOs = (__nv_bfloat16*) params.Os_ptr;
__nv_bfloat16* mxFs = (__nv_bfloat16*) params.x_ptr;
float* mxSt = (float*) params.St_ptr;
float* mxdc = (float*) params.dc_ptr;
float* mxdA = (float*) params.dA_ptr;
__nv_bfloat16 const* mxdt = (__nv_bfloat16 const*) params.delta_ptr;
float const* mxdb = (float const*) params.delta_bias_ptr;
float const* mxA = (float const*) params.A_ptr;
__nv_bfloat16* mxCB = (__nv_bfloat16*) params.CB_ptr;
__nv_bfloat16 const* mxBC = (__nv_bfloat16 const*) params.BC_ptr;
float const* mxD = (float const*) params.D_ptr;
__nv_bfloat16 const* mxX = (__nv_bfloat16 const*) params.u_ptr;
__nv_bfloat16 const* mxZ = (__nv_bfloat16 const*) params.z_ptr;
auto rp = params.remove_padding;
auto ltip = params.last_token_ids_ptr;
auto ssmp = params.slot_mapping_ptr;
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
}
}
#define INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE(input_t, weight_t) \
template void invokeSelectiveScan<input_t, weight_t>(SSMParamsBase & params, cudaStream_t stream);
@ -341,9 +446,19 @@ INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE(__nv_bfloat16, float);
#endif
#undef INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE
#define INSTANTIATE_CHUNK_SCAN_DATA_TYPE(input_t, weight_t) \
template void invokeChunkScan<input_t, weight_t>(SSMParamsBase & params, cudaStream_t stream);
INSTANTIATE_CHUNK_SCAN_DATA_TYPE(float, float);
INSTANTIATE_CHUNK_SCAN_DATA_TYPE(half, float);
#ifdef ENABLE_BF16
INSTANTIATE_CHUNK_SCAN_DATA_TYPE(__nv_bfloat16, float);
#endif
#undef INSTANTIATE_CHUNK_SCAN_DATA_TYPE
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128>
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, bool MAMBA_V1 = true>
__launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParamsBase params)
{
@ -359,15 +474,21 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
weight_t* dt_bias = reinterpret_cast<weight_t*>(params.delta_bias_ptr);
bool dt_softplus = params.delta_softplus;
int num_channels = params.dim;
int nheads = params.nheads;
int ngroups = params.ngroups;
int const channel = blockIdx.x * blockDim.x + threadIdx.x;
if (channel >= num_channels)
return;
int const sample = blockIdx.y;
int const head_dim = num_channels / nheads;
int const head = channel / head_dim;
int const head_chl = channel % head_dim;
int const group = head / (nheads / ngroups);
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
int const bc_cols = DSTATE * 2 + params.dt_rank;
int const b_offset = params.dt_rank;
int const c_offset = params.dt_rank + DSTATE;
int const bc_offset = MAMBA_V1 ? sample * (DSTATE * 2 + params.dt_rank) : sample * DSTATE * ngroups * 2;
int const b_offset = MAMBA_V1 ? params.dt_rank : DSTATE * group;
int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : DSTATE * (ngroups + group);
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
input_t* my_output = &output[sample * num_channels];
@ -375,30 +496,45 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
float rA[DSTATE];
float rB[DSTATE];
float rC[DSTATE];
float rState[DSTATE];
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
rA[i] = toFloat(A[i * num_channels + channel]);
rB[i] = toFloat(B[sample * bc_cols + b_offset + i]);
rC[i] = toFloat(C[sample * bc_cols + c_offset + i]);
rState[i] = toFloat(my_state[i * num_channels + channel]);
}
float my_x, my_dt, my_z, my_dt_bias, my_D;
my_x = toFloat(x[sample * num_channels + channel]);
my_dt = toFloat(dt[sample * num_channels + channel]);
my_z = z ? toFloat(z[sample * num_channels + channel]) : 0.f;
my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f;
my_D = D ? toFloat(D[channel]) : 0.f;
if (MAMBA_V1)
{
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
rA[i] = toFloat(A[i * num_channels + channel]);
rB[i] = toFloat(B[bc_offset + b_offset + i]);
rC[i] = toFloat(C[bc_offset + c_offset + i]);
rState[i] = toFloat(my_state[i * num_channels + channel]);
}
my_dt = toFloat(dt[sample * num_channels + channel]);
my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f;
my_D = D ? toFloat(D[channel]) : 0.f;
}
else
{
float A_tmp = toFloat(A[head]);
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
rA[i] = A_tmp;
rB[i] = toFloat(B[bc_offset + b_offset + i]);
rC[i] = toFloat(C[bc_offset + c_offset + i]);
rState[i] = toFloat(my_state[(head * DSTATE + i) * head_dim + head_chl]);
}
my_dt = toFloat(dt[sample * nheads + head]);
my_dt_bias = dt_bias ? toFloat(dt_bias[head]) : 0.f;
my_D = D ? toFloat(D[head]) : 0.f;
}
float dt_b = my_dt + my_dt_bias;
float dt_b_sp;
if (dt_softplus)
{
// dt_b_sp = dt_b <= 20.f ? logf(1.f + expf(dt_b)) : dt_b; // softplus
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
}
@ -407,19 +543,21 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
// float dA = expf(rA[i] * dt_b_sp);
float dA = __expf(rA[i] * dt_b_sp);
float dB = rB[i] * dt_b_sp;
float sdA = rState[i] * dA;
float dBx = dB * my_x;
float newState = sdA + dBx;
convertAndStore(&my_state[i * num_channels + channel], newState); // Write the new state back out to the cache
// Write the new state back out to the cache
if (MAMBA_V1)
convertAndStore(&my_state[i * num_channels + channel], newState);
else
convertAndStore(&my_state[(head * DSTATE + i) * head_dim + head_chl], newState);
out += newState * rC[i];
}
if (z)
{
// float sig_z = 1.0 / (1.0 + exp(0.f - my_z));
float sig_z = __fdividef(1.f, (1.f + __expf(0.f - my_z)));
float silu_z = my_z * sig_z;
out *= silu_z;
@ -433,16 +571,25 @@ void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream)
{
int samples = params.batch;
int channels = params.dim;
int nheads = params.nheads;
int ngroups = params.ngroups;
int const threads = 128;
int const blocks = (channels + threads - 1) / threads;
dim3 block(threads, 1);
dim3 grid(blocks, samples);
TLLM_CHECK(params.is_variable_B);
TLLM_CHECK(params.is_variable_C);
TLLM_CHECK(params.dstate == 16);
selective_scan_update_kernel<input_t, weight_t><<<grid, block, 0, stream>>>(params);
TLLM_CHECK_WITH_INFO(nheads % ngroups == 0, "nheads must be divisible by ngroups");
if (params.is_mamab2)
{
TLLM_CHECK(params.dstate == 128);
selective_scan_update_kernel<input_t, weight_t, 128, 128, false><<<grid, block, 0, stream>>>(params);
}
else
{
TLLM_CHECK(params.dstate == 16);
selective_scan_update_kernel<input_t, weight_t, 16, 128, true><<<grid, block, 0, stream>>>(params);
}
}
#define INSTANTIATE_SELECTIVE_SCAN_UPDATE_DATA_TYPE(input_t, weight_t) \

View File

@ -40,13 +40,11 @@ namespace kernels
struct SSMParamsBase
{
int batch, dim, dstate, dt_rank;
int batch, dim, dstate, dt_rank, nheads, ngroups, chunk_size;
int max_seqlen; // only valid for padded input.
bool remove_padding;
bool is_variable_B;
bool is_variable_C;
bool delta_softplus;
bool is_mamab2;
// Common data pointers.
void* __restrict__ A_ptr;
@ -58,6 +56,12 @@ struct SSMParamsBase
void* __restrict__ out_ptr;
void* __restrict__ x_ptr;
void* __restrict__ z_ptr;
// Workspace data pointers.
void* __restrict__ Os_ptr;
void* __restrict__ St_ptr;
void* __restrict__ dc_ptr;
void* __restrict__ dA_ptr;
void* __restrict__ CB_ptr;
int const* __restrict__ last_token_ids_ptr;
int const* __restrict__ slot_mapping_ptr;
};
@ -67,6 +71,9 @@ struct SSMParamsBase
template <typename input_t, typename weight_t>
void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream);
template <typename input_t, typename weight_t>
void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream);
template <typename input_t, typename weight_t>
void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream);
} // namespace kernels

View File

@ -158,7 +158,7 @@ void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileT
common::check_cuda_error(cudaStreamCreate(&mStream));
int const startMinMRounded = nextPowerOfTwo(dims.minM);
for (int m = startMinMRounded; m < maxM; m *= 2)
for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2)
{
profileTactics(m, dims.n, dims.k);
}
@ -184,7 +184,7 @@ std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHa
return std::nullopt;
}
int const mRounded = std::min(nextPowerOfTwo(m), getMaxProfileM());
int const mRounded = std::min(std::max(1, nextPowerOfTwo(m)), getMaxProfileM());
fflush(stdout);
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
}

View File

@ -275,8 +275,23 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
{
TLLM_CHECK(mHeadSize > 0);
int const beamWidth
= isCrossAttention() ? 1 : (useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1] : 1);
int beamWidth = -1;
if (!isCrossAttention() && useKVCache())
{
// desc_val == -1 means beam_width is not static, we should look at min/max/opt.
//
// In prepareEnqueueGeneration, we'll prepare for all cases where beam_width doesn't exceed max.
// TODO(minwei): pass min AND max to prepareEnqueueGeneration instead of max only.
int desc_val = in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1];
int max_val = in[getIdx(IdxEntry::CACHE_INDIR)].max.d[1];
beamWidth = desc_val == -1 ? max_val : desc_val;
}
else
{
beamWidth = 1;
}
TLLM_CHECK(beamWidth != -1);
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
// unless each layer has different attention window sizes.
// the kv_cache capacity.

View File

@ -29,18 +29,22 @@ static char const* SELECTIVE_SCAN_PLUGIN_NAME{"SelectiveScan"};
PluginFieldCollection SelectiveScanPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> SelectiveScanPluginCreator::mPluginAttributes;
SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dt_rank, bool isVariableB, bool isVariableC,
bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState)
SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize,
bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2)
: mDim(dim)
, mDState(dstate)
, mDtRank(dt_rank)
, mIsVariableB(isVariableB)
, mIsVariableC(isVariableC)
, mDtRank(dtRank)
, mNHeads(nHeads)
, mNGroups(nGroups)
, mChunkSize(chunkSize)
, mDeltaSoftplus(deltaSoftplus)
, mType(type)
, mRemovePadding(removePadding)
, mPagedState(pagedState)
, mZEnabled(zEnabled)
, mIsMamba2(isMamba2)
{
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (!mIsMamba2), "Pre SM 80 GPUs do not support Mamba2");
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16),
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF),
@ -54,12 +58,15 @@ SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length)
read(d, mDim);
read(d, mDState);
read(d, mDtRank);
read(d, mIsVariableB);
read(d, mIsVariableC);
read(d, mNHeads);
read(d, mNGroups);
read(d, mChunkSize);
read(d, mDeltaSoftplus);
read(d, mType);
read(d, mRemovePadding);
read(d, mPagedState);
read(d, mZEnabled);
read(d, mIsMamba2);
TLLM_CHECK(d == a + length);
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type");
TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF),
@ -69,8 +76,8 @@ SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length)
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* SelectiveScanPlugin::clone() const noexcept
{
auto* plugin = new SelectiveScanPlugin(
mDim, mDState, mDtRank, mIsVariableB, mIsVariableC, mDeltaSoftplus, mType, mRemovePadding, mPagedState);
auto* plugin = new SelectiveScanPlugin(mDim, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, mDeltaSoftplus, mType,
mRemovePadding, mPagedState, mZEnabled, mIsMamba2);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
@ -91,7 +98,8 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions(
bool SelectiveScanPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx() || (mPagedState && pos == getSlotMappingIdx()))
if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx()
|| (mRemovePadding && pos == getHostContextLengthIdx()) || (mPagedState && pos == getSlotMappingIdx()))
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
@ -117,14 +125,56 @@ void SelectiveScanPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc cons
size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
return 0;
if (!mIsMamba2)
return 0;
int const NUM_BUFFERS = 5;
size_t workspaces[NUM_BUFFERS];
if (mRemovePadding)
{
int B = inputs[getLastTokenIdsIdx()].dims.d[0];
int BxL = inputs[getInputTensorIdx()].dims.d[0]; // num_tokens
int H = mNHeads;
int P = inputs[getInputTensorIdx()].dims.d[1] / H;
int G = mNGroups;
int N = inputs[getBCIdx()].dims.d[1] / G / 2;
int Q = mChunkSize;
int BxC = (BxL + Q - 1) / Q + B;
workspaces[0] = BxC * H * N * P * 2; // g_mxOs_
workspaces[1] = BxC * H * N * P * 4; // g_mxSt_ in float
workspaces[2] = BxC * H * Q * 4; // g_mxdc_ in float
workspaces[3] = BxC * H * Q * 4; // g_mxdA_ in float
workspaces[4] = BxC * G * Q * Q * 2; // g_mxCB_
}
else
{
int B = inputs[getInputTensorIdx()].dims.d[0];
int L = inputs[getInputTensorIdx()].dims.d[1];
int H = mNHeads;
int P = inputs[getInputTensorIdx()].dims.d[2] / H;
int G = mNGroups;
int N = inputs[getBCIdx()].dims.d[2] / G / 2;
int Q = mChunkSize;
int C = (L + Q - 1) / Q;
workspaces[0] = B * C * H * N * P * 2; // g_mxOs_
workspaces[1] = B * C * H * N * P * 4; // g_mxSt_ in float
workspaces[2] = B * C * H * Q * 4; // g_mxdc_ in float
workspaces[3] = B * C * H * Q * 4; // g_mxdA_ in float
workspaces[4] = B * C * G * Q * Q * 2; // g_mxCB_
}
return calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
}
void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch, const size_t dim,
const size_t maxSeqLen, const size_t dstate, const size_t dtRank, bool const isVariableB, bool const isVariableC,
void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC,
void const* D, void const* z, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus,
bool removePadding)
const size_t maxSeqLen, const size_t dstate, const size_t dtRank, const size_t nHeads, const size_t nGroups,
const size_t chunkSize, void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A,
void const* BC, void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr,
void const* dAPtr, void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out,
bool deltaSoftplus, bool removePadding)
{
// Reset the parameters
memset(&params, 0, sizeof(params));
@ -134,12 +184,13 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch
params.max_seqlen = maxSeqLen;
params.dstate = dstate;
params.dt_rank = dtRank;
params.nheads = nHeads;
params.ngroups = nGroups;
params.chunk_size = chunkSize;
params.delta_softplus = deltaSoftplus;
params.remove_padding = removePadding;
params.is_variable_B = isVariableB;
params.is_variable_C = isVariableC;
params.is_mamab2 = mIsMamba2;
// Set the pointers and strides.
params.u_ptr = const_cast<void*>(x);
@ -151,6 +202,11 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch
params.out_ptr = out;
params.x_ptr = statePtr;
params.z_ptr = const_cast<void*>(z);
params.Os_ptr = const_cast<void*>(osPtr);
params.St_ptr = const_cast<void*>(stPtr);
params.dc_ptr = const_cast<void*>(dcPtr);
params.dA_ptr = const_cast<void*>(dAPtr);
params.CB_ptr = const_cast<void*>(cbPtr);
params.last_token_ids_ptr = lastTokenIds;
params.slot_mapping_ptr = slotMapping;
}
@ -162,24 +218,30 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
{
// inputs
// 0. input_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim]
// 1. state [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
// 2. delta [batch_size, max_seq_len, dim] or [num_tokens, dim]
// 3. delta_bias [dim]
// 4. A [dstate, dim]
// 5. BC [batch_size, max_seq_len, dt_rank + dstate * 2] or [num_tokens, dt_rank + dstate * 2]
// 6. D [dim]
// 7. z [batch_size, max_seq_len, dim] or [num_tokens, dim]
// 8. host_request_types [batch_size] int32. 0: context; 1: generation.
// 9. last_token_ids [batch_size] int32
// 1. state mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
// mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state
// 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
// 3. delta_bias, [dim] for mamba, [nheads] for mamba2
// 4. A, [dstate, dim] for mamba, [nheads] for mamba2
// 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
// mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for
// remove_input_padding
// 6. D, [dim] for mamba, [nheads] for mamba2
// 7. host_request_types [batch_size] int32. 0: context; 1: generation.
// 8. last_token_ids [batch_size] int32
// 9. host_context_lengths [batch_size] int32, optional for remove_input_padding
// 10. state_slot_mapping [batch_size] int32, optional for paged state
// 11. z [batch_size, max_seq_len, dim] or [num_tokens, dim]
// outputs
// 0. output_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim]
// 1. state [batch_size, dstate, dim]
// 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2
auto const batch_size = inputDesc[getHostRequestTypesIdx()].dims.d[0];
int max_seq_len;
if (mRemovePadding)
{
max_seq_len = -1;
int const* host_context_length = static_cast<int const*>(inputs[getHostContextLengthIdx()]);
max_seq_len = *std::max_element(host_context_length, host_context_length + batch_size);
}
else
{
@ -192,17 +254,72 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
SSMParamsBase ssm_params;
int const* slotMapping = mPagedState ? static_cast<int const*>(inputs[getSlotMappingIdx()]) : nullptr;
void const* z = mZEnabled ? inputs[getZIdx()] : nullptr;
void* statePtr = mPagedState ? *reinterpret_cast<void**>(const_cast<void*>(inputs[getStateIdx()])) : outputs[1];
setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mIsVariableB, mIsVariableC, statePtr,
// Workspace pointer shift
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
size_t offset = 0;
T* mxOs = nullptr;
float* mxSt = nullptr;
float* mxdc = nullptr;
float* mxdA = nullptr;
T* mxCB = nullptr;
if (!mIsMamba2) /* no workspace needed */
;
else if (mRemovePadding)
{
int B = inputDesc[getLastTokenIdsIdx()].dims.d[0];
int BxL = inputDesc[getInputTensorIdx()].dims.d[0]; // num_tokens
int H = mNHeads;
int P = inputDesc[getInputTensorIdx()].dims.d[1] / H;
int G = mNGroups;
int N = inputDesc[getBCIdx()].dims.d[1] / G / 2;
int Q = mChunkSize;
int BxC = (BxL + Q - 1) / Q + B;
mxOs = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 2));
mxSt = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 4));
mxdc = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4));
mxdA = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4));
mxCB = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * G * Q * Q * 2));
}
else
{
int B = inputDesc[getInputTensorIdx()].dims.d[0];
int L = inputDesc[getInputTensorIdx()].dims.d[1];
int H = mNHeads;
int P = inputDesc[getInputTensorIdx()].dims.d[2] / H;
int G = mNGroups;
int N = inputDesc[getBCIdx()].dims.d[2] / G / 2;
int Q = mChunkSize;
int C = (L + Q - 1) / Q;
mxOs = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 2));
mxSt = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 4));
mxdc = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4));
mxdA = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4));
mxCB = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * G * Q * Q * 2));
}
setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, statePtr,
inputs[getInputTensorIdx()], inputs[getDeltaIdx()], inputs[getDeltaBiasIdx()], inputs[getAIdx()],
inputs[getBCIdx()], inputs[getDIdx()], inputs[getZIdx()], static_cast<int const*>(inputs[getLastTokenIdsIdx()]),
slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding);
inputs[getBCIdx()], inputs[getDIdx()], z, mxOs, mxSt, mxdc, mxdA, mxCB,
static_cast<int const*>(inputs[getLastTokenIdsIdx()]), slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding);
if (reqTypes[0] == RequestType::kCONTEXT)
{
invokeSelectiveScan<T, float>(ssm_params, stream);
if (mIsMamba2)
{
invokeChunkScan<T, float>(ssm_params, stream);
}
else
{
invokeSelectiveScan<T, float>(ssm_params, stream);
}
}
else if (reqTypes[0] == RequestType::kGENERATION)
{
@ -276,8 +393,9 @@ void SelectiveScanPlugin::terminate() noexcept {}
size_t SelectiveScanPlugin::getSerializationSize() const noexcept
{
return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mIsVariableB) + sizeof(mIsVariableC)
+ sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState);
return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mNHeads) + sizeof(mNGroups) + sizeof(mChunkSize)
+ sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState) + sizeof(mZEnabled)
+ sizeof(mIsMamba2);
}
void SelectiveScanPlugin::serialize(void* buffer) const noexcept
@ -286,12 +404,15 @@ void SelectiveScanPlugin::serialize(void* buffer) const noexcept
write(d, mDim);
write(d, mDState);
write(d, mDtRank);
write(d, mIsVariableB);
write(d, mIsVariableC);
write(d, mNHeads);
write(d, mNGroups);
write(d, mChunkSize);
write(d, mDeltaSoftplus);
write(d, mType);
write(d, mRemovePadding);
write(d, mPagedState);
write(d, mZEnabled);
write(d, mIsMamba2);
assert(d == a + getSerializationSize());
}
@ -306,15 +427,18 @@ SelectiveScanPluginCreator::SelectiveScanPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 16));
mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 16));
mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 16));
mPluginAttributes.emplace_back(PluginField("is_variable_B", nullptr, PluginFieldType::kINT8, 1));
mPluginAttributes.emplace_back(PluginField("is_variable_C", nullptr, PluginFieldType::kINT8, 1));
mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("nheads", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("ngroups", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("chunk_size", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("delta_softplus", nullptr, PluginFieldType::kINT8, 1));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 1));
mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 1));
mPluginAttributes.emplace_back(PluginField("z_enabled", nullptr, PluginFieldType::kINT8, 1));
mPluginAttributes.emplace_back(PluginField("is_mamba2", nullptr, PluginFieldType::kINT8, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
@ -337,8 +461,8 @@ PluginFieldCollection const* SelectiveScanPluginCreator::getFieldNames() noexcep
IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int dim, dstate, dtRank;
bool isVariableB, isVariableC, deltaSoftplus, removePadding, pagedState;
int dim, dstate, dtRank, nHeads, nGroups, chunkSize;
bool deltaSoftplus, removePadding, pagedState, zEnabled, isMamab2;
nvinfer1::DataType type;
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
@ -359,15 +483,20 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFiel
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
dtRank = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "is_variable_B"))
else if (!strcmp(attrName, "nheads"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
isVariableB = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
nHeads = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "is_variable_C"))
else if (!strcmp(attrName, "ngroups"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
isVariableC = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
nGroups = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "chunk_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
chunkSize = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "delta_softplus"))
{
@ -389,11 +518,21 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFiel
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
pagedState = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
}
else if (!strcmp(attrName, "z_enabled"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
zEnabled = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
}
else if (!strcmp(attrName, "is_mamba2"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
isMamab2 = static_cast<bool>(*(static_cast<bool const*>(fields[i].data)));
}
}
try
{
auto* obj = new SelectiveScanPlugin(
dim, dstate, dtRank, isVariableB, isVariableC, deltaSoftplus, type, removePadding, pagedState);
auto* obj = new SelectiveScanPlugin(dim, dstate, dtRank, nHeads, nGroups, chunkSize, deltaSoftplus, type,
removePadding, pagedState, zEnabled, isMamab2);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -30,25 +30,30 @@ namespace tensorrt_llm::plugins
// inputs
// 0. input_tensor [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// 1. state [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
// 2. delta [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// 3. delta_bias [dim]
// 4. A [dstate, dim]
// 5. BC [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
// 6. D [dim]
// 7. z [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// 8. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none.
// 9. last_token_ids [batch_size] int32
// 1. state, mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state
// mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state
// 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
// 3. delta_bias, [dim] for mamba, [nheads] for mamba2
// 4. A, [dstate, dim] for mamba, [nheads] for mamba2
// 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
// mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for
// remove_input_padding
// 6. D, [dim] for mamba, [nheads] for mamba2
// 7. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none.
// 8. last_token_ids [batch_size] int32
// 9. host_context_lengths [batch_size] int32, optional for remove_input_padding
// 10. state_slot_mapping [batch_size] int32, optional for paged state
// 11. z [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// outputs
// 0. output_tensor [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
// 1. state [batch_size, dstate, dim]
// 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2
class SelectiveScanPlugin : public BasePlugin
{
public:
SelectiveScanPlugin(int dim, int dstate, int dt_rank, bool isVariableB, bool isVariableC, bool deltaSoftplus,
nvinfer1::DataType type, bool removePadding, bool pagedState);
SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize, bool deltaSoftplus,
nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2);
SelectiveScanPlugin(void const* data, size_t length);
@ -128,45 +133,63 @@ private:
return 6;
};
IndexType getZIdx() const
IndexType getHostRequestTypesIdx() const
{
return 7;
};
IndexType getHostRequestTypesIdx() const
IndexType getLastTokenIdsIdx() const
{
return 8;
};
IndexType getLastTokenIdsIdx() const
IndexType getHostContextLengthIdx() const
{
return 9;
if (mRemovePadding)
return 9;
else
return 8;
};
IndexType getSlotMappingIdx() const
{
return 10;
if (mPagedState)
return getHostContextLengthIdx() + 1;
else
return getHostContextLengthIdx();
};
IndexType getZIdx() const
{
if (mZEnabled)
return getSlotMappingIdx() + 1;
else
return getSlotMappingIdx();
};
void setSSMParams(tensorrt_llm::kernels::SSMParamsBase& params,
// sizes
const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dstate, const size_t dtRank,
bool const isVariableB, bool const isVariableC,
const size_t nHeads, const size_t nGroups, const size_t chunkSize,
// device pointers
void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC,
void const* D, void const* z, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus,
void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr, void const* dAPtr,
void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus,
bool removePadding);
private:
int mDim;
int mDState;
int mDtRank;
bool mIsVariableB;
bool mIsVariableC;
int mNHeads;
int mNGroups;
int mChunkSize;
bool mDeltaSoftplus;
nvinfer1::DataType mType;
bool mRemovePadding = false;
bool mPagedState = false;
bool mZEnabled = true;
bool mIsMamba2 = false;
};
class SelectiveScanPluginCreator : public BaseCreator

View File

@ -300,6 +300,9 @@ int WeightOnlyQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc const* input
int const n = TLLM_INT32_CAST(inputDesc[1].dims.d[1]);
int const k = TLLM_INT32_CAST(inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]);
if (m == 0)
return 0;
bool const use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled;
#if defined(ENABLE_BF16)
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16,

View File

@ -69,9 +69,9 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(std::move(tensorMap), mRequestId);
inferenceRequest->setIsStreaming(isStreaming());
if (mlogitsPostProcessor)
if (mLogitsPostProcessor)
{
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mlogitsPostProcessor));
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mLogitsPostProcessor));
}
return inferenceRequest;
@ -79,7 +79,7 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
std::string InferenceRequest::serialize() const
{
TLLM_CHECK_WITH_INFO(mlogitsPostProcessor == std::nullopt,
TLLM_CHECK_WITH_INFO(mLogitsPostProcessor == std::nullopt,
"Serializing InferenceRequest with logitsPostProcessor set is not supported."
"Please set the callback after de-serialization");
std::vector<std::int64_t> serialized{toTrtLlm()->serialize()};

View File

@ -228,15 +228,14 @@ void InitBindings(pybind11::module_& m)
std::optional<SizeType32> const&, std::optional<SizeType32> const&,
std::optional<std::list<VecTokens>>, std::optional<std::list<VecTokens>>, std::optional<Tensor>,
std::optional<tle::ExternalDraftTokensConfig>, std::optional<tle::PromptTuningConfig>,
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>, bool>(),
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>>(),
py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false,
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(),
py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(),
py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(),
py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(),
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(),
py::arg("return_all_generated_tokens") = false)
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none())
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
.def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens)
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -255,9 +254,7 @@ void InitBindings(pybind11::module_& m)
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
&tle::Request::setLogitsPostProcessorName)
.def_property(
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
&tle::Request::setReturnAllGeneratedTokens);
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds);
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
py::class_<tle::Result>(m, "Result")

View File

@ -46,8 +46,8 @@ FieldType parseJsonFieldOr(Json const& json, std::string_view name, FieldType de
}
catch (nlohmann::json::out_of_range& e)
{
TLLM_LOG_INFO("Parameter %s cannot be read from json:", std::string(name).c_str());
TLLM_LOG_INFO(e.what());
TLLM_LOG_DEBUG("Parameter %s cannot be read from json:", std::string(name).c_str());
TLLM_LOG_DEBUG(e.what());
}
return value;
}
@ -62,13 +62,13 @@ std::optional<FieldType> parseJsonFieldOptional(Json const& json, std::string_vi
}
catch (nlohmann::json::out_of_range const& e)
{
TLLM_LOG_INFO(e.what());
TLLM_LOG_INFO("Optional value for parameter %s will not be set.", std::string(name).c_str());
TLLM_LOG_DEBUG(e.what());
TLLM_LOG_DEBUG("Optional value for parameter %s will not be set.", std::string(name).c_str());
}
catch (nlohmann::json::type_error const& e)
{
TLLM_LOG_INFO(e.what());
TLLM_LOG_INFO("Optional value for parameter %s will not be set.", std::string(name).c_str());
TLLM_LOG_DEBUG(e.what());
TLLM_LOG_DEBUG("Optional value for parameter %s will not be set.", std::string(name).c_str());
}
return value;
}
@ -427,10 +427,17 @@ GptJsonConfig parseJson(InputType&& input)
auto const& stateSize = pretrainedConfig.at("state_size").template get<SizeType32>();
auto const& convKernel = pretrainedConfig.at("conv_kernel").template get<SizeType32>();
auto const& rnnHiddenSize = pretrainedConfig.at("rnn_hidden_size").template get<SizeType32>();
auto const& rnnConvDimSize = pretrainedConfig.at("rnn_conv_dim_size").template get<SizeType32>();
ModelConfig::RnnConfig rnnConfig{};
rnnConfig.stateSize = stateSize;
rnnConfig.convKernel = convKernel;
rnnConfig.rnnHiddenSize = rnnHiddenSize;
rnnConfig.rnnConvDimSize = rnnConvDimSize;
if (pretrainedConfig.contains("rnn_head_size"))
{
auto const& rnnHeadSize = pretrainedConfig.at("rnn_head_size").template get<SizeType32>();
rnnConfig.rnnHeadSize = rnnHeadSize;
}
modelConfig.setRnnConfig(rnnConfig);
}
}
@ -449,10 +456,17 @@ GptJsonConfig parseJson(InputType&& input)
auto const& stateSize = builderConfig.at("state_size").template get<SizeType32>();
auto const& convKernel = builderConfig.at("conv_kernel").template get<SizeType32>();
auto const& rnnHiddenSize = builderConfig.at("rnn_hidden_size").template get<SizeType32>();
auto const& rnnConvDimSize = builderConfig.at("rnn_conv_dim_size").template get<SizeType32>();
ModelConfig::RnnConfig rnnConfig{};
rnnConfig.stateSize = stateSize;
rnnConfig.convKernel = convKernel;
rnnConfig.rnnHiddenSize = rnnHiddenSize;
rnnConfig.rnnConvDimSize = rnnConvDimSize;
if (builderConfig.contains("rnn_head_size"))
{
auto const& rnnHeadSize = builderConfig.at("rnn_head_size").template get<SizeType32>();
rnnConfig.rnnHeadSize = rnnHeadSize;
}
modelConfig.setRnnConfig(rnnConfig);
}
}

View File

@ -37,7 +37,7 @@ RnnStateBuffers::RnnStateBuffers(
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(modelConfig.isRnnBased());
TLLM_CHECK_WITH_INFO(modelConfig.hasRnnConfig(), "RNN only support Mamba now.");
TLLM_CHECK_WITH_INFO(modelConfig.hasRnnConfig(), "RNN only support Mamba1/Mamba2/RecurrentGemma now.");
auto maxBatchSize = modelConfig.getMaxBatchSize();
auto maxBeamWidth = modelConfig.getMaxBeamWidth();
auto maxBatchBeam = maxBatchSize * maxBeamWidth;
@ -46,23 +46,37 @@ RnnStateBuffers::RnnStateBuffers(
mConvKernel = rnnConfig->convKernel;
mStateSize = rnnConfig->stateSize;
mRnnHiddenSize = rnnConfig->rnnHiddenSize;
mRnnHeadSize = rnnConfig->rnnHeadSize;
mRnnConvDimSize = rnnConfig->rnnConvDimSize;
auto dType = modelConfig.getDataType();
auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism());
mLocalNbLayers = localNbLayers;
mMaxBeamWidth = maxBeamWidth;
mUseMambaConv1dPlugin = modelConfig.useMambaConv1dPlugin();
auto rnnStatesShape = ITensor::makeShape({localNbLayers * maxBatchBeam, mStateSize, mRnnHiddenSize});
auto const rnnStatesShape = [&]()
{
if (mRnnHeadSize > 0)
{
return tensorrt_llm::runtime::ITensor::makeShape(
{localNbLayers * maxBatchBeam, mRnnHiddenSize / mRnnHeadSize, mStateSize, mRnnHeadSize});
}
else
{
return tensorrt_llm::runtime::ITensor::makeShape(
{localNbLayers * maxBatchBeam, mStateSize, mRnnHiddenSize});
}
}();
auto const convStatesShape = [&]()
{
if (mUseMambaConv1dPlugin)
{
return tensorrt_llm::runtime::ITensor::makeShape(
{localNbLayers * maxBatchBeam, mConvKernel - 1, mRnnHiddenSize});
{localNbLayers * maxBatchBeam, mConvKernel - 1, mRnnConvDimSize});
}
else
{
return tensorrt_llm::runtime::ITensor::makeShape(
{localNbLayers * maxBatchBeam, mRnnHiddenSize, mConvKernel - 1});
{localNbLayers * maxBatchBeam, mRnnConvDimSize, mConvKernel - 1});
}
}();
auto& bufferManager = runtime.getBufferManager();
@ -96,18 +110,30 @@ RnnStateBuffers::RnnStateBuffers(
void RnnStateBuffers::reshape(SizeType32 batchSize)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto rnnStatesShape = ITensor::makeShape({mLocalNbLayers * batchSize * mMaxBeamWidth, mStateSize, mRnnHiddenSize});
auto const rnnStatesShape = [&]()
{
if (mRnnHeadSize > 0)
{
return tensorrt_llm::runtime::ITensor::makeShape(
{mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnHiddenSize / mRnnHeadSize, mStateSize, mRnnHeadSize});
}
else
{
return tensorrt_llm::runtime::ITensor::makeShape(
{mLocalNbLayers * batchSize * mMaxBeamWidth, mStateSize, mRnnHiddenSize});
}
}();
auto const convStatesShape = [&]()
{
if (mUseMambaConv1dPlugin)
{
return tensorrt_llm::runtime::ITensor::makeShape(
{mLocalNbLayers * batchSize * mMaxBeamWidth, mConvKernel - 1, mRnnHiddenSize});
{mLocalNbLayers * batchSize * mMaxBeamWidth, mConvKernel - 1, mRnnConvDimSize});
}
else
{
return tensorrt_llm::runtime::ITensor::makeShape(
{mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnHiddenSize, mConvKernel - 1});
{mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnConvDimSize, mConvKernel - 1});
}
}();
rnnStates->reshape(rnnStatesShape);

View File

@ -39,7 +39,8 @@ public:
TensorPtr convStates; // [layer_count * batch_beam, conv_kernel - 1, rnn_hidden_size]
TensorPtr convStatesAlt; // [layer_count * batch_beam, conv_kernel - 1, rnn_hidden_size]
std::vector<TensorPtr> rnnState; // [batch_beam, state_size, rnn_hidden_size]
std::vector<TensorPtr> rnnState; // [batch_beam, state_size, rnn_hidden_size] or
// [batch_beam, num_heads, rnn_hidden_size, rnn_head_size]
std::vector<TensorPtr> convState; // [batch_beam, conv_kernel - 1, rnn_hidden_size]
std::vector<TensorPtr> convStateAlt; // [batch_beam, conv_kernel - 1, rnn_hidden_size]
@ -83,6 +84,8 @@ private:
SizeType32 mConvKernel = 0;
SizeType32 mStateSize = 0;
SizeType32 mRnnHiddenSize = 0;
SizeType32 mRnnHeadSize = 0;
SizeType32 mRnnConvDimSize = 0;
int mLocalNbLayers = 0;
int mMaxBeamWidth = 0;

View File

@ -120,13 +120,27 @@ def main():
parser.add_argument('--tp-size', type=int, default=1)
parser.add_argument('--out-dir', type=Path, required=True)
parser.add_argument('--num-loras', type=int, default=1)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--adapter-size', type=int, default=8)
parser.add_argument('--hidden-size', type=int, default=16)
parser.add_argument('--mlp-hidden-size', type=int, default=32)
parser.add_argument('--no-generate-cache-pages',
action='store_true',
default=False)
parser.add_argument(
'--config-ids-filter',
type=str,
default=None,
help=
"Comma separated list of ids to include. For example, use --config-ids-filter=0 for attn_qkv only."
)
args = parser.parse_args()
num_layers = 2
adapter_size = 8
hidden_size = 16
mlp_hidden_size = 32
num_layers = args.num_layers
adapter_size = args.adapter_size
hidden_size = args.hidden_size
mlp_hidden_size = args.mlp_hidden_size
configs = [
(0, num_layers, adapter_size, hidden_size, 3 * hidden_size), # attn_qkv
(1, num_layers, adapter_size // 2, hidden_size, hidden_size), # attn_q
@ -149,6 +163,9 @@ def main():
(12, num_layers, adapter_size, hidden_size,
hidden_size), # cross_attn_dense
]
if args.config_ids_filter:
config_ids_filter = [int(x) for x in args.config_ids_filter.split(",")]
configs = [c for c in configs if c[0] in config_ids_filter]
for lora_idx in range(args.num_loras):
all_source = []
@ -178,19 +195,20 @@ def main():
os.makedirs(output_dir, exist_ok=True)
# copy weights into cache pages
for rank in range(args.tp_size):
page_block = torch.zeros((8, 18, 128),
dtype=torch.float32,
device='cpu')
copy_to_cache_pages(all_source,
all_config,
page_block,
configs,
tp_rank=rank,
tp_size=args.tp_size)
if not args.no_generate_cache_pages:
for rank in range(args.tp_size):
page_block = torch.zeros((8, 18, 128),
dtype=torch.float32,
device='cpu')
copy_to_cache_pages(all_source,
all_config,
page_block,
configs,
tp_rank=rank,
tp_size=args.tp_size)
out_path = output_dir / f'cache_pages_rank{rank}.npy'
np.save(out_path, page_block)
out_path = output_dir / f'cache_pages_rank{rank}.npy'
np.save(out_path, page_block)
source_out_path = output_dir / 'source.npy'
config_out_path = output_dir / 'config.npy'

View File

@ -15,6 +15,7 @@
# limitations under the License.
import argparse as _arg
import copy
import logging as _log
import os as _os
import pathlib as _pl
@ -135,9 +136,18 @@ def run_tests(build_dir: _pl.Path,
"--num-loras=128",
]
generate_gpt2_lora_data_args_tp1 = [
python_exe,
str(resources_dir / "scripts" / "generate_test_lora_weights.py"),
"--out-dir=cpp/tests/resources/data/lora-test-weights-gpt2-tp1",
"--tp-size=1", "--hidden-size=768", "--num-layers=12",
"--config-ids-filter=0", "--no-generate-cache-pages"
]
run_command(generate_lora_data_args_tp1, cwd=root_dir, timeout=100)
run_command(generate_lora_data_args_tp2, cwd=root_dir, timeout=100)
run_command(generate_multi_lora_tp2_args, cwd=root_dir, timeout=100)
run_command(generate_gpt2_lora_data_args_tp1, cwd=root_dir, timeout=100)
if not skip_unit_tests:
run_unit_tests(build_dir=build_dir, timeout=test_timeout)
@ -484,9 +494,15 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
]
run_command(trt_model_test, cwd=tests_dir, env=cpp_env,
timeout=timeout) # expecting ~ 1200s
cpp_blocking_env = copy.copy(cpp_env)
cpp_blocking_env["CUDA_LAUNCH_BLOCKING"] = '1'
run_command(trt_model_test,
cwd=tests_dir,
env=cpp_blocking_env,
timeout=timeout) # expecting ~ 1200s
#Executor test in leader mode
new_env = cpp_env
new_env = copy.copy(cpp_env)
xml_output_file = build_dir / "results-multi-gpu-llama-exec-leader-mode.xml"
new_env["RUN_LLAMA_MULTI_GPU"] = "true"
trt_model_test = [
@ -507,7 +523,7 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500)
#EncDec test in leader mode
new_env = cpp_env
new_env = copy.copy(cpp_env)
xml_output_file = build_dir / "results-multi-gpu-t5-exec-leader-mode.xml"
trt_model_test = [
"mpirun", "-n", "4", "--allow-run-as-root", "executor/executorTest",

View File

@ -52,7 +52,7 @@ The `awaitResponses` method of the `Executor` class returns a vector of response
### The Result Class
The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = false` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request.
The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request.
## C++ Executor API Example

View File

@ -23,7 +23,7 @@ git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
(quick-start-guide-compile)=
## Compile the Model into a TensorRT Engine
Use the included [Llama model definition](https://nvidia.github.io/TensorRT-LLM/_modules/tensorrt_llm/models/llama/model.html#LLaMAModel). This is a minimal example that includes some of the optimizations available in TensorRT-LLM.
Use the included [Llama model definition](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama). This is a minimal example that includes some of the optimizations available in TensorRT-LLM.
```bash
# Launch the Tensorrt-LLM container
@ -138,3 +138,7 @@ In this Quick Start Guide, you:
For more examples, refer to:
- [examples/](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for showcases of how to run a quick benchmark on latest LLMs.
## Links
- [Best Practices Guide](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/perf-best-practices.md)
- [Support Matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html)

View File

@ -83,7 +83,7 @@ The following table shows the supported software for TensorRT-LLM.
- [mT5](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/enc_dec)
- [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt)
- [Phi-1.5/Phi-2/Phi-3](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/phi)
- [Qwen](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen)
- [Qwen/Qwen1.5/Qwen2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen)
- [Qwen-VL](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwenvl)
- [RecurrentGemma](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/recurrentgemma)
- [Replit Code](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mpt)
@ -103,6 +103,7 @@ The following table shows the supported software for TensorRT-LLM.
- [Fuyu](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
- [Kosmos](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
- [LLaVA-v1.5](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
- [LLaVa-Next](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
- [NeVA](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
- [Nougat](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)
- [Phi-3-vision](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal)

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.15.0
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
protobuf

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -29,8 +29,10 @@ from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer,
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm._ipc_utils import set_peer_access
from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
@ -387,19 +389,27 @@ class TRTLLMEncDecModel:
(max_input_length, ),
dtype=hidden_states_dtype('max_input_length'),
device=self.device).contiguous()
batch_size = input_lengths.size(0)
inputs['host_request_types'] = torch.IntTensor([0] *
batch_size).to('cpu')
if self.encoder_model_config.remove_input_padding:
inputs['host_context_lengths'] = input_lengths.to('cpu')
if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:
if self.encoder_model_config.use_custom_all_reduce and self.encoder_runtime_mapping.tp_size > 1:
set_peer_access(self.encoder_runtime_mapping)
ipc_buffers, all_reduce_workspace = CustomAllReduceHelper.allocate_workspace(
self.encoder_runtime_mapping,
CustomAllReduceHelper.max_workspace_size_auto(
self.encoder_runtime_mapping.tp_size))
inputs['all_reduce_workspace'] = all_reduce_workspace
if self.encoder_model_config.lora_plugin:
inputs.update(
self.encoder_lora_manager.input_buffers(
self.lora_task_uids,
self.encoder_runtime_mapping,
self.encoder_model_config.num_layers,
))
batch_size = input_lengths.size(0)
inputs['host_request_types'] = torch.IntTensor([0] *
batch_size).to('cpu')
if self.encoder_model_config.remove_input_padding:
inputs['host_context_lengths'] = input_lengths.to('cpu')
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
self.encoder_session.set_shapes(inputs)

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
transformers>=4.31.0
datasets~=2.14.5
evaluate~=0.4.1

View File

@ -3,7 +3,7 @@
# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
flax~=0.8.0
# jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
jax~=0.4.19; platform_system == "Windows"

View File

@ -414,6 +414,21 @@ trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu \
Note that GPT attention plugin is required to be enabled for SmoothQuant for now.
User can also use `ModelOpt` to do INT8 quantization. Especially for gpt variant Starcoder2.
```bash
python3 example/quantization/quantize.py --model_dir starcoder2 \
--dtype float16 \
--qformat int8_sq \
--output_dir starcoder2/trt_ckpt/int8-sq/
```
Then, use `trtllm-build` to build engine(s).
```bash
trtllm-build --checkpoint_dir starcoder2/trt_ckpt/int8-sq/ \
--output_dir starcoder2/trt_engine/int8-sq/ \
--builder_opt 4
```
### INT8 KV Cache

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
rouge_score~=0.1.2
evaluate~=0.4.1

View File

@ -1,6 +1,6 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets==2.14.6
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -4,12 +4,49 @@ Here we show you a preview of how it works and how to use it.
Note that the APIs are not stable and only support the LLaMA model. We appreciate your patience and understanding as we improve this API.
## Quick start
Please install the required packages first:
```bash
pip install -r requirements.txt
```
Here is a simple example to show how to use the HLAPI:
Firstly, import the `LLM` and `SamplingParams` from the `tensorrt_llm` package, and create an LLM object with a HuggingFace (HF) model directly. Here we use the TinyLlama model as an example, `LLM` will download the model from the HuggingFace model hub automatically. You can also specify local models, either in HF format, TensorRT-LLM engine format or TensorRT-LLM checkpoint format.
```python
from tensorrt_llm import LLM, SamplingParams
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
```
Secondly, generate text with the `generate` method of the `LLM` object directly with a batch of prompts, the `sampling_params` is optional, and you can customize the sampling strategy with it.
```python
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Please refer to the [LLM quickstart](./quickstart_example.py) for the complete example.
## Examples
You can refer to [llm_examples.py](llm_examples.py) for all of the examples, and run it with the [run_examples.py](./run_examples.py) script, the command is as follows:
```sh
@ -34,15 +71,16 @@ python3 llm_examples.py --task run_llm_on_tensor_parallel \
```
## Model preparation
The HLAPI supports three kinds of model formats:
The `LLM` class supports four kinds of model inputs:
1. HuggingFace models
2. TensorRT-LLM engine built by trtllm-build tool or saved by the HLAPI
3. TensorRT-LLM checkpoints, converted by `convert_checkpoint.py` in examples
1. **HuggingFace model name**: triggers a download from the HuggingFace model hub, e.g. `TinyLlama/TinyLlama-1.1B-Chat-v1.0` in the quickstart.
1. **Local HuggingFace models**: uses a locally stored HuggingFace model.
2. **Local TensorRT-LLM engine**: built by `trtllm-build` tool or saved by the HLAPI
3. **Local TensorRT-LLM checkpoints**: converted by `convert_checkpoint.py` script in the examples
All kinds of models could be used directly by the HLAPI, and the `LLM(model=<any-model-path>)` could accept any kind of them.
All kinds of the model inputs can be seamlessly integrated with the HLAPI, and the `LLM(model=<any-model-path>)` construcotr can accommodate models in any of the above formats.
Let's elaborate on the preparation of the three kinds of model formats.
Let's delve into the preparation of the three kinds of local model formats.
### Option 1: From HuggingFace models
@ -143,16 +181,16 @@ It is easy to enable Tensor Parallelism in the HLAPI. For example, setting `para
```python
from tensorrt_llm.hlapi import LLM
llm = LLM(<llama_model_path>, tensor_parallel_size=2)
llm = LLM(<llama_model_path>,
tensor_parallel_size=2)
```
### Pipeline Parallelism
Similar to Tensor Parallelism, you can enable Pipeline Parallelism in the HLAPI with following code:
```python
config.parallel_config.pp_size = 4
# you can also mix TP and PP
# config.parallel_config.tp_size = 2
llm = LLM(<llama_model_path>,
pipeline_parallel_size=4)
```
### Automatic Parallelism (in preview)
@ -266,17 +304,28 @@ Please refer to these classes for more details.
## LLM pipeline configuration
### Runtime customization
### Build configuration
Apart from the arguments mentioned above, you can also customize the build configuration with the `build_config` class and other arguments borrowed from the lower-level APIs. For example:
```python
llm = LLM(<model-path>,
build_config=BuildConfig(
max_new_tokens=4096,
max_batch_size=128,
max_beam_width=4))
```
### Runtime customization
Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other arguments borrowed from the lower-level APIs. For example:
For `kv_cache_config` and `streaming_llm` features, please refer to LLaMA's [README](../llama/README.md) for more details, the high-level API supports these features as well by setting the corresponding fields in the `LLM()` constructor.
```python
from tensorrt_llm.hlapi import LLM, KvCacheConfig
llm = LLM(<llama_model_path>,
kv_cache_config=KvCacheConfig(
max_new_tokens=128,
free_gpu_memory_fraction=0.8))
max_new_tokens=128,
free_gpu_memory_fraction=0.8))
```
### Tokenizer customization
@ -313,3 +362,13 @@ RequestOutput(request_id=1, prompt=None, prompt_token_ids=[1, 15043, 29892, 590,
```
Note that the `text` field in `CompletionOutput` is empty since the tokenizer is deactivated.
### Build caching
Although the HLAPI runs the engine building in the background, you can also cache the built engine to disk and load it in the next run to save the engine building time.
To enable the build cache, there are two ways to do it:
1. Use the environment variable: `export TLLM_HLAPI_BUILD_CACHE=1` to enable the build cache globally, and optionally export `TLLM_HLAPI_BUILD_CACHE_ROOT` to specify the cache root directory.
2. Pass the `build_cache_config` to the `LLM` constructor
The build cache will reuse the built engine if all the building settings are the same, or it will rebuild the engine.

View File

@ -1,2 +1,2 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets==2.14.5
rouge_score~=0.1.2
sentencepiece~=0.1.99

View File

@ -440,11 +440,17 @@ expected results:
#### 1M long context test case
- Prepare 1M needle-in-a-haystack datasets
```bash
python examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 7
```
- Llama-3-8B example
```bash
git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/
python examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 7
python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
--output_dir /tmp/llama-3-8B-1048k/trt_ckpts \
--dtype float16 \
@ -454,8 +460,8 @@ python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt
--output_dir /tmp/llama-3-8B-1048k/trt_engines \
--gemm_plugin float16 \
--max_num_tokens 4096 \
--max_input_len 1048576 \
--max_output_len 10 \
--max_input_len 1048566 \
--max_seq_len 1048576 \
--use_paged_context_fmha enable \
--workers 4
@ -463,7 +469,37 @@ mpirun -n 4 --allow-run-as-root python examples/eval_long_context.py --task pas
--engine_dir /tmp/llama-3-8B-1048k/trt_engines \
--tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
--stop_idx 1 \
--max_input_length 1048576 \
--max_input_length 1048566 \
--enable_chunked_context \
--max_tokens_in_paged_kv_cache 1100000
```
- Llama-3-70B example
For the 70B model, at least 8 A100 80GB GPUs are required.
```bash
git-lfs clone https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k/
python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-70B-Instruct-Gradient-1048k/ \
--output_dir /tmp/llama-3-70B-1048k/trt_ckpts \
--dtype float16 \
--tp_size 8
python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-70B-1048k/trt_ckpts \
--output_dir /tmp/llama-3-70B-1048k/trt_engines \
--gemm_plugin float16 \
--max_num_tokens 4096 \
--max_input_len 1048566 \
--max_seq_len 1048576 \
--use_paged_context_fmha enable \
--workers 8
mpirun -n 8 --allow-run-as-root python examples/eval_long_context.py --task passkey \
--engine_dir /tmp/llama-3-70B-1048k/trt_engines \
--tokenizer_dir ./Llama-3-70B-Instruct-Gradient-1048k/ \
--stop_idx 1 \
--max_input_length 1048566 \
--enable_chunked_context \
--max_tokens_in_paged_kv_cache 1100000
```

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets==2.14.6
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -2,6 +2,15 @@
This document shows how to build and run a [Mamba](https://github.com/state-spaces/mamba) model in TensorRT-LLM on a single GPU.
- [Mamba](#mamba)
- [Overview](#overview)
- [Support Matrix](#support-matrix)
- [Usage](#usage)
- [1. Download weights from HuggingFace Transformers](#1-download-weights-from-huggingface-transformers)
- [2. Convert weights from HF Transformers to TensorRT-LLM format](#2-convert-dweights-from-hf-transformers-to-tensorrt-llm-format)
- [3. Build TensorRT engine(s)](#3-build-tensorrt-engines)
- [4. Run summarization task with the TensorRT engine(s)](#4-run-summarization-task-with-the-tensorrt-engines)
## Overview
The TensorRT-LLM Mamba implementation can be found in [`tensorrt_llm/models/mamba/model.py`](../../tensorrt_llm/models/mamba/model.py). The TensorRT-LLM Mamba example code is located in [`examples/mamba`](./). There is one main file:
@ -15,8 +24,13 @@ In addition, there are two shared files in the parent folder [`examples`](../) f
## Support Matrix
* FP16
* BF16
| Model Name | FP16 | BF16 |
| :--------------: | :---: | :---: |
| Mamba1 | Y | Y |
| Mamba2 | Y | Y |
* Mamba2: TensorRT-LLM can only support the pure Mamba model for now, will support the hybrid models later.
## Usage
@ -32,23 +46,20 @@ pip install -r requirements.txt
git lfs install
```
There are five HF checkpoints available. Use one of the following commands to fetch the checkpoint you are interested in.
There are different HF checkpoints available. For Mamba1, TensorRT-LLM can support those Transformers compatible models. Here're some examples to fetch the checkpoint.
```bash
# mamba-2.8b
git clone https://huggingface.co/state-spaces/mamba-2.8b-hf ./mamba_model/mamba-2.8b
# mamba-1.4b
git clone https://huggingface.co/state-spaces/mamba-1.4b-hf ./mamba_model/mamba-1.4b
# mamba-790m
git clone https://huggingface.co/state-spaces/mamba-790m-hf ./mamba_model/mamba-790m
# mamba-370m
git clone https://huggingface.co/state-spaces/mamba-370m-hf ./mamba_model/mamba-370m
# mamba-130m
git clone https://huggingface.co/state-spaces/mamba-130m-hf ./mamba_model/mamba-130m
# mamba2-2.7b
git clone https://huggingface.co/state-spaces/mamba2-2.7b ./mamba_model/mamba2-2.7b
# mamba2-130m
git clone https://huggingface.co/state-spaces/mamba2-130m ./mamba_model/mamba2-130m
```
Since mamba models use tokenizer from gpt-neox-20b model, use the following command to fetch the checkpoint of gpt-neox-20b.
@ -67,25 +78,20 @@ python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \
--dtype bfloat16 \
--output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/
# mamba-1.4b
python convert_checkpoint.py --model_dir ./mamba_model/mamba-1.4b/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/
# mamba-790m
python convert_checkpoint.py --model_dir ./mamba_model/mamba-790m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/
# mamba-370m
python convert_checkpoint.py --model_dir ./mamba_model/mamba-370m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/
# mamba-130m
python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/
# mamba2-2.7b
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-2.7b/ \
--dtype float16 \
--output_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/
# mamba2-130m
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-130m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/
```
### 3. Build TensorRT engine(s)
@ -101,33 +107,6 @@ trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
# mamba-1.4b
trtllm-build --checkpoint_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/
# mamba-790m
trtllm-build --checkpoint_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/
# mamba-370m
trtllm-build --checkpoint_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/
# mamba-130m
trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
@ -136,6 +115,24 @@ trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
# mamba2-2.7b
trtllm-build --checkpoint_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
# mamba2-130m
trtllm-build --checkpoint_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
```
Note that when building Mamba models, you need to disable the `paged_kv_cache` as it is used for
@ -148,7 +145,6 @@ The following section describes how to run a TensorRT-LLM Mamba model to summari
[cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. For each summary, the script can compute the
[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation.
### Run
```bash
# mamba-2.8b
python ../summarize.py --test_trt_llm \
@ -157,31 +153,24 @@ python ../summarize.py --test_trt_llm \
--data_type bf16 \
--engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
# mamba-1.4b
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-1.4b/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/
# mamba-790m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-790m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/
# mamba-370m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-370m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/
# mamba-130m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-130m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
# mamba2-2.7b
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba2-2.7b/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
# mamba2-130m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba2-130m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
```

View File

@ -1,13 +1,17 @@
import argparse
import copy
import json
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import safetensors.torch
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils import CONFIG_NAME
from transformers.utils.hub import cached_file
import tensorrt_llm
from tensorrt_llm import logger
@ -55,7 +59,10 @@ def get_tllm_linear_weight(weight, prefix, bias=None):
return results
def convert_hf_mamba(hf_mamba, rank=0, dtype='float32'):
def convert_hf_mamba(hf_mamba,
rank=0,
dtype='float32',
mamba_version: str = 'Mamba1'):
weights = {}
tik = time.time()
@ -130,9 +137,12 @@ def rename_hf_to_tllm(name: str):
# change layer name
if 'embeddings.' in name:
name = name.replace('embeddings', 'vocab_embedding')
elif 'embedding.' in name:
name = name.replace('embedding', 'vocab_embedding')
norm_pattern = r'\d\.norm\.'
if 'mixer.' in name:
name = name.replace('mixer.', 'ssm.')
elif 'norm.' in name:
elif re.search(norm_pattern, name):
name = name.replace('norm.', 'input_layernorm.')
elif 'norm_f.' in name:
name = name.replace('norm_f.', 'ln_f.')
@ -147,7 +157,8 @@ def rename_hf_to_tllm(name: str):
def convert_from_hf_checkpoint(model_dir: Union[str, Path],
rank=0,
dtype: Union[str, torch.dtype] = torch.float32):
dtype: Union[str, torch.dtype] = torch.float32,
mamba_version: str = 'Mamba1'):
logger.info('Loading weights from HF Mamba...')
tik = time.time()
@ -164,15 +175,19 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path],
param = param.detach().cpu()
if 'A_log' in name:
param = -torch.exp(param.float())
param = param.permute(1, 0).contiguous()
if mamba_version == 'Mamba1':
param = param.permute(1, 0).contiguous()
elif 'D' in name:
param = param.float()
elif 'dt_proj.bias' in name:
param = param.float()
elif 'dt_bias' in name:
param = param.float()
elif 'conv1d.weight' in name:
param = param.unsqueeze(3)
if 'in_proj' in name:
# split in_proj in Mamba1
if 'in_proj' in name and mamba_version == 'Mamba1':
in_proj_params = torch.split(param, param.size(0) // 2, dim=0)
weights[tllm_name.replace('proj', 'proj_x')] = in_proj_params[0]
weights[tllm_name.replace('proj', 'proj_z')] = in_proj_params[1]
@ -181,9 +196,10 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path],
del model_params
# lm_head
if 'lm_head.weight' not in weights:
weights['lm_head.weight'] = copy.deepcopy(
weights['backbone.vocab_embedding.weight'])
emb = weights['backbone.vocab_embedding.weight']
if 'lm_head.weight' not in weights or weights['lm_head.weight'].data_ptr(
) == emb.data_ptr():
weights['lm_head.weight'] = copy.deepcopy(emb)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
@ -208,6 +224,72 @@ def convert(worker_rank, args, convert_args):
args.output_dir / f'rank{rank}.safetensors')
@dataclass
class MambaConfig:
d_model: int = 2560
d_intermediate: int = 0
n_layer: int = 64
vocab_size: int = 50277
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
tie_embeddings: bool = True
hidden_size: int = 2560
num_hidden_layers: int = 64
intermediate_size: int = 0
state_size: int = 128
conv_kernel: int = 4
use_bias: bool = False
headdim: int = 64
ngroups: int = 1
chunk_size: int = 256
ssm_rmsnorm: bool = True
def update(self, data_dict):
self.__dict__.update(data_dict)
def load_config_hf(model_name):
resolved_archive_file = cached_file(
model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
config = json.load(open(resolved_archive_file))
if 'transformers_version' in config: # transformer compatible models
hf_config = AutoConfig.from_pretrained(model_name,
trust_remote_code=True)
# TODO: change mamba_version when transformers can support Mamba2 models
mamba_version = 'Mamba1'
else: # state-spaces/mamba models
hf_config = MambaConfig(**config)
hf_config.hidden_size = hf_config.d_model
hf_config.num_hidden_layers = hf_config.n_layer
if 'expand' in hf_config.ssm_cfg:
expand = hf_config.ssm_cfg['hf_config']
hf_config.intermediate_size = expand * hf_config.d_model
else:
hf_config.intermediate_size = 2 * hf_config.d_model
ssm_cfg_to_hf_cfg = {
'd_state': 'state_size',
'd_conv': 'conv_kernel',
'bias': 'use_bias',
'headdim': 'headdim',
'ngroups': 'ngroups',
'chunk_size': 'chunk_size',
'rmsnorm': 'ssm_rmsnorm',
}
cfg_dict = {}
for k, v in hf_config.ssm_cfg.items():
if k in ssm_cfg_to_hf_cfg:
cfg_dict[ssm_cfg_to_hf_cfg[k]] = v
hf_config.update(cfg_dict)
mamba_version = hf_config.ssm_cfg.pop("layer", "Mamba1")
return hf_config, mamba_version
def main():
print(tensorrt_llm.__version__)
@ -217,8 +299,8 @@ def main():
args.output_dir.mkdir(exist_ok=True, parents=True)
hf_config = AutoConfig.from_pretrained(args.model_dir,
trust_remote_code=True)
hf_config, mamba_version = load_config_hf(args.model_dir)
vocab_size = hf_config.vocab_size
pad_vocab_size_multiple = hf_config.pad_vocab_size_multiple
if vocab_size % pad_vocab_size_multiple != 0:
@ -239,15 +321,29 @@ def main():
'hidden_act': 'silu',
'num_attention_heads': 1,
'rnn_hidden_size': hf_config.intermediate_size,
'rnn_conv_dim_size': hf_config.intermediate_size,
'state_size': hf_config.state_size,
'conv_kernel': hf_config.conv_kernel,
'use_bias': hf_config.use_bias,
'mamba_version': mamba_version,
}
if mamba_version == 'Mamba2':
conv_dim = hf_config.intermediate_size + 2 * hf_config.ngroups * hf_config.state_size
mamba2_cfg = {
'rnn_head_size': hf_config.headdim,
'rnn_conv_dim_size': conv_dim,
'ngroups': hf_config.ngroups,
'chunk_size': hf_config.chunk_size,
'ssm_rmsnorm': hf_config.ssm_rmsnorm,
}
config.update(mamba2_cfg)
with (args.output_dir / 'config.json').open('w') as f:
json.dump(config, f, indent=4)
convert_from_ckpt = do_convert_from_ckpt(args)
# TODO: Add convert_hf_mamba support for Mamba2 when transformers can support Mamba2 models
assert convert_from_ckpt or mamba_version == 'Mamba2', "Mamba2 can only support convert from checkpoints."
if not convert_from_ckpt:
logger.info(f'Convert by using model')
hf_mamba = AutoModelForCausalLM.from_pretrained(args.model_dir,
@ -264,6 +360,7 @@ def main():
convert_args['model_dir'] = args.model_dir
else:
convert_args['hf_mamba'] = hf_mamba
convert_args['mamba_version'] = mamba_version
convert(0, args, convert_args)

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
transformers>=4.39.0
datasets~=2.14.5
evaluate

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
rouge_score~=0.1.2
sentencepiece~=0.1.99

View File

@ -1,4 +1,4 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
transformers==4.38.2
accelerate==0.25.0

View File

@ -255,6 +255,7 @@ class Pipeline:
self.pad_id = pad_id
self.end_id = end_id
self.max_attention_window_size = max_attention_window_size
self.output_len = 2
def __call__(self, prompt):
rank = tensorrt_llm.mpi_rank()
@ -263,7 +264,7 @@ class Pipeline:
batch_input_ids = [inputs]
# For multi-choice tasks like MMLU, we don't need to adjust following parameters
output_len = 2
output_len = self.output_len
top_k = 1
top_p = 0.0
@ -313,7 +314,8 @@ class Pipeline:
def check_valid_length(self, prompt):
if isinstance(self.model, nn.Module):
return True
return len(self.tokenizer.encode(prompt)) <= self.model.max_input_len
input_len = len(self.tokenizer.encode(prompt))
return input_len <= self.model.max_input_len and input_len + self.output_len <= self.model.max_seq_len
def parse_args():
@ -391,7 +393,7 @@ def main():
model = auto_model_cls.from_pretrained(
args.hf_model_dir,
trust_remote_code=True,
torch_dtype=DTYPE_STR_MAPPING[args.data_type],
torch_dtype=DTYPE_STR_MAPPING[args.hf_data_type],
device_map="auto" if args.hf_device_map_auto else None,
)
if not args.hf_device_map_auto:

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -12,7 +12,7 @@ We first describe how to run each model on a single GPU. We then provide general
- [Deplot](#deplot)
- [Fuyu](#fuyu)
- [Kosmos-2](#kosmos-2)
- [LLaVA and VILA](#llava-and-vila)
- [LLaVA, LLaVa-NeXT and VILA](#llava-llava-next-and-vila)
- [NeVA](#neva)
- [Nougat](#nougat)
- [Phi-3-vision](#phi-3-vision)
@ -361,9 +361,9 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu
```
## LLaVA and VILA
## LLaVA, LLaVa-NeXT and VILA
[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options.
[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options. [LLaVA-NeXT](https://huggingface.co/collections/llava-hf/llava-next-65f75c4afac77fd37dbbe6cf) is an extension of LLaVA. TRT-LLM currently supports [Mistral-7b](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) and [ Nous-Hermes-2-Yi-34B](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) variant of LLaVA-NeXT.
1. Download Huggingface model weights. These models have both visual and LLM components
unlike BLIP2 example which downloads only LLM components from Huggingface.
@ -374,6 +374,12 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
export MODEL_NAME="llava-1.5-7b-hf" # also llava-1.5-13b-hf
git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
```
For LLaVA-NeXT,
```bash
export MODEL_NAME="llava-v1.6-mistral-7b-hf" #for 34b variant "llava-v1.6-34b-hf"
git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
```
For VILA, we need a few more steps until it is added to HF model zoo
@ -408,6 +414,18 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
--max_seq_len 2560 \
--max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) for LLaVA
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
--gpt_attention_plugin float16 \
--gemm_plugin float16 \
--max_batch_size 1 \
--max_input_len 4096 \
--max_seq_len 5120 \
--max_num_tokens 4096 \ # 1 (max_batch_size) * 4096 (max_input_len)
--max_multimodal_len 4096 \ # 1 (max_batch_size) * 4096 (max_input_len)
--use_fused_mlp
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
@ -426,6 +444,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
```bash
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava # for LLaVA
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava_next --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 5 # 1 (max_batch_size) * 5 (because LLAVA-NeXT visual encoder can have at most 5 patches) # for LLaVA-NeXT
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type vila --vila_path ${VILA_PATH} # for VILA
```
@ -435,7 +455,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--visual_engine_dir visual_engines/${MODEL_NAME} \
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
--input_text "Question: which city is this? Answer:" # for LLaVA
--input_text "Question: which city is this? Answer:" # for LLaVA and for LLaVA-NeXT
```
For VILA, you can use either local file or web url as input images.

View File

@ -11,13 +11,6 @@ import yaml
import torch
import tensorrt as trt
from tensorrt_llm.builder import Builder
# isort: on
import json
import math
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import save_file
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoModelForVision2Seq, AutoProcessor,
Blip2ForConditionalGeneration, Blip2Processor,
@ -25,6 +18,13 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
LlavaForConditionalGeneration, NougatProcessor,
Pix2StructForConditionalGeneration,
VisionEncoderDecoderModel)
# isort: on
import json
import math
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import save_file
def parse_arguments():
@ -33,8 +33,8 @@ def parse_arguments():
type=str,
default=None,
choices=[
'blip2', 'llava', 'vila', 'nougat', 'cogvlm',
'fuyu', 'pix2struct', 'neva', 'kosmos-2',
'blip2', 'llava', 'llava_next', 'vila', 'nougat',
'cogvlm', 'fuyu', 'pix2struct', 'neva', 'kosmos-2',
'video-neva', 'phi-3-vision'
],
help="Model type")
@ -80,7 +80,7 @@ class VisionEngineBuilder:
build_blip2_engine(args)
elif args.model_type == 'pix2struct':
build_pix2struct_engine(args)
elif args.model_type == 'llava':
elif 'llava' in args.model_type:
build_llava_engine(args)
elif args.model_type == 'vila':
assert args.vila_path is not None, "Please clone and provide VILA source code path"
@ -305,30 +305,59 @@ def build_pix2struct_engine(args):
def build_llava_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new('RGB', [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image,
return_tensors="pt")['pixel_values'].to(
args.device, torch.float16)
if args.model_type == "llava":
raw_image = Image.new('RGB', [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image,
return_tensors="pt")['pixel_values'].to(
args.device, torch.float16)
class LlavaVisionWrapper(torch.nn.Module):
class LlavaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
model = LlavaForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16)
wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device),
model.multi_modal_projector.to(args.device),
model.config.vision_feature_layer)
model = LlavaForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16)
wrapper = LlavaVisionWrapper(
model.vision_tower.to(args.device),
model.multi_modal_projector.to(args.device),
model.config.vision_feature_layer)
elif args.model_type == "llava_next":
from transformers import LlavaNextForConditionalGeneration
raw_image = Image.new('RGB', [512, 512])
image = processor(text="dummy", images=raw_image,
return_tensors="pt")['pixel_values'].to(
args.device, torch.float16)[0]
class LlavaNextVisionWrapper(torch.nn.Module):
def __init__(self, vision_tower, projector):
super().__init__()
self.vision_tower = vision_tower
self.projector = projector
def forward(self, pixel_values):
image_features = self.vision_tower(pixel_values,
output_hidden_states=True)
selected_image_feature = image_features.hidden_states[-2][:, 1:]
image_features = self.projector(selected_image_feature)
return image_features # (bs, 576, c)
model = LlavaNextForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16)
wrapper = LlavaNextVisionWrapper(
model.vision_tower.vision_model.to(args.device),
model.multi_modal_projector.to(args.device),
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
@ -336,6 +365,11 @@ def build_llava_engine(args):
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size)
if args.model_type == "llava_next":
image_newline = model.image_newline.data
tensor_img_newline = {"image_newline": image_newline}
save_file(tensor_img_newline,
os.path.join(args.output_dir, "image_newline.safetensors"))
def build_vila_engine(args):
@ -517,7 +551,12 @@ def build_neva_engine(args):
vision_x = self.connector(vision_x)
return vision_x
encoder = AutoModel.from_pretrained(vision_config["from_pretrained"],
vision_path = vision_config["from_pretrained"]
joined_path = os.path.join(os.path.dirname(args.model_path),
os.path.basename(vision_path))
if os.path.isdir(joined_path):
vision_path = joined_path
encoder = AutoModel.from_pretrained(vision_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
vision_encoder = encoder.vision_model

View File

@ -95,6 +95,130 @@ def trt_dtype_to_torch(dtype):
raise TypeError("%s is not supported" % dtype)
class LlavaNextUtils:
# https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
@staticmethod
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(
original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height,
original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
@staticmethod
def get_anyres_image_grid_shape(image_size, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
IMAGE_GRID_PINPOINTS = [[336, 672], [672, 336], [672, 672], [1008, 336],
[336, 1008]]
width, height = LlavaNextUtils.select_best_resolution(
image_size, IMAGE_GRID_PINPOINTS)
return width // patch_size, height // patch_size
@staticmethod
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (width, height).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
@staticmethod
def rearrange_image_features(image_feature, image_newline, image_size):
"""
Combine PyTorch feature grids from image patches.
Args:
image_feature (torch.Tensor): The feature grids, assumed to be in NxCxHxW format.
image_newline (torch.Tensor): The newline embedding.
image_size (tuple): Size of the original image (width, height).
"""
CLIP_IMAGE_SIZE = 336
CLIP_PATCH_SIZE = 14
NUM_PATCHES_PER_SIDE = CLIP_IMAGE_SIZE // CLIP_PATCH_SIZE
if image_feature.shape[0] == 1:
return torch.cat((image_feature, image_newline[None]), dim=0)
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = NUM_PATCHES_PER_SIDE
assert height * width == base_image_feature.shape[0]
num_patch_width, num_patch_height = LlavaNextUtils.get_anyres_image_grid_shape(
image_size, CLIP_IMAGE_SIZE)
image_feature = image_feature.view(num_patch_height, num_patch_width,
height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = LlavaNextUtils.unpad_image(image_feature, image_size)
image_feature = torch.cat(
(image_feature, image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1)),
dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
return image_feature
class MultimodalModelRunner:
def __init__(self, args):
@ -123,7 +247,9 @@ class MultimodalModelRunner:
if self.model_type == 'video-neva':
self.num_frames = config['builder_config'].get('num_frames', None)
if self.model_type == "llava_next":
self.llm_name = AutoConfig.from_pretrained(
args.hf_model_dir).text_config._name_or_path
self.profiling_iterations = 20
self.init_image_encoder()
@ -203,6 +329,14 @@ class MultimodalModelRunner:
device="cuda") as f:
for k in f.keys():
self.image_newlines[k] = f.get_tensor(k)
if self.model_type == "llava_next":
self.image_newlines = {}
image_newlines_path = os.path.join(self.args.visual_engine_dir,
'image_newline.safetensors')
with safe_open(image_newlines_path, framework="pt",
device="cuda") as f:
for k in f.keys():
self.image_newlines[k] = f.get_tensor(k)
def init_llm(self):
if self.decoder_llm:
@ -276,6 +410,12 @@ class MultimodalModelRunner:
image = input['pixel_values']
bs = image.shape[0]
image = image.flatten(0, 1)
elif self.model_type == 'llava_next':
input = image
image = input['pixel_values']
bs = image.shape[0]
image = image[0]
image_size = input['image_sizes'][0].cpu()
if not warmup:
profiler.start("Vision")
@ -366,6 +506,13 @@ class MultimodalModelRunner:
input_ids = self.ptuning_setup_phi3(visual_features, input_ids,
num_img_tokens)
length = input_ids.shape[1]
elif self.model_type == 'llava_next':
visual_features = LlavaNextUtils.rearrange_image_features(
visual_features, self.image_newlines["image_newline"],
image_size)
input_ids = self.ptuning_setup_llava_next(visual_features,
pre_prompt, post_prompt)
length = input_ids.shape[1]
else:
pre_input_ids = self.tokenizer(pre_prompt,
return_tensors="pt",
@ -387,7 +534,9 @@ class MultimodalModelRunner:
input_lengths = torch.IntTensor([length] * args.batch_size).to(
torch.int32)
if self.model_type in ['fuyu', 'kosmos-2', 'phi-3-vision']:
if self.model_type in [
'fuyu', 'kosmos-2', 'phi-3-vision', 'llava_next'
]:
return input_ids, input_lengths, [visual_features], visual_features
input_ids, ptuning_args = self.setup_fake_prompts(
@ -667,6 +816,19 @@ class MultimodalModelRunner:
res_input_ids.append(cur_input_ids)
return res_input_ids
def ptuning_setup_llava_next(self, visual_features, pre_prompt,
post_prompt):
input_ids = []
fake_prompt_ids = list(
range(self.model_config.vocab_size,
self.model_config.vocab_size + visual_features.shape[0]))
input_ids = self.tokenizer.encode(
pre_prompt[0]) + fake_prompt_ids + self.tokenizer.encode(
post_prompt[0])[self.tokenizer.add_bos_token:]
input_ids = [input_ids] * len(pre_prompt)
input_ids = torch.tensor(input_ids)
return input_ids
def ptuning_setup_phi3(self, visual_features, input_ids, num_img_tokens):
fake_prompt_id = torch.arange(
self.model_config.vocab_size,
@ -869,6 +1031,32 @@ class MultimodalModelRunner:
pre_prompt = """<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n<extra_id_1>User"""
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n" ""
elif self.model_type == "llava_next":
if self.llm_name == "mistralai/Mistral-7B-Instruct-v0.2":
pre_prompt = "[INST] "
if input_text is None:
input_text = "Question: which city is this? Answer:"
post_prompt = f"\n{input_text} [/INST]"
prompt = pre_prompt + post_prompt
elif self.llm_name == "NousResearch/Nous-Hermes-2-Yi-34B":
pre_prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n"
if input_text is None:
input_text = "Question: which city is this? Answer:"
post_prompt = f"\n{input_text}<|im_end|><|im_start|>assistant\n"
prompt = pre_prompt + post_prompt
else:
raise Exception(
f"Prompt template for {self.llm_name} for not included currently"
)
processor = AutoProcessor.from_pretrained(args.hf_model_dir,
trust_remote_code=True)
image = processor(text=prompt,
images=raw_image,
return_tensors="pt")
elif self.model_type in ['llava', 'vila', 'fuyu', 'kosmos-2']:
# LLaVA and VILA
if self.model_type == "llava":
@ -924,7 +1112,8 @@ class MultimodalModelRunner:
pre_prompt = [pre_prompt] * self.args.batch_size
post_prompt = [post_prompt] * self.args.batch_size
if self.model_type not in [
'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision'
'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision',
'llava_next'
]:
if image.dim() == 5:
image = image.expand(args.batch_size, -1, -1, -1,
@ -932,7 +1121,6 @@ class MultimodalModelRunner:
else:
image = image.expand(args.batch_size, -1, -1, -1).contiguous()
image = image.to(self.device)
# Generate decoder_input_ids for enc-dec models
# Custom prompts can be added as:
# decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
@ -955,7 +1143,6 @@ class MultimodalModelRunner:
def run(self, input_text, input_image, max_new_tokens):
input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = model.setup_inputs(
input_text, input_image)
model.generate(pre_prompt,
post_prompt,
processed_image,
@ -999,7 +1186,9 @@ class MultimodalModelRunner:
elif self.model_type == "pix2struct":
assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[
0][0].lower()
elif self.model_type in ['blip2', 'neva', 'phi-3-vision']:
elif self.model_type in [
'blip2', 'neva', 'phi-3-vision', 'llava_next'
]:
assert 'singapore' in output_text[0][0].lower()
elif self.model_type == 'video-neva':
assert 'robot' in output_text[0][0].lower()

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
transformers==4.40.2
datasets~=2.14.5
evaluate~=0.4.1

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets>=2.14.4
nemo-toolkit[all]<=1.20.0,>=1.18.0
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.16.0
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.16.0
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -496,6 +496,7 @@ def main():
rnn_hidden_size=ckpt_config["lru_width"],
logits_soft_cap=ckpt_config["logits_soft_cap"],
emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"],
rnn_conv_dim_size=ckpt_config["lru_width"],
)
trt_llm_config_dict = trt_llm_config.to_dict()

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
git+https://github.com/google-deepmind/recurrentgemma.git
flax>=0.8.2
jax~=0.4.23

View File

@ -36,7 +36,6 @@ if PYTHON_BINDINGS:
def parse_arguments(args=None):
# see `add_common_args` for extended list of arguments
parser = argparse.ArgumentParser()
parser.add_argument('--max_input_length', type=int, default=923)
parser.add_argument('--max_output_len', type=int, required=True)
@ -319,18 +318,6 @@ def main(args):
"Debug mode is not supported in C++ session for now, fallback to Python session."
)
args.use_py_session = True
if args.return_all_generated_tokens and args.use_py_session:
raise ValueError(
"Returning all the generated tokens at each step is not supported in the Python session, use C++ session instead."
)
if (not args.return_all_generated_tokens) and args.streaming and (
args.num_beams > 1):
logger.warning(
"Setting return_all_generated_tokens to True since streaming AND beam search are done simultaneously. "
"Returning the full beams at each streaming step is needed because beam search + streaming can change previous outputs. "
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)."
)
args.return_all_generated_tokens = True
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
runner_kwargs = dict(
engine_dir=args.engine_dir,
@ -360,7 +347,8 @@ def main(args):
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
kv_cache_free_gpu_memory_fraction=args.
kv_cache_free_gpu_memory_fraction,
enable_chunked_context=args.enable_chunked_context)
enable_chunked_context=args.enable_chunked_context,
)
runner = runner_cls.from_dir(**runner_kwargs)
with torch.no_grad():
@ -394,8 +382,7 @@ def main(args):
output_sequence_lengths=True,
no_repeat_ngram_size=args.no_repeat_ngram_size,
return_dict=True,
medusa_choices=args.medusa_choices,
return_all_generated_tokens=args.return_all_generated_tokens)
medusa_choices=args.medusa_choices)
torch.cuda.synchronize()
if args.streaming:
@ -482,9 +469,7 @@ def main(args):
prompt_tasks=args.prompt_tasks,
streaming=args.streaming,
output_sequence_lengths=True,
return_dict=True,
return_all_generated_tokens=args.return_all_generated_tokens
)
return_dict=True)
torch.cuda.synchronize()
tensorrt_llm.profiler.start("tmp")
@ -516,9 +501,7 @@ def main(args):
prompt_tasks=args.prompt_tasks,
streaming=args.streaming,
output_sequence_lengths=True,
return_dict=True,
return_all_generated_tokens=args.return_all_generated_tokens
)
return_dict=True)
torch.cuda.synchronize()
tensorrt_llm.profiler.stop("tmp")

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets~=2.16.1
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
datasets==2.14.6
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -415,10 +415,6 @@ def main(args):
"Python bindings of C++ session is unavailable, fallback to Python session."
)
args.use_py_session = True
if args.return_all_generated_tokens:
raise ValueError(
"Returning all the generated tokens at each step is not supported in summarize.py"
)
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
runner_kwargs = dict(engine_dir=args.engine_dir,
rank=runtime_rank,

View File

@ -329,16 +329,4 @@ def add_common_args(parser):
action='store_true',
help="Use device map 'auto' to load a pretrained HF model. This may "
"help to test a large model that cannot fit into a singlue GPU.")
parser.add_argument(
"--return_all_generated_tokens",
default=False,
action="store_true",
help="if false, return only generated tokens at each streaming step."
"If true, return the full beams/outputs at each step"
"Overwritten to True if num_beams>1 and streaming"
"(only available with cpp session). "
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)."
)
return parser

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024070200
tensorrt_llm==0.12.0.dev2024070900
tiktoken
datasets
kaldialign

View File

@ -20,7 +20,7 @@ tokenizers>=0.14
# Default torch is CPU-only on Windows, so need to specify a torch version with GPU support
torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-win_amd64.whl
nvidia-modelopt~=0.13,<0.14
transformers==4.38.2
transformers>=4.38.2
wheel
optimum
evaluate

View File

@ -14,6 +14,7 @@
# limitations under the License.
import array
import struct
import sys
from contextlib import contextmanager
from typing import List, Tuple
@ -83,7 +84,7 @@ class IpcMemory():
self.local_ptr = 0
def __del__(self):
if self.open_ipc:
if not sys.is_finalizing() and self.open_ipc:
IpcMemory.close_ipc_memory(self.mapping, self.peer_ptrs)
def serialize(self) -> List[int]:

View File

@ -69,23 +69,6 @@ _bandwidths = {
"PCIe-5": 64,
}
_templates = {
"H100-SXM":
dict(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
intra_node_sharp=True,
memory_bw=3350,
math_throughput=MathThroughput(
int8=1979,
fp8=1979,
float16=989,
bfloat16=989,
float32=495,
),
),
}
cluster_infos = {
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"A100-SXM-80GB":
@ -119,18 +102,18 @@ cluster_infos = {
# from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
"H100-SXM":
ClusterInfo(
**_templates["H100-SXM"],
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
intra_node_sharp=True,
memory_bw=3350,
memory_budget_per_device=80,
),
"H100-SXM-64G":
ClusterInfo(
**_templates["H100-SXM"],
memory_budget_per_device=64,
),
"H100-SXM-94G":
ClusterInfo(
**_templates["H100-SXM"],
memory_budget_per_device=94,
math_throughput=MathThroughput(
int8=1979,
fp8=1979,
float16=989,
bfloat16=989,
float32=495,
),
),
"H100-PCIe":
ClusterInfo(
@ -369,12 +352,6 @@ def infer_cluster_key() -> str:
return "H100-SXM"
else:
return "H100-PCIe"
elif match("H100XS", device_name):
return "H100-SXM-64G"
elif match("H100XM", device_name):
return "H100-SXM"
elif match("H100XL", device_name):
return "H100-SXM-94G"
elif match("L40S", device_name):
return "L40S"
elif match("L40", device_name):

View File

@ -27,11 +27,11 @@ import torch
from ..auto_parallel import infer_cluster_config
from ..auto_parallel.cluster_info import cluster_infos
from ..builder import BuildConfig, Engine, build
from ..functional import PositionEmbeddingType
from ..logger import logger
from ..lora_manager import LoraConfig, LoraManager
from ..models import MODEL_MAP, PretrainedConfig
from ..models.modeling_utils import (WEIGHT_LOADER_MODELS,
SpeculativeDecodingMode)
from ..models.modeling_utils import SpeculativeDecodingMode
from ..plugin import PluginConfig, add_plugin_argument
@ -248,13 +248,6 @@ def parse_arguments():
return args
def preprocess_model_config(model_config, **kwargs):
if model_config.architecture in WEIGHT_LOADER_MODELS:
model_config.mapping.tp_size = kwargs['tp_size']
model_config.mapping.pp_size = kwargs['pp_size']
model_config.mapping.world_size = kwargs['tp_size'] * kwargs['pp_size']
def build_model(
build_config: BuildConfig,
rank: int = 0,
@ -428,7 +421,6 @@ def main():
ckpt_dir = ckpt_dir_or_model_config
model_config = PretrainedConfig.from_json_file(config_path)
preprocess_model_config(model_config, **kwargs)
if args.build_config is None:
if args.multiple_profiles == "enable" and args.opt_num_tokens is not None:
@ -472,7 +464,6 @@ def main():
deduced_max_seq_len = model_config.max_position_embeddings
# Step 2: Scale max_seq_len with rotary scaling
rotary_scaling = getattr(model_config, "rotary_scaling", None)
if rotary_factor != 1:
deduced_max_seq_len *= rotary_factor
logger.warning(
@ -485,8 +476,18 @@ def main():
f'max_seq_len is not specified, using value {deduced_max_seq_len}'
)
else:
if not plugin_config.streamingllm and model_config.max_position_embeddings is not None:
assert args.max_seq_len <= model_config.max_position_embeddings * rotary_factor, f'max_seq_len {args.max_seq_len} can\'t be larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}'
if not plugin_config.streamingllm and model_config.max_position_embeddings is not None \
and model_config.position_embedding_type != PositionEmbeddingType.relative:
if args.max_seq_len > model_config.max_position_embeddings * rotary_factor:
logger.warning(
f'max_seq_len {args.max_seq_len} is larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}, '
'the model accuracy might be affected')
if args.max_input_len > args.max_seq_len:
logger.warning(
f'max_input_len is {args.max_input_len} is larger than max_seq_len {args.max_seq_len}, clipping it to max_seq_len'
)
args.max_input_len = args.max_seq_len
build_config = BuildConfig.from_dict(
{

View File

@ -76,9 +76,9 @@ class GenerationRequest:
# The following options in the Executor API are not yet exposed by the HLAPI:
# https://jirasw.nvidia.com/browse/TRTLLM-489
"bad_words":
self.sampling_params.bad_words or [],
self.sampling_params._get_bad_words(),
"stop_words":
self.sampling_params.stop_words or [],
self.sampling_params._get_stop_words(),
"embedding_bias":
self.sampling_params.embedding_bias,
"external_draft_tokens_config":
@ -182,6 +182,15 @@ class GenerationResult:
self.outputs[i].generation_logits = generation_logits[
i, :self.outputs[i].length]
if self.finished and not self._generation_request.sampling_params.include_stop_str_in_output:
for beam_output in self.outputs:
for stop_ids in self._generation_request.sampling_params._get_stop_words(
):
if beam_output.token_ids[-len(stop_ids):] == stop_ids:
beam_output.token_ids = beam_output.token_ids[:-len(
stop_ids)]
break
if context_logits is not None:
self.context_logits = context_logits

View File

@ -5637,17 +5637,20 @@ def selective_scan(input: Tensor,
A: Tensor,
BC: Tensor,
D: Tensor,
z: Tensor,
host_request_types: Tensor,
last_token_ids: Tensor,
dim: int,
dstate: int,
dt_rank: int,
is_variable_B: bool,
is_variable_C: bool,
delta_softplus: bool,
dtype: str,
slot_mapping: Optional[Tensor] = None):
z: Optional[Tensor] = None,
host_context_lengths: Optional[Tensor] = None,
slot_mapping: Optional[Tensor] = None,
nheads: int = 1,
ngroups: int = 1,
chunk_size: int = 256,
mamba_version: str = 'Mamba1'):
'''
Parameters:
input : Tensor (On GPU)
@ -5658,27 +5661,34 @@ def selective_scan(input: Tensor,
Or the CPU tensor of shape [1] for the pointer of paged states.
delta : Tensor (On GPU)
The delta tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
The delta tensor.
mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
delta_bias : Tensor (On GPU)
The delta bias tensor. Its shape is [dim]
The delta bias tensor.
mamba: Its shape is [dim]
mamba2: Its shape is [nheads]
A : Tensor (On GPU)
A matrix. Its shape is [dstate, dim]
A matrix.
mamba: Its shape is [dstate, dim]
mamba2: Its shape is [nheads]
BC : Tensor (On GPU)
B matrix. Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
B and C matrix.
mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding
D : Tensor (On GPU)
D matrix. Its shape is [dim]
z : Tensor (On GPU)
The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
D matrix.
mamba: Its shape is [dim]
mamba2: Its shape is [nheads]
host_request_types : Tensor (On CPU)
The tensor on the host that indicates if a request is in context or
generation phase. Its shape is [batch_size]. See Inflight Batching
in docs/gpt_attention.md,
in docs/gpt_attention.md
last_token_ids : Tensor (On GPU)
The inclusive prefix-sum of the lengths or the lengths of the
@ -5693,22 +5703,32 @@ def selective_scan(input: Tensor,
dt_rank: int
The rank dimension of dt_proj
is_variable_B : bool
Is the matrix B a variable? Set to 'True' if B is a dynamic matrix
during inference, 'False' otherwise
is_variable_C : bool
Is the matrix C a variable? Set to 'True' if C is a dynamic matrix
during inference, 'False' otherwise
delta_softplus : bool
Do we apply softplus to the delta.
dtype: str
data type
z : Tensor (On GPU) (Optional)
The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
host_context_lengths: Tensor (On CPU) (Optional)
A host tensor that contains the lengths of the different inputs,
slot_mapping: Tensor (On GPU) (Optional)
Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]
nheads: int (Optional)
The number of heads.
ngroups: int (Optional)
The number of groups.
chunk_size: int (Optional)
The chunk_size is used for the chunk_scan kernel.
mamba_version: int (Optional)
Mamba version, support Mamba1 as default.
'''
assert host_request_types is not None
selective_scan_plg_creator = trt.get_plugin_registry().get_plugin_creator(
@ -5721,12 +5741,13 @@ def selective_scan(input: Tensor,
trt.PluginFieldType.INT32)
dt_rank = trt.PluginField("dt_rank", np.array(dt_rank, dtype=np.int32),
trt.PluginFieldType.INT32)
is_variable_B = trt.PluginField(
"is_variable_B", np.array(np.int8(is_variable_B), dtype=np.int8),
trt.PluginFieldType.INT8)
is_variable_C = trt.PluginField(
"is_variable_C", np.array(np.int8(is_variable_C), dtype=np.int8),
trt.PluginFieldType.INT8)
nheads = trt.PluginField("nheads", np.array(nheads, dtype=np.int32),
trt.PluginFieldType.INT32)
ngroups = trt.PluginField("ngroups", np.array(ngroups, dtype=np.int32),
trt.PluginFieldType.INT32)
chunk_size = trt.PluginField("chunk_size",
np.array(chunk_size, dtype=np.int32),
trt.PluginFieldType.INT32)
delta_softplus = trt.PluginField(
"delta_softplus", np.array(np.int8(delta_softplus), dtype=np.int8),
trt.PluginFieldType.INT8)
@ -5741,20 +5762,34 @@ def selective_scan(input: Tensor,
"paged_state",
np.array(np.int8(default_net().plugin_config.paged_state),
dtype=np.int8), trt.PluginFieldType.INT8)
if z is None:
z_enabled = trt.PluginField("z_enabled", np.array(0, dtype=np.int8),
trt.PluginFieldType.INT8)
else:
z_enabled = trt.PluginField("z_enabled", np.array(1, dtype=np.int8),
trt.PluginFieldType.INT8)
is_mamba2 = trt.PluginField(
"is_mamba2",
np.array(1 if mamba_version == 'Mamba2' else 0, dtype=np.int8),
trt.PluginFieldType.INT8)
pfc = trt.PluginFieldCollection([
dim, dstate, dt_rank, is_variable_B, is_variable_C, delta_softplus,
pf_type, remove_input_padding, paged_state
dim, dstate, dt_rank, nheads, ngroups, chunk_size, delta_softplus,
pf_type, remove_input_padding, paged_state, z_enabled, is_mamba2
])
selective_scan_plug = selective_scan_plg_creator.create_plugin(
"selective_scan", pfc)
plug_inputs = [
input, state_or_ptr, delta, delta_bias, A, BC, D, z, host_request_types,
input, state_or_ptr, delta, delta_bias, A, BC, D, host_request_types,
last_token_ids
]
if default_net().plugin_config.remove_input_padding:
plug_inputs += [host_context_lengths]
if default_net().plugin_config.paged_state:
plug_inputs += [slot_mapping]
if z is not None:
plug_inputs += [z]
plug_inputs = [i.trt_tensor for i in plug_inputs]
layer = default_trtnet().add_plugin_v2(plug_inputs, selective_scan_plug)

View File

@ -27,32 +27,59 @@ def get_build_cache_config_from_env() -> tuple[bool, str]:
return build_cache_enabled, build_cache_root
class BuildCache:
'''
The BuildCache class is a class that manages the intermediate products from the build steps.
class BuildCacheConfig:
"""
Configuration for the build cache.
NOTE: currently, only engine-building is supported
TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc.
'''
# The version of the cache, will be used to determine if the cache is compatible
CACHE_VERSION = 0
Attributes:
cache_root (str): The root directory for the build cache.
max_records (int): The maximum number of records to store in the cache.
max_cache_storage_gb (float): The maximum amount of storage (in GB) to use for the cache.
"""
def __init__(self,
cache_root: Optional[Path] = None,
max_records: int = 10,
max_cache_storage_gb: int = 256):
'''
Args:
cache_root (Path): The root directory of the cache
max_records (int): The maximum number of records to keep in the cache
max_cache_storage_gb (int): The maximum storage size of the cache
'''
_, default_cache_root = get_build_cache_config_from_env()
self.cache_root = cache_root or Path(default_cache_root)
self.max_records = max_records
self.max_cache_storage_gb = max_cache_storage_gb
max_cache_storage_gb: float = 256):
self._cache_root = cache_root
self._max_records = max_records
self._max_cache_storage_gb = max_cache_storage_gb
if max_records < 1:
@property
def cache_root(self) -> Path:
_build_cache_enabled, _build_cache_root = get_build_cache_config_from_env(
)
return self._cache_root or Path(_build_cache_root)
@property
def max_records(self) -> int:
return self._max_records
@property
def max_cache_storage_gb(self) -> float:
return self._max_cache_storage_gb
class BuildCache:
"""
The BuildCache class is a class that manages the intermediate products from the build steps.
NOTE: currently, only engine-building is supported
TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc.
"""
# The version of the cache, will be used to determine if the cache is compatible
CACHE_VERSION = 0
def __init__(self, config: Optional[BuildCacheConfig] = None):
_, default_cache_root = get_build_cache_config_from_env()
config = config or BuildCacheConfig()
self.cache_root = config.cache_root or Path(default_cache_root)
self.max_records = config.max_records
self.max_cache_storage_gb = config.max_cache_storage_gb
if config.max_records < 1:
raise ValueError("max_records should be greater than 0")
def get_engine_building_cache_stage(self,

Some files were not shown because too many files have changed in this diff Show More