Update TensorRT-LLM (#1954)

* Update TensorRT-LLM

---------

Co-authored-by: Altair-Alpha <62340011+Altair-Alpha@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-07-16 15:30:25 +08:00 committed by GitHub
parent a96cccafcf
commit 2d234357c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
207 changed files with 7410 additions and 1334 deletions

8
.gitignore vendored
View File

@ -40,3 +40,11 @@ tensorrt_llm/bindings/*.pyi
# Testing
.coverage.*
results_trt/
# build/debug
*.safetensors
*/tllm_debug/**
*.patch
# Generated files
cpp/include/tensorrt_llm/executor/version.h

View File

@ -177,13 +177,6 @@ def parse_arguments():
'If this option is specified, it will override the max decoder input len of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined max decoder input len'
))
parser.add_argument(
'--max_output_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max output len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_seq_len',
'--max_decoder_seq_len',
@ -360,21 +353,6 @@ def main(args):
rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
if args.max_output_len:
logger.warning(
'--max_output_len has been deprecated in favor of --max_seq_len')
if args.max_input_len:
if args.max_seq_len:
logger.warning(
'--max_seq_len has been overwritten due to --max_output_len being specified'
)
args.max_seq_len = args.max_input_len + args.max_output_len
else:
raise Exception(
f"--max_output_len is specified but not --max_input_len")
del args.max_output_len
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
# Current Mem Monitor will cause benchmark script hang
# because MPI does not work well with multiprocessing.

View File

@ -129,13 +129,6 @@ def parse_arguments():
help=
('If this option is specified, it will override the max input len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_output_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max output len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_seq_len',
'--max_decoder_seq_len',

View File

@ -185,7 +185,7 @@ When the benchmark runs successfully, you will see a report out of the run simil
[RANK 0] Completed request submission.
[RANK 0] Calculating results.
[RANK 0] Reporting...
[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_output_len 128 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_seq_len 256 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
===========================================================
= METADATA
===========================================================

View File

@ -128,6 +128,27 @@ if(INDEX_RANGE_CHECK)
message(WARNING "Check index range to detect OOB accesses")
endif()
# Read the project version
set(TRTLLM_VERSION_DIR ${PROJECT_SOURCE_DIR}/../tensorrt_llm)
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS
${TRTLLM_VERSION_DIR}/version.py)
execute_process(
COMMAND python3 -c "import version; print(version.__version__)"
WORKING_DIRECTORY ${TRTLLM_VERSION_DIR}
OUTPUT_VARIABLE TRTLLM_VERSION
RESULT_VARIABLE TRTLLM_VERSION_RESULT
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(TRTLLM_VERSION_RESULT EQUAL 0)
message(STATUS "TensorRT-LLM version: ${TRTLLM_VERSION}")
else()
message(FATAL_ERROR "Failed to determine Tensorrt-LLM version")
endif()
configure_file(
cmake/templates/version.h
${CMAKE_CURRENT_SOURCE_DIR}/include/tensorrt_llm/executor/version.h)
# Determine CUDA version before enabling the language extension
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
@ -139,7 +160,7 @@ if(CMAKE_CUDA_COMPILER)
"${CMAKE_CUDA_COMPILER} --version | egrep -o 'V[0-9]+.[0-9]+.[0-9]+' | cut -c2-"
RESULT_VARIABLE _BASH_SUCCESS
OUTPUT_VARIABLE CMAKE_CUDA_COMPILER_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE)
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _BASH_SUCCESS EQUAL 0)
message(FATAL_ERROR "Failed to determine CUDA version")

View File

@ -0,0 +1,24 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// THIS FILE IS AUTO GENERATED FROM cmake/templates/version.h. DO NOT EDIT.
namespace tensorrt_llm::executor
{
static auto constexpr kTensorRtLlmVersion = "@TRTLLM_VERSION@";
}

View File

@ -122,6 +122,9 @@ private:
void decoupled_execution_loop();
std::shared_ptr<std::thread> worker_thread_;
std::shared_ptr<nvinfer1::ILogger> mLogger{};
inline static std::string const kPROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP";
inline static std::string const kLEGACY_PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_GPTM_PROFILE_START_STOP";
};
} // namespace tensorrt_llm::batch_manager

View File

@ -375,6 +375,11 @@ public:
return mSecondaryPool;
}
[[nodiscard]] SizeType32 getNumLayers() const
{
return mNumLayers;
}
//! \brief Get index in pool to K or V block.
//! \param blockId the blockId as returned by getBlockId()
//! \param fieldIdx either 0 (K) or 1 (V),
@ -592,6 +597,8 @@ public:
void removeToken(SizeType32 seqSlotIdx);
void rewindKVCache(SizeType32 seqSlotIdx, SizeType32 rewindLengths);
[[nodiscard]] GenerationRequest const& getSequence(SizeType32 seqSlotIdx) const;
[[nodiscard]] bool isCrossKv() const
{
return mCacheType == CacheType::kCROSS;
@ -634,4 +641,5 @@ private:
// KV cache type (self or cross)
CacheType mCacheType;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -0,0 +1,99 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class BlockIterator
{
public:
using iterator_category = std::forward_iterator_tag;
using value_type = runtime::ITensor;
using pointer = runtime::ITensor::SharedPtr;
using reference = value_type&;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
BlockIterator(runtime::ITensor::SharedPtr blockPoolPtr, std::vector<SizeType32> blockIds, size_t idx)
: mPool{std::move(blockPoolPtr)}
, mBlockIds{std::move(blockIds)}
, mIdx{idx}
{
TLLM_CHECK(mPool);
TLLM_CHECK(mIdx <= mBlockIds.size());
update();
}
[[nodiscard]] pointer operator->()
{
return mCurrent;
}
[[nodiscard]] reference operator*()
{
return *mCurrent;
}
BlockIterator& operator++()
{
mIdx++;
update();
return *this;
}
BlockIterator operator++(int)
{
auto ret = *this;
ret.update();
mIdx++;
return ret;
}
[[nodiscard]] bool operator==(BlockIterator const& other) const
{
return mIdx == other.mIdx && mPool.get() == other.mPool.get();
}
[[nodiscard]] bool operator!=(BlockIterator const& other) const
{
return !(*this == other);
}
private:
void update()
{
if (mIdx < mBlockIds.size())
{
mCurrent = runtime::ITensor::slice(mPool, mBlockIds.at(mIdx), 1);
}
}
runtime::ITensor::SharedPtr mPool;
runtime::ITensor::SharedPtr mCurrent;
const std::vector<SizeType32> mBlockIds;
size_t mIdx;
};
[[nodiscard]] BlockIterator getBlockBeginIt(
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
[[nodiscard]] BlockIterator getBlockEndIt(
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/bufferManager.h"
@ -62,7 +63,8 @@ public:
using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor;
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream const&)>;
using LogitsPostProcessor = std::function<void(
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
@ -77,19 +79,21 @@ public:
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false,
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false,
std::optional<RequestIdType> clientId = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(isStreaming)
, mEndId(endId)
, mPadId(padId)
, mLogitsPostProcessor(logitsPostProcessor)
, mApplyLogitsPostProcessorBatched(applyLogitsPostProcessorBatched)
, mClientId(clientId)
, mIsStreaming(isStreaming)
, mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1)
, mMaxSentTokenLen(mPromptLen)
, mEmbeddingBias(std::move(embeddingBias))
, mBadWordsList(std::move(badWordsList))
, mStopWordsList(std::move(stopWordsList))
@ -105,6 +109,7 @@ public:
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
, mDraftLogits(draftLogits)
, mNumTokensPerIteration(1)
, mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1))
, mReturnContextLogits(returnContextLogits)
, mReturnGenerationLogits(returnGenerationLogits)
, mExcludeInputFromOutput(excludeInputFromOutput)
@ -125,11 +130,12 @@ public:
, mMaxNewTokens(req.getMaxNewTokens())
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(req.getStreaming())
, mEndId(req.getEndId())
, mPadId(req.getPadId())
, mClientId(req.getClientId())
, mIsStreaming(req.getStreaming())
, mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1)
, mMaxSentTokenLen(mPromptLen)
, mEmbeddingBias(std::nullopt)
, mBadWordsList(std::nullopt)
, mStopWordsList(std::nullopt)
@ -145,6 +151,7 @@ public:
, mDraftTokens(std::make_shared<VecTokens>())
, mDraftLogits(std::nullopt)
, mNumTokensPerIteration(1)
, mReturnAllGeneratedTokens(req.getReturnAllGeneratedTokens())
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
@ -152,6 +159,16 @@ public:
, mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput)
, mDecodingIter(0)
{
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnAllGeneratedTokens == false)
{
TLLM_LOG_WARNING(
"Setting mReturnAllGeneratedTokens to True since streaming AND beam search are done simultaneously. "
"Returning the full beams at each streaming step is needed because beam search + streaming can change "
"previous outputs. Initialize request with mReturnAllGeneratedTokens = True to dismiss this error. "
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output "
"length).");
mReturnAllGeneratedTokens = true;
}
if (req.getEncoderInputTokenIds())
{
mState = REQUEST_STATE_ENCODER_INIT;
@ -440,20 +457,20 @@ public:
mSeqSlot.reset();
}
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
/// client duplicated token positions.
/// @return The maximum position of the tokens sent to the client
[[nodiscard]] SizeType32 getMaxSentTokenPos() const
/// @brief Get the maximum length of tokens returned to the client. Use to ensure we don't return to
/// client duplicated tokens.
/// @return The maximum length of the tokens sent to the client.
[[nodiscard]] SizeType32 getMaxSentTokenLen() const
{
return mMaxSentTokenPos;
return mMaxSentTokenLen;
}
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to
/// client duplicated token positions.
/// @param pos The maximum position
void setMaxSentTokenPos(SizeType32 pos)
/// @brief Sets the maximum length of tokens returned to the client. Use to ensure we don't return to
/// client duplicated tokens.
/// @param maxSentLength The new maximum length.
void setMaxSentTokenLen(SizeType32 maxSentLength)
{
mMaxSentTokenPos = pos;
mMaxSentTokenLen = maxSentLength;
}
[[nodiscard]] std::optional<TensorPtr> getPromptEmbeddingTable() const
@ -599,7 +616,7 @@ public:
void setNumTokensPerIteration(SizeType32 numTokensPerIteration)
{
mNumTokensPerIteration = numTokensPerIteration;
mNumTokensPerIteration = std::max(1, numTokensPerIteration);
}
[[nodiscard]] SizeType32 getNumTokensPerIteration() const
@ -669,6 +686,23 @@ public:
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
[[nodiscard]] bool constexpr isStreaming() const noexcept
{
return mIsStreaming;
}
void constexpr setStreaming(bool isStreaming) noexcept
{
mIsStreaming = isStreaming;
}
void setReturnAllGeneratedTokens(bool const returnAllGeneratedTokens)
{
TLLM_CHECK_WITH_INFO(!mIsStreaming || mSamplingConfig.beamWidth == 1 || returnAllGeneratedTokens,
"returnAllGeneratedTokens must be true if streaming AND beam search are used.");
mReturnAllGeneratedTokens = returnAllGeneratedTokens;
}
void setReturnContextLogits(bool const returnContextLogits)
{
mReturnContextLogits = returnContextLogits;
@ -739,7 +773,7 @@ public:
return mGenerationLogitsFragments;
}
void addGenerationFragments(TensorPtr& genLogits)
void addGenerationLogitsFragment(TensorPtr& genLogits)
{
mGenerationLogitsFragments.push_back(genLogits);
}
@ -876,21 +910,27 @@ public:
executor::Result result;
result.isFinal = isGenerationCompleteState();
auto nbBeams = mSamplingConfig.beamWidth;
auto maxNbTokens = getMaxBeamNumTokens();
// FIXME(nkorobov): For streaming we do not allow beam search and
// streaming index calculation here applies only for sampling
// getNumTokensPerIteration takes accepted draft tokens into account
int nbTokensOut = mIsStreaming ? std::max(getNumTokensPerIteration(), 1) : maxNbTokens;
if (mExcludeInputFromOutput && !mIsStreaming)
auto const nbBeams = mSamplingConfig.beamWidth;
auto const maxNbTokens = getMaxBeamNumTokens();
auto const calculateNbTokensOut = [this](SizeType32 maxNbTokens)
{
nbTokensOut -= getOrigPromptLen();
}
if (!mIsStreaming)
{
return maxNbTokens - (mExcludeInputFromOutput ? getOrigPromptLen() : 0);
}
return mReturnAllGeneratedTokens ? maxNbTokens - getOrigPromptLen()
: maxNbTokens - getMaxSentTokenLen();
};
auto const maxNbTokensOut = calculateNbTokensOut(maxNbTokens);
result.outputTokenIds.resize(nbBeams);
SizeType32 tokenPos = maxNbTokens - nbTokensOut;
bool shouldSendResponse = isGenerationCompleteState() || (mIsStreaming && tokenPos > getMaxSentTokenPos());
auto const startTokenPos = maxNbTokens - maxNbTokensOut;
auto const shouldSendResponse
= isGenerationCompleteState() || (mIsStreaming && maxNbTokens > getMaxSentTokenLen());
if (!shouldSendResponse)
{
@ -900,24 +940,14 @@ public:
{
for (SizeType32 beam = 0; beam < nbBeams; ++beam)
{
auto tokens = getTokens(beam);
auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size();
auto const& tokens = getTokens(beam);
auto const nbTokensOut = calculateNbTokensOut(tokens.size());
// Take accepted draft tokens into account when streaming
auto const numAcceptedTokens = std::max(0, getNumTokensPerIteration() - 1);
nbTokens += mIsStreaming ? numAcceptedTokens : 0;
if (mExcludeInputFromOutput && !mIsStreaming)
if (nbTokensOut > 0)
{
nbTokens -= getOrigPromptLen();
auto const first = tokens.data() + startTokenPos;
result.outputTokenIds.at(beam).assign(first, first + nbTokensOut);
}
if (nbTokens > 0)
{
result.outputTokenIds.at(beam).assign(
tokens.data() + tokenPos, tokens.data() + tokenPos + nbTokens);
}
// Correct next token position by accepted draft tokens
tokenPos += numAcceptedTokens;
}
if (returnLogProbs())
@ -954,7 +984,7 @@ public:
}
// Update position of last sent response
mMaxSentTokenPos = tokenPos;
setMaxSentTokenLen(maxNbTokens);
auto response = executor::Response(mRequestId, std::move(result));
return response;
@ -972,21 +1002,25 @@ public:
// Tokens [beam_size, mPromptLen + getMaxNumGeneratedTokens()]
runtime::SamplingConfig mSamplingConfig;
LlmRequestState_t mState;
bool mIsStreaming;
std::optional<TokenIdType> mEndId;
std::optional<TokenIdType> mPadId;
std::optional<SizeType32> mSeqSlot;
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
bool mApplyLogitsPostProcessorBatched;
std::optional<RequestIdType> mClientId;
// Position of mask token in GLM model inputs
SizeType32 mMaskPosition{0};
protected:
bool mIsStreaming;
BeamTokens mTokens;
SizeType32 mOrigPromptLen;
// Number of tokens already in KV cache before context phase.
// A value > 0 indicates cached KV cache blocks were reused.
// Up to inputLen - 1 tokens can be reused.
SizeType32 mPrepopulatedPromptLen{0};
SizeType32 mMaxSentTokenPos;
SizeType32 mMaxSentTokenLen;
std::optional<TensorPtr> mEmbeddingBias;
std::optional<TensorPtr> mBadWordsList;
@ -1011,6 +1045,8 @@ protected:
std::optional<TensorPtr> mDraftLogits;
SizeType32 mNumTokensPerIteration;
// whether to return the full beams on each iteration. True when doing streaming + beamsearch
bool mReturnAllGeneratedTokens;
// Save logits
bool mReturnContextLogits;
bool mReturnGenerationLogits;
@ -1108,13 +1144,14 @@ public:
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false,
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false,
std::optional<RequestIdType> clientId = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched,
std::move(encoderInputTokens), returnEncoderOutput)
std::move(encoderInputTokens), returnEncoderOutput, clientId)
{
}

View File

@ -0,0 +1,31 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <optional>
#include <string>
#include <unordered_set>
namespace tensorrt_llm::common
{
/// @brief Populate the start and end profiling iteration indexes from the provided environment variables
/// Try to set from envVarName first, and if that fails, try to set from legacyEnvVarName
/// Env variable values are expected to be in the format "1,2,3-5,6-8,9"
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName = std::nullopt);
} // namespace tensorrt_llm::common

View File

@ -24,6 +24,7 @@
#include <memory> // std::make_unique
#include <sstream> // std::stringstream
#include <string>
#include <unordered_set>
#include <vector>
namespace tensorrt_llm::common
@ -106,4 +107,7 @@ inline bool strStartsWith(std::string const& str, std::string const& prefix)
return str.rfind(prefix, 0) == 0;
}
/// @brief Split a string into a set of strings using a delimiter
std::unordered_set<std::string> str2set(std::string const& input, char delimiter);
} // namespace tensorrt_llm::common

View File

@ -16,8 +16,6 @@
#pragma once
#include "tensorrt_llm/common/stringUtils.h"
#include <array>
#include <cstddef>
#include <stdexcept>

View File

@ -37,6 +37,9 @@ class MpiComm;
namespace tensorrt_llm::executor
{
/// @brief Version of TRT-LLM
char const* version() noexcept;
class Model;
class Serialization;
@ -252,6 +255,8 @@ public:
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
/// name provided to the ExecutorConfig.
/// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models
/// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens
/// after every streaming step.
Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@ -262,7 +267,8 @@ public:
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
std::optional<LoraConfig> loraConfig = std::nullopt,
std::optional<std::string> logitsPostProcessorName = std::nullopt,
std::optional<VecTokens> encoderInputTokenIds = std::nullopt);
std::optional<VecTokens> encoderInputTokenIds = std::nullopt, std::optional<IdType> clientId = std::nullopt,
bool returnAllGeneratedTokens = false);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
@ -288,6 +294,8 @@ public:
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
[[nodiscard]] std::optional<VecTokens> getEncoderInputTokenIds() const;
[[nodiscard]] std::optional<IdType> getClientId() const;
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig const& config);
@ -302,6 +310,8 @@ public:
void setLoraConfig(LoraConfig const& loraConfig);
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds);
void setClientId(IdType clientId);
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
private:
friend class Serialization;

View File

@ -53,10 +53,12 @@ using IterationType = std::uint64_t;
using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&)>;
using LogitsPostProcessor
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
using LogitsPostProcessorBatched = std::function<void(std::vector<IdType> const&, std::vector<Tensor>&,
std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&)>;
std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&,
std::vector<std::optional<IdType>> const&)>;
using MedusaChoices = std::vector<std::vector<SizeType32>>;
enum class DataType

View File

@ -41,30 +41,32 @@ public:
class Inputs
{
public:
//! [batchSize]
//! [maxBatchSize]
TensorPtr temperatures;
//! [batchSize]
//! [maxBatchSize]
TensorPtr positionIdsBase;
//! [batchSize] or [numGenSequences]
//! [maxBatchSize] or [numGenSequences]
TensorPtr generationLengths;
//! [batchSize]
//! [maxBatchSize]
TensorPtr randomDataSample;
//! [batchSize, maxNumPaths, maxPathDraftLen] or [numGenSequences, maxNumPaths, maxPathDraftLen]
//! [maxBatchSize, maxNumPaths, maxPathDraftLen] or [numGenSequences, maxNumPaths, maxPathDraftLen]
TensorPtr randomDataValidation;
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
//! [maxBatchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
TensorPtr draftTokens;
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
//! [maxBatchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
TensorPtr draftIndices;
//! [batchSize, maxNumPaths, maxPathDraftLen, vocabSize]
//! [maxBatchSize, maxNumPaths, maxPathDraftLen, vocabSize]
//! or [numGenSequences, maxNumPaths, maxPathDraftLen, vocabSize]
TensorPtr draftProbs;
//! [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
//! [maxBatchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
//! or [numGenSequences, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
TensorPtr packedMasks;
//! [batchSize] or [numGenSequences]
//! [maxBatchSize] or [numGenSequences]
TensorPtr positionIds;
// [1], on pinned
TensorPtr maxGenLengthHost;
// [maxBatchSize]
TensorPtr generationLengthsHost;
void create(SizeType32 maxNumSequences, runtime::TllmRuntime const& runtime,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);

View File

@ -40,10 +40,11 @@ public:
enum class ModelVariant : std::int32_t
{
kGpt = 0,
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
kMamba = 2, // https://github.com/state-spaces/mamba
kRecurrentGemma = 3, // https://github.com/google-deepmind/recurrentgemma
kEncDec = 4,
kChatGlm = 1, // https://github.com/THUDM/ChatGLM-6B
kGlm = 2, // https://github.com/THUDM/GLM
kMamba = 3, // https://github.com/state-spaces/mamba
kRecurrentGemma = 4, // https://github.com/google-deepmind/recurrentgemma
kEncDec = 5,
};
struct RnnConfig
@ -526,7 +527,7 @@ public:
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
{
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm
|| mModelVariant == ModelVariant::kRecurrentGemma;
|| mModelVariant == ModelVariant::kChatGlm || mModelVariant == ModelVariant::kRecurrentGemma;
}
[[nodiscard]] bool hasRnnConfig() const noexcept

View File

@ -106,6 +106,7 @@ int parseTacticToId(nlohmann::json tactic_config)
{K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
{K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
{K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
{K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B},
};
if (c.tile_config_sm90 != tile_map.at(K(tile_shape[0], tile_shape[1])))

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5804fde474d6489db29204259b7e6c368117acadb7fb6dc807868ee0391c458b
size 3953206
oid sha256:f41188ef30e21d12ebcb92ee6546badb330f6c63a90fff535f3e613d61f103f9
size 4268820

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:85802a0e66148acb17d017a64dd982287775ce7bf5aa4e8bb7e5466b3736c7ee
size 4019734
oid sha256:510d90d67edcdbbe164493637772e50ef2f8d88d927f561c46512052aed7624c
size 4365768

View File

@ -1,3 +1,3 @@
00fb525bdf4ff217c16940540b2357c4 libtensorrt_llm_batch_manager_static.a
97d2db7f62745001d871bc89fb38eed6 libtensorrt_llm_batch_manager_static.pre_cxx11.a
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
6aa84afae42ed1e725b8809b8f8a38fb libtensorrt_llm_batch_manager_static.a
cc359afb584b086456510f084f6617ed libtensorrt_llm_batch_manager_static.pre_cxx11.a
db055e58b6c6c8cf7350b66a583f9c388c4eac07 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33a724d7e9eabc358c0d674151d45cef8849ae702cc5f2f88b259299a8306574
size 3842582
oid sha256:920951af1730c7304fd1a7c286ddc8f96a17f918aaaf7815da385bf92c37e54c
size 4129858

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:490a93ff13a67949a30e279fc3df27456c7f5d4084158c3089befccf78118b7f
size 3799140
oid sha256:1d9b525a3855dd5a853604031efb306b08afd1ee425aba2a7846f7cd77f89ddb
size 4107114

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:663a163c3177644ed86fa7a2145fe5e9dbf6f2f0ed06c96d367236da323a3432
size 22523526
oid sha256:2a903a8cae43ec88d69fba666c3da1f301f1cb0aaf37256715da7363ee04a236
size 23909614

View File

@ -0,0 +1,84 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaProfilerUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdint>
#include <optional>
namespace
{
std::tuple<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexesImpl(
std::string const& envVarName)
{
auto envVarVal = std::getenv(envVarName.c_str());
auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""};
auto values = tensorrt_llm::common::str2set(envVarValStr, ',');
std::unordered_set<int32_t> startSet;
std::unordered_set<int32_t> endSet;
for (std::string const& value : values)
{
size_t dashIdx = value.find("-");
if (dashIdx != std::string::npos)
{
int32_t start = std::stoi(value.substr(0, dashIdx));
startSet.insert(start);
int32_t end = std::stoi(value.substr(dashIdx + 1));
endSet.insert(end);
}
else
{
int32_t start_end = std::stoi(value);
startSet.insert(start_end);
endSet.insert(start_end);
}
}
return std::make_pair(startSet, endSet);
}
} // namespace
namespace tensorrt_llm::common
{
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName)
{
auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName);
// If empty, try to use legacy env var name
if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty())
{
std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value());
if (!profileIterIdxs.empty() || !stopIterIdxs.empty())
{
TLLM_LOG_WARNING(
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
"Please "
"use %s "
"instead.",
legacyEnvVarName.value().c_str(), envVarName.c_str());
}
}
return std::make_pair(profileIterIdxs, stopIterIdxs);
}
} // namespace tensorrt_llm::common

View File

@ -20,6 +20,7 @@
#include <cerrno>
#include <cstdarg>
#include <cstring>
#include <iostream>
#include <string>
namespace tensorrt_llm::common
@ -54,4 +55,22 @@ std::string fmtstr(char const* format, ...)
return result;
};
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
{
std::unordered_set<std::string> values;
if (!input.empty())
{
std::stringstream valStream(input);
std::string val;
while (std::getline(valStream, val, delimiter))
{
if (!val.empty())
{
values.insert(val);
}
}
}
return values;
};
} // namespace tensorrt_llm::common

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <iomanip>
#include <sstream>
#include "tensorrt_llm/common/timestampUtils.h"
namespace tensorrt_llm::common
{
std::string getCurrentTimestamp()
{
auto now = std::chrono::system_clock::now();
auto now_t = std::chrono::system_clock::to_time_t(now);
auto tm = *std::localtime(&now_t);
auto epoch_to_now = now.time_since_epoch();
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(epoch_to_now);
auto us = std::chrono::duration_cast<std::chrono::microseconds>(epoch_to_now - seconds);
std::ostringstream stream;
stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S");
stream << "." << std::setfill('0') << std::setw(6) << us.count();
return stream.str();
}
} // namespace tensorrt_llm::common

View File

@ -0,0 +1,25 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
namespace tensorrt_llm::common
{
/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu"
std::string getCurrentTimestamp();
} // namespace tensorrt_llm::common

View File

@ -15,6 +15,7 @@
*/
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdlib>
#if !defined(_MSC_VER)

View File

@ -93,6 +93,8 @@ enum class CutlassTileConfigSM90
CtaShape128x128x128B,
CtaShape128x256x128B,
// CTA configs for M=128
CtaShape256x128x128B,
};
enum class MainloopScheduleType

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:497b00031131c1dc705e848e52f3d43148f55505e37bdad97f4933b2c074469d
size 1400502
oid sha256:af8889214b82f8e65a226b6558dbdef474552850b50f07df76cbc24aeac94d6c
size 1410084

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:417978bdb5c19f97d9758475acacfa18a4038fc3c5a83f981b02ee220104e0c7
size 1425792
oid sha256:09b641ce17db25301b7c4e9049bc11dc105f749be0742b91986bf47601c1bbc7
size 1437532

View File

@ -1,3 +1,3 @@
1df55ac2948ca7b7fe2d5e79934e660e libtensorrt_llm_executor_static.a
ea1641928d184d117deec0696763b274 libtensorrt_llm_executor_static.pre_cxx11.a
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
56835da1fe4f3edc7f28bac50b253b2d libtensorrt_llm_executor_static.a
4f1d05581c8c8d4663500c5300446778 libtensorrt_llm_executor_static.pre_cxx11.a
db055e58b6c6c8cf7350b66a583f9c388c4eac07 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0441d473852d11f50bcf23f4934b38d7e4c6d4a42f057eb04beb8aea4211cac
size 1451118
oid sha256:cb73df78859b9bf2d425a4c307403863b9de62820ff3b7e0ff2bbe6ac9f35894
size 1459664

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dc8619f99cf5a2e04bdb1482f157a9852bd745e90cf9e03a7878f73ed07e5610
size 1383936
oid sha256:c7bf468b3c45d0c8e605ada27e16edeaf7b22928883d28eca4e8f9b568a01eff
size 1391962

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:772d1b83e739b926729b99999fbb81768569ffb172c2e120665b2d31b987bb47
size 14071986
oid sha256:6491a8b88087cb0be7af82f9523dae800f5d217730941441f1804e5ccd4770b5
size 14289284

View File

@ -29,38 +29,38 @@ namespace tensorrt_llm
namespace kernels
{
typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int G_, int N_,
typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
// const half *g_mxY_, // B*L*H*P
// const half *g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
// const float *g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
half* g_mxCB_, // B*C*G*Q*Q
half const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
half* g_mxCB_, // B*C*G*Q*Q
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int G_, int N_,
typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
// const bf16 *g_mxY_, // B*L*H*P
// const bf16 *g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
// const float *g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
bf16* g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
bf16* g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
@ -68,21 +68,21 @@ template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
int warpM_, int warpN_, // warp number
int pipeS_, class Tp_>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> bmm_chunk_kernel(int B_,
int L_, int G_, int N_,
int L_, int H_, int P_, int G_, int N_,
// const Tp_ *g_mxY_, // B*L*H*P
// const Tp_ *g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
// const float *g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
Tp_* g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
Tp_* g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
@ -98,13 +98,17 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
// auto H = Rn<ID>{H_};
// auto P = Rn<ID>{P_};
auto H = Rn<ID>{H_};
auto P = Rn<ID>{P_};
auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
auto bOffset = Rn<ID>{H_ * P_};
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
@ -167,10 +171,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
g_mxBC_
+ get(
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
+ cOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
= int4{0, 0, 0, 0};
@ -180,10 +183,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
cp_shared_global<16>(b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
g_mxBC_
+ get(
(aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
+ cn<0> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
g_mxXBC_
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *xbcDim
+ bOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
= int4{0, 0, 0, 0};
@ -285,10 +287,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
* G * N
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
+ cOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= int4{0, 0, 0, 0};
@ -299,10 +300,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
cp_shared_global<16>(
b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2>
* G * N
+ cn<0> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
g_mxXBC_
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *xbcDim
+ bOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
= int4{0, 0, 0, 0};

View File

@ -29,57 +29,54 @@ namespace tensorrt_llm
namespace kernels
{
typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_,
typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
// const half *g_mxY_, // B*L*H*P
// const half *g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float* g_mxdc_, // B*C*H*Q
float* g_mxdA_, // B*C*H*Q
half const* g_mxdt_, // B*L*H
half const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
float const* g_mxdb_, // H
float const* g_mxA_, // H
// const half *g_mxCB_, // B*C*G*Q*Q
// const half *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
// const half *g_mxCB_, // B*C*G*Q*Q
// const float *g_mxD_, // H
// const half *g_mxXBC_, // B*L*(H*P+2*G*N)
half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_,
typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
// const bf16 *g_mxY_, // B*L*H*P
// const bf16 *g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float* g_mxdc_, // B*C*H*Q
float* g_mxdA_, // B*C*H*Q
bf16 const* g_mxdt_, // B*L*H
bf16 const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
float const* g_mxdb_, // H
float const* g_mxA_, // H
// const bf16 *g_mxCB_, // B*C*G*Q*Q
// const bf16 *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
// const bf16 *g_mxCB_, // B*C*G*Q*Q
// const float *g_mxD_, // H
// const bf16 *g_mxXBC_, // B*L*(H*P+2*G*N)
bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileH_, int warpH_, bool dtSoftplus_, class Tp_, class Wt_ = float>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_cumsum_kernel(int B_,
int L_, int H_,
int L_, int H_, int P_, int G_, int N_,
// const Tp_ *g_mxY_, // B*L*H*P
// const Tp_ *g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float* g_mxdc_, // B*C*H*Q
float* g_mxdA_, // B*C*H*Q
Tp_ const* g_mxdt_, // B*L*H
Tp_ const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
Wt_ const* g_mxdb_, // H
Wt_ const* g_mxA_, // H
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
// const Tp_ *g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxXBC_, // B*L*(H*P+2*G*N)
Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_)
{
using namespace tensorrt_llm::common;
@ -94,11 +91,12 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
auto H = Rn<ID>{H_};
// auto P = Rn<ID>{P_};
// auto G = Rn<ID>{G_};
// auto N = Rn<ID>{N_};
auto P = Rn<ID>{P_};
auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto dt_dim = g_mxZ_ ? Rn<ID>{2 * H_ * P_ + 2 * G_ * N_ + H_} : Rn<ID>{H_ * P_ + 2 * G_ * N_ + H_};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
@ -138,7 +136,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileH_> && blockIdx_y * Q + iQ < L)
{
r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * H + blockIdx_x * cn<tileH_> + thread(iStep))])
r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * dt_dim + dt_dim - H + blockIdx_x * cn<tileH_>
+ thread(iStep))])
+ r_db;
if (dtSoftplus_)

View File

@ -36,14 +36,13 @@ typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_,
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
half const* g_mxCB_, // B*C*G*Q*Q
half const* g_mxBC_, // B*L*2*G*N
float const* g_mxD_, // H
half const* g_mxX_, // B*L*H*P
half const* g_mxZ_, // B*L*H*P
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
@ -53,14 +52,13 @@ typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_,
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
bf16 const* g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxBC_, // B*L*2*G*N
float const* g_mxD_, // H
bf16 const* g_mxX_, // B*L*H*P
bf16 const* g_mxZ_, // B*L*H*P
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
@ -75,14 +73,13 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
Tp_ const* g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxBC_, // B*L*2*G*N
Wt_ const* g_mxD_, // H
Tp_ const* g_mxX_, // B*L*H*P
Tp_ const* g_mxZ_, // B*L*H*P
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
@ -105,6 +102,10 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
auto zdtDim = Rn<ID>{2 * H_ * P_ + 2 * G_ * N_ + H_};
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
@ -185,10 +186,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
g_mxBC_
+ get(
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
+ cOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
= int4{0, 0, 0, 0};
@ -365,10 +365,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
* G * N
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
+ cOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= int4{0, 0, 0, 0};
@ -402,8 +401,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_> + N)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
g_mxX_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *H * P
g_mxXBC_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxOs
@ -434,10 +433,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[0] = *(int*) (g_mxX_
*(int*) &tmp16[0] = *(int*) (g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
@ -452,10 +450,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[2] = *(int*) (g_mxX_
*(int*) &tmp16[2] = *(int*) (g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
@ -484,8 +481,7 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
{
*(int*) &tmp16[0] = *(int*) (g_mxZ_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *zdtDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
@ -502,8 +498,7 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
{
*(int*) &tmp16[2] = *(int*) (g_mxZ_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
* P
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *zdtDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));

View File

@ -36,14 +36,13 @@ typedef void (*ChunkStateKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_,
float* g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const half *g_mxCB_, // B*C*G*Q*Q
half const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
half const* g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
// const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
@ -53,14 +52,13 @@ typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_,
float* g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const bf16 *g_mxCB_, // B*C*G*Q*Q
bf16 const* g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
bf16 const* g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
// const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
@ -75,14 +73,13 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
float* g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
Tp_ const* g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
Tp_ const* g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
// const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
@ -105,6 +102,10 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
auto bOffset = Rn<ID>{H_ * P_};
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
@ -183,8 +184,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - iK * cn<tileK_>)
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
g_mxXBC_
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileM_>) *xbcDim + bOffset
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
@ -195,8 +196,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - iK * cn<tileK_>)
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>),
g_mxX_
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
g_mxXBC_
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>))
@ -329,8 +330,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileM_ * tileK_> && thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - jK * cn<tileK_>
&& jK * cn<tileK_> < Q)
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
g_mxBC_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
g_mxXBC_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileM_>) *xbcDim + bOffset
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
else if (thread(iStep) < cn<tileM_ * tileK_> && jK * cn<tileK_> < Q)
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
@ -341,8 +342,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
if (thread(iStep) < cn<tileN_ * tileK_> && thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_>
&& jK * cn<tileK_> < Q)
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>),
g_mxX_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
g_mxXBC_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_> && jK * cn<tileK_> < Q)
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>))

View File

@ -36,14 +36,13 @@ typedef void (*StatePassingKernelFuncFp16)(int B_, int L_, int H_, int P_, int N
float const* g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*H
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const half *g_mxCB_, // B*C*G*Q*Q
// const half *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const half *g_mxX_, // B*L*H*P
// const half *g_mxZ_, // B*L*H*P
// const half *g_mxXBC_, // B*L*(H*P+2*G*N)
// const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N_,
@ -53,14 +52,13 @@ typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N
float const* g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*H
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
// const bf16 *g_mxCB_, // B*C*G*Q*Q
// const bf16 *g_mxBC_, // B*L*2*G*N
// const float *g_mxD_, // H
// const bf16 *g_mxX_, // B*L*H*P
// const bf16 *g_mxZ_, // B*L*H*P
// const bf16 *g_mxXBC_, // B*L*(H*P+2*G*N)
// const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
template <int Q_, int tileH_, int warpH_, class Tp_>
@ -72,14 +70,13 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
float const* g_mxSt_, // B*C*H*N*P
// const float *g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*H
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
// const Tp_ *g_mxBC_, // B*L*2*G*N
// const Wt_ *g_mxD_, // H
// const Tp_ *g_mxX_, // B*L*H*P
// const Tp_ *g_mxZ_, // B*L*H*P
// const Tp_ *g_mxXBC_, // B*L*(H*P+2*G*N)
// const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_)
{
using namespace tensorrt_llm::common;

View File

@ -91,6 +91,7 @@ __global__ void cumsum_last_dim(T const* d_in, T* d_out, int length)
// Store items from a blocked arrangement
BlockStoreT(temp_storage.store).Store(cur_d_out, data, cur_tile_size);
__syncthreads();
}
}

View File

@ -70,8 +70,8 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
}
}
bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only)
bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape,
int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only)
{
// All tile sizes have a k_tile of 64.
@ -181,7 +181,7 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
{
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B};
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
}
else
{
@ -196,28 +196,29 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_m(const CutlassTileConfigSM90 tile)
bool supports_mcast_along_m(CutlassTileConfigSM90 const tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B,
CutlassTileConfigSM90::CtaShape256x128x128B};
return valid_tiles.count(tile) == 1;
#endif
}
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B};
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
return valid_tiles.count(tile) == 1;
#endif
}
@ -288,8 +289,8 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
}
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only)
std::vector<int> const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts,
int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only)
{
if (occupancies.size() != candidate_configs.size())

View File

@ -247,7 +247,7 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
TLLM_CHECK(hopper_input.ptr_b);
TLLM_CHECK(hopper_input.ptr_d);
const MainloopArguments mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
@ -255,12 +255,15 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
const EpilogueArguments epilogue_params
EpilogueArguments const epilogue_params
= {epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c,
reinterpret_cast<ElementD**>(hopper_input.ptr_d), hopper_input.stride_d};
typename GemmKernel::TileScheduler::Arguments scheduler_args{
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info};
mainloop_params, epilogue_params, hw_info, scheduler_args};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,

View File

@ -190,6 +190,7 @@ void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int
SHAPE_CASE(128, 64, 128)
SHAPE_CASE(128, 128, 128)
SHAPE_CASE(128, 256, 128)
SHAPE_CASE(256, 128, 128)
#undef SHAPE_CASE
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;

View File

@ -323,7 +323,7 @@ def generate_sm90_grouped_gemm_operations():
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64, 128, 256]
cta_shapes_mn = product(M_TILES, N_TILES)
cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)]
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto

View File

@ -1,2 +1,2 @@
8b0f8deb35940359b39f876fc5e94e4f libtensorrt_llm_nvrtc_wrapper.so
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
db055e58b6c6c8cf7350b66a583f9c388c4eac07 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:78209a1351f9f21f635bf9f763f4947031ea12b7526c5782094e9869b667a23f
oid sha256:c439d4074454207e5a26887a041d3e7868dd05dab30b903536bac5428758c9eb
size 1091072

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:38d470721122f47b75e91b7967fa56cebcb48c8abcc4b3ddefe4f39c85d061f8
oid sha256:ba9784dd196da1c35f9326de49b0851395f7d858131b9a60334211feaec34c52
size 3488

View File

@ -231,8 +231,11 @@ public:
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, mTileSize);
int const* maskPtr = xqaParams.spec_decoding_packed_mask;
int const* cuQSeqLens = launchParams.cu_seq_lens;
unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable ? // true for ReDrafter
xqaParams.spec_decoding_max_generation_length
: qSeqLen;
// TODO: merge SingleQueryToken params and MultiQueryTokens params into one kernelParams.
void* kernelParams[] = {&qSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
void* kernelParams[] = {&maxQSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
&launchParams.output, &xqa_q_input_ptr, &maskPtr, &launchParams.kvCacheParams, &launchParams.batch_size,
&launchParams.kv_scale_quant_orig, &launchParams.scratch};
int multi_block = 1;

View File

@ -45,12 +45,14 @@ struct XQAParams
int32_t sink_token_length = 0;
int timestep = 0;
void const* qkv_bias;
int32_t const* sequence_lengths; //
int32_t const* context_lengths; // maybe not used now
void const* alibi_slopes; // maybe not used now
int32_t const* sequence_lengths; //
int32_t const* context_lengths; // maybe not used now
void const* alibi_slopes; // maybe not used now
int32_t const* spec_decoding_packed_mask;
int const* spec_decoding_position_offsets; // rotary embedding.
int const* spec_decoding_generation_lengths; // variable input lengths.
int const* spec_decoding_position_offsets; // for position embedding.
int const* spec_decoding_generation_lengths; // variable input lengths.
bool spec_decoding_is_generation_length_variable; // whether the generation lengths actually vary
int32_t spec_decoding_max_generation_length; // max possible input length
// almost copy from GPTAttentionPluginCommon.
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.

View File

@ -482,11 +482,12 @@ __global__ void finalizeKernel(BeamHypotheses bh)
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream)
{
TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s %s start", __FILE__, __PRETTY_FUNCTION__);
int const nBM = bh.nBeamWidth;
size_t const smem_size = sizeof(int) * nBM * 2 + sizeof(float) * nBM * 2;
finalizeKernel<<<bh.nBatchSize, roundUp(nBM * 2, 32), smem_size, stream>>>(bh);
TLLM_LOG_TRACE("%s %s stop", __FILE__, __PRETTY_FUNCTION__);
}
__global__ void initializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const nMaxSeqLen)

View File

@ -197,9 +197,9 @@ __global__
|| std::is_same_v<T_, __nv_bfloat16>
#endif
>
mambaConv1dContextKernel(int B_, int L_, int D_, T_* g_mxYa_, T_* g_mxYs_, T_ const* g_mxXa_, T_ const* g_mxXs_,
T_ const* g_mxW_, T_ const* g_mxB_, bool removePadding_, bool applySilu_, int const* lastTokenIdsPtr_,
int const* stateSlotMappingPtr_ = nullptr)
mambaConv1dContextKernel(int B_, int L_, int D_, int S_pre_, int S_post_, T_* g_mxYa_, T_* g_mxYs_,
T_ const* g_mxXa_, T_ const* g_mxXs_, T_ const* g_mxW_, T_ const* g_mxB_, bool removePadding_, bool applySilu_,
int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_ = nullptr)
{
using namespace tensorrt_llm::common;
@ -262,17 +262,20 @@ __global__
int STEP = 256 * warpL_ * warpD_;
int L = L_;
int DS_ = (D_ + S_pre_ + S_post_);
int aStart = blockIdx.z * L_ * D_;
int aIStart = blockIdx.z * L_ * DS_;
int aOStart = blockIdx.z * L_ * D_;
int sStart = blockIdx.z * (K_ - 1) * D_;
int lStart = blockIdx.y * tileL_;
int dStart = blockIdx.x * tileD_;
if (removePadding_)
{
aStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
L = lastTokenIdsPtr_[blockIdx.z] - aStart;
aStart = aStart * D_;
aIStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
L = lastTokenIdsPtr_[blockIdx.z] - aIStart;
aOStart = aIStart * D_;
aIStart = aIStart * DS_;
}
else
{
@ -295,8 +298,8 @@ __global__
+ 2
* swizzle<tileD_ * 2, tileD_, T_>(
i + thread * 8 + (warpL_ * laneL * pipe_ + 1 - K_) * tileD_),
g_mxXa_ + aStart + (1 - K_ + lStart + thread * 8 / tileD_) * D_ + i * (D_ / tileD_) + dStart
+ thread * 8 % tileD_);
g_mxXa_ + aIStart + (1 - K_ + lStart + thread * 8 / tileD_) * DS_ + S_pre_ + i * (D_ / tileD_)
+ dStart + thread * 8 % tileD_);
}
else if (g_mxXs_)
{
@ -331,7 +334,7 @@ __global__
if (i + STEP <= warpL_ * laneL * tileD_ || i + thread * 8 < warpL_ * laneL * tileD_)
cp_shared_global<16>(
b_mxX + 2 * swizzle<tileD_ * 2, tileD_, T_>(i + thread * 8 + iL * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
}
else
@ -341,7 +344,7 @@ __global__
if (i + thread * 8 < (L - lStart - iL * warpL_ * laneL) * tileD_)
cp_shared_global<16>(
b_mxX + 2 * swizzle<tileD_ * 2, tileD_, T_>(i + thread * 8 + iL * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
}
@ -436,7 +439,7 @@ __global__
+ 2
* swizzle<tileD_ * 2, tileD_, T_>(
i + thread * 8 + jL % pipe_ * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
}
else if (jL < tileL_ / (warpL_ * laneL))
@ -448,14 +451,14 @@ __global__
+ 2
* swizzle<tileD_ * 2, tileD_, T_>(
i + thread * 8 + jL % pipe_ * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
}
cp_commit_group();
int offset
= aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_ + dStart + thread * 8 % tileD_;
= aOStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_ + dStart + thread * 8 % tileD_;
if (lStart + (iL + 1) * warpL_ * laneL <= L)
{
@ -501,9 +504,9 @@ __global__
}
template <int K_, int tileL_, int tileD_, int warpL_, int warpD_, int laneD_, int pipe_, typename T_>
__global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(int B_, int L_, int D_, T_* g_mxYa_,
T_* g_mxYs_, T_ const* g_mxXa_, T_ const* g_mxXs_, T_ const* g_mxW_, T_ const* g_mxB_, bool removePadding_,
bool applySilu_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_ = nullptr)
__global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(int B_, int L_, int D_, int S_pre_,
int S_post_, T_* g_mxYa_, T_* g_mxYs_, T_ const* g_mxXa_, T_ const* g_mxXs_, T_ const* g_mxW_, T_ const* g_mxB_,
bool removePadding_, bool applySilu_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_ = nullptr)
{
static_assert(laneD_ >= 1 && laneD_ <= 32 && (laneD_ & (laneD_ - 1)) == 0);
@ -542,17 +545,20 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
int STEP = 128 * warpL_ * warpD_;
int L = L_;
int DS_ = (D_ + S_pre_ + S_post_);
int aStart = blockIdx.z * L_ * D_;
int aIStart = blockIdx.z * L_ * DS_;
int aOStart = blockIdx.z * L_ * D_;
int sStart = blockIdx.z * (K_ - 1) * D_;
int lStart = blockIdx.y * tileL_;
int dStart = blockIdx.x * tileD_;
if (removePadding_)
{
aStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
L = lastTokenIdsPtr_[blockIdx.z] - aStart;
aStart = aStart * D_;
aIStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
L = lastTokenIdsPtr_[blockIdx.z] - aIStart;
aOStart = aIStart * D_;
aIStart = aIStart * DS_;
}
else
{
@ -575,8 +581,8 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
+ 4
* swizzle<tileD_ * 4, tileD_, T_>(
i + thread * 4 + (warpL_ * laneL * pipe_ + 1 - K_) * tileD_),
g_mxXa_ + aStart + (1 - K_ + lStart + thread * 4 / tileD_) * D_ + i * (D_ / tileD_) + dStart
+ thread * 4 % tileD_);
g_mxXa_ + aIStart + (1 - K_ + lStart + thread * 4 / tileD_) * DS_ + S_pre_ + i * (D_ / tileD_)
+ dStart + thread * 4 % tileD_);
}
else if (g_mxXs_)
{
@ -611,7 +617,7 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
if (i + STEP <= warpL_ * laneL * tileD_ || i + thread * 4 < warpL_ * laneL * tileD_)
cp_shared_global<16>(
b_mxX + 4 * swizzle<tileD_ * 4, tileD_, T_>(i + thread * 4 + iL * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
}
else
@ -621,7 +627,7 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
if (i + thread * 4 < (L - lStart - iL * warpL_ * laneL) * tileD_)
cp_shared_global<16>(
b_mxX + 4 * swizzle<tileD_ * 4, tileD_, T_>(i + thread * 4 + iL * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
}
@ -684,7 +690,7 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
+ 4
* swizzle<tileD_ * 4, tileD_, T_>(
i + thread * 4 + jL % pipe_ * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
}
else if (jL < tileL_ / (warpL_ * laneL))
@ -696,14 +702,14 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
+ 4
* swizzle<tileD_ * 4, tileD_, T_>(
i + thread * 4 + jL % pipe_ * warpL_ * laneL * tileD_),
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
}
cp_commit_group();
int offset
= aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_ + dStart + thread * 4 % tileD_;
= aOStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_ + dStart + thread * 4 % tileD_;
if (lStart + (iL + 1) * warpL_ * laneL <= L)
{
@ -755,6 +761,8 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
int L = params.max_seqlen;
int D = params.dim;
int K = params.dconv;
int S_pre = params.pre_stride;
int S_post = params.post_stride;
int tileL = 32;
int tileD = 128;
@ -763,9 +771,9 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
int laneD = 4;
int pipe = 4;
void (*f)(int B_, int L_, int D_, input_t* g_mxYa_, input_t* g_mxYs_, input_t const* g_mxXa_,
input_t const* g_mxXs_, input_t const* g_mxW_, input_t const* g_mxB_, bool removePadding_, bool applySilu_,
int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
void (*f)(int B_, int L_, int D_, int S_pre_, int S_post_, input_t* g_mxYa_, input_t* g_mxYs_,
input_t const* g_mxXa_, input_t const* g_mxXs_, input_t const* g_mxW_, input_t const* g_mxB_,
bool removePadding_, bool applySilu_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
if (std::is_same_v<input_t, float>)
{
@ -1071,7 +1079,7 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
cudaFuncSetAttribute(f, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem);
f<<<blks, thds, shmem, stream>>>(B, L, D, ya, ys, xa, xs, w, b, rmpd, silu, ltip, ssmp);
f<<<blks, thds, shmem, stream>>>(B, L, D, S_pre, S_post, ya, ys, xa, xs, w, b, rmpd, silu, ltip, ssmp);
}
template <typename input_t, int DCONV = 4, int CHANNELS_PER_THREAD = 4>
@ -1088,6 +1096,7 @@ __launch_bounds__(64, 8) __global__
int num_channels = params.dim;
int const micro_batch = blockIdx.y;
int const channel = (blockIdx.x * blockDim.x + threadIdx.x) * CHANNELS_PER_THREAD;
int const num_channels_in = num_channels + params.pre_stride + params.post_stride;
if (channel >= num_channels)
{
@ -1096,7 +1105,7 @@ __launch_bounds__(64, 8) __global__
weight += channel;
bias += channel;
input += channel;
input += (channel + params.pre_stride);
output += channel;
state_in += channel;
state_out += channel;
@ -1119,7 +1128,7 @@ __launch_bounds__(64, 8) __global__
++sample)
{
int const slot_idx = params.state_slot_mapping_ptr == nullptr ? sample : params.state_slot_mapping_ptr[sample];
input_t* token_input = input + sample * params.dim;
input_t* token_input = input + sample * num_channels_in;
input_t* token_output = output + sample * params.dim;
input_t* token_state_in = state_in + slot_idx * (params.dconv - 1) * params.dim;
input_t* token_state_out = state_out + slot_idx * (params.dconv - 1) * params.dim;

View File

@ -26,7 +26,7 @@ namespace kernels
struct MambaConv1dParamsBase
{
int batch, dim, max_seqlen, dconv;
int batch, dim, max_seqlen, dconv, pre_stride, post_stride;
bool remove_padding;
bool apply_silu;
void* __restrict__ in_ptr;

View File

@ -370,9 +370,8 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
float const* mxdb = (float const*) params.delta_bias_ptr;
float const* mxA = (float const*) params.A_ptr;
half* mxCB = (half*) params.CB_ptr;
half const* mxBC = (half const*) params.BC_ptr;
float const* mxD = (float const*) params.D_ptr;
half const* mxX = (half const*) params.u_ptr;
half const* mxXBC = (half const*) params.u_ptr;
half const* mxZ = (half const*) params.z_ptr;
auto rp = params.remove_padding;
@ -380,16 +379,16 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
auto ssmp = params.slot_mapping_ptr;
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, P, G, N, mxdc, mxdA, mxdt, mxdb, mxA, mxZ, rp, ltip);
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxXBC, rp, ltip);
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, H, P, G, N, mxCB, mxXBC, rp, ltip);
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxD, mxXBC, mxZ, rp, ltip);
}
else if constexpr (std::is_same_v<input_t, __nv_bfloat16>)
{
@ -413,9 +412,8 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
float const* mxdb = (float const*) params.delta_bias_ptr;
float const* mxA = (float const*) params.A_ptr;
__nv_bfloat16* mxCB = (__nv_bfloat16*) params.CB_ptr;
__nv_bfloat16 const* mxBC = (__nv_bfloat16 const*) params.BC_ptr;
float const* mxD = (float const*) params.D_ptr;
__nv_bfloat16 const* mxX = (__nv_bfloat16 const*) params.u_ptr;
__nv_bfloat16 const* mxXBC = (__nv_bfloat16 const*) params.u_ptr;
__nv_bfloat16 const* mxZ = (__nv_bfloat16 const*) params.z_ptr;
auto rp = params.remove_padding;
@ -423,16 +421,16 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
auto ssmp = params.slot_mapping_ptr;
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, P, G, N, mxdc, mxdA, mxdt, mxdb, mxA, mxZ, rp, ltip);
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxXBC, rp, ltip);
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, H, P, G, N, mxCB, mxXBC, rp, ltip);
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxD, mxXBC, mxZ, rp, ltip);
}
}
@ -458,7 +456,8 @@ INSTANTIATE_CHUNK_SCAN_DATA_TYPE(__nv_bfloat16, float);
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, bool MAMBA_V1 = true>
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, bool MAMBA_V1 = true,
int STATE_UNROLL = 16>
__launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParamsBase params)
{
@ -485,24 +484,43 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
int const head = channel / head_dim;
int const head_chl = channel % head_dim;
int const group = head / (nheads / ngroups);
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
int const bc_offset = MAMBA_V1 ? sample * (DSTATE * 2 + params.dt_rank) : sample * DSTATE * ngroups * 2;
int const b_offset = MAMBA_V1 ? params.dt_rank : DSTATE * group;
int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : DSTATE * (ngroups + group);
int const dt_d_idx = MAMBA_V1 ? channel : head;
int const bc_dim = MAMBA_V1 ? 2 * DSTATE : 2 * ngroups * DSTATE;
int const x_dim = MAMBA_V1 ? num_channels : num_channels + bc_dim;
int const z_dim = MAMBA_V1 ? num_channels : 2 * num_channels + bc_dim + nheads;
int const dt_dim = MAMBA_V1 ? num_channels : (z ? z_dim : z_dim - num_channels);
int const dt_offset = MAMBA_V1 ? sample * dt_dim : sample * dt_dim + dt_dim - nheads;
int const bc_offset = MAMBA_V1 ? sample * (bc_dim + params.dt_rank) : sample * (num_channels + bc_dim);
int const b_offset = MAMBA_V1 ? params.dt_rank : num_channels + DSTATE * group;
int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : num_channels + DSTATE * (ngroups + group);
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
input_t* my_output = &output[sample * num_channels];
float rA[DSTATE];
float rB[DSTATE];
float rC[DSTATE];
float rState[DSTATE];
float my_x, my_dt, my_z, my_dt_bias, my_D;
my_x = toFloat(x[sample * num_channels + channel]);
my_z = z ? toFloat(z[sample * num_channels + channel]) : 0.f;
int const state_loops = (DSTATE + STATE_UNROLL - 1) / STATE_UNROLL;
float my_x, my_dt, my_z, my_dt_bias, out;
my_x = toFloat(x[sample * x_dim + channel]);
my_z = z ? toFloat(z[sample * z_dim + channel]) : 0.f;
my_dt = toFloat(dt[dt_offset + dt_d_idx]);
my_dt_bias = dt_bias ? toFloat(dt_bias[dt_d_idx]) : 0.f;
out = D ? toFloat(D[dt_d_idx]) * my_x : 0.f;
float dt_b = my_dt + my_dt_bias;
float dt_b_sp = 1.0f;
if (dt_softplus)
{
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
}
if (MAMBA_V1)
{
float rA[DSTATE];
float rB[DSTATE];
float rC[DSTATE];
float rState[DSTATE];
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
@ -511,49 +529,48 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
rC[i] = toFloat(C[bc_offset + c_offset + i]);
rState[i] = toFloat(my_state[i * num_channels + channel]);
}
my_dt = toFloat(dt[sample * num_channels + channel]);
my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f;
my_D = D ? toFloat(D[channel]) : 0.f;
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
float dA = __expf(rA[i] * dt_b_sp);
float dB = rB[i] * dt_b_sp;
float sdA = rState[i] * dA;
float dBx = dB * my_x;
float newState = sdA + dBx;
// Write the new state back out to the cache
convertAndStore(&my_state[i * num_channels + channel], newState);
out += newState * rC[i];
}
}
else
{
float A_tmp = toFloat(A[head]);
#pragma unroll
for (int i = 0; i < DSTATE; i++)
float rB[STATE_UNROLL];
float rC[STATE_UNROLL];
float rState[STATE_UNROLL];
for (int si = 0; si < state_loops; si++)
{
rA[i] = A_tmp;
rB[i] = toFloat(B[bc_offset + b_offset + i]);
rC[i] = toFloat(C[bc_offset + c_offset + i]);
rState[i] = toFloat(my_state[(head * DSTATE + i) * head_dim + head_chl]);
}
my_dt = toFloat(dt[sample * nheads + head]);
my_dt_bias = dt_bias ? toFloat(dt_bias[head]) : 0.f;
my_D = D ? toFloat(D[head]) : 0.f;
}
float dt_b = my_dt + my_dt_bias;
float dt_b_sp;
if (dt_softplus)
{
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
}
float out = D ? my_D * my_x : 0.f;
int i_offset = si * STATE_UNROLL;
#pragma unroll
for (int i = 0; i < DSTATE; i++)
{
float dA = __expf(rA[i] * dt_b_sp);
float dB = rB[i] * dt_b_sp;
float sdA = rState[i] * dA;
float dBx = dB * my_x;
float newState = sdA + dBx;
// Write the new state back out to the cache
if (MAMBA_V1)
convertAndStore(&my_state[i * num_channels + channel], newState);
else
convertAndStore(&my_state[(head * DSTATE + i) * head_dim + head_chl], newState);
out += newState * rC[i];
for (int i = 0; i < STATE_UNROLL; i++)
{
rB[i] = toFloat(B[bc_offset + b_offset + i_offset + i]);
rC[i] = toFloat(C[bc_offset + c_offset + i_offset + i]);
rState[i] = toFloat(my_state[(head * DSTATE + i_offset + i) * head_dim + head_chl]);
}
#pragma unroll
for (int i = 0; i < STATE_UNROLL; i++)
{
float dA = __expf(A_tmp * dt_b_sp);
float dB = rB[i] * dt_b_sp;
float sdA = rState[i] * dA;
float dBx = dB * my_x;
float newState = sdA + dBx;
// Write the new state back out to the cache
convertAndStore(&my_state[(head * DSTATE + i_offset + i) * head_dim + head_chl], newState);
out += newState * rC[i];
}
}
}
if (z)

View File

@ -477,10 +477,10 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
params.outputTemperatures[batchIdx] = params.inputTemperatures[batchSlot];
}
// Copy random validation data.
auto const numDecodingDraftTokens = params.numPaths * (params.maxPathLength - 1);
if (isGenerationRequest)
{
// Copy random validation data.
auto const numDecodingDraftTokens = params.numPaths * (params.maxPathLength - 1);
auto outputRandomDataValidation = params.outputRandomDataValidation + genIdx * numDecodingDraftTokens;
auto const inputRandomDataValidation = params.inputRandomDataValidation + batchSlot * numDecodingDraftTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numDecodingDraftTokens;
@ -488,11 +488,8 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
{
outputRandomDataValidation[ti] = inputRandomDataValidation[ti];
}
}
// Copy draft tokens and indices
if (isGenerationRequest)
{
// Copy draft tokens and indices
auto const numUnpackedTokens = numDecodingDraftTokens + params.numPaths;
auto outputNextDraftTokens = params.outputNextDraftTokens + genIdx * numUnpackedTokens;
auto outputNextDraftIndices = params.outputNextDraftIndices + genIdx * numUnpackedTokens;
@ -504,41 +501,37 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
outputNextDraftTokens[ti] = inputNextDraftTokens[ti];
outputNextDraftIndices[ti] = inputNextDraftIndices[ti];
}
}
auto const maxGenerationLength = params.maxGenerationLength[0];
auto const maxDecodingTokens = numDecodingDraftTokens + 1;
auto const numPackedMasks = divUp(maxGenerationLength, 32);
auto const outputMaskStartId = (genIdx == 0) ? 0 : params.cumSumGenerationLengths[genIdx - 1];
auto const numTokens = (genIdx == 0)
? params.cumSumGenerationLengths[0]
: params.cumSumGenerationLengths[genIdx] - params.cumSumGenerationLengths[genIdx - 1];
// Copy packed masks.
// Masks are placed next to each other with offsets of cumSumGenerationLengths[bi-1]
if (isGenerationRequest)
{
auto const maxGenerationLength = params.maxGenerationLength[0];
auto const maxDecodingTokens = numDecodingDraftTokens + 1;
auto const numPackedMasks = divUp(maxGenerationLength, 32);
auto const outputStartId = (genIdx == 0) ? 0 : params.cumSumGenerationLengths[genIdx - 1];
auto const numTokens = (genIdx == 0)
? params.cumSumGenerationLengths[0]
: params.cumSumGenerationLengths[genIdx] - params.cumSumGenerationLengths[genIdx - 1];
// Copy packed masks.
// Masks are placed next to each other with offsets of cumSumGenerationLengths[bi-1]
auto const inputPackedMask = params.inputPackedMask + batchSlot * numPackedMasks * maxDecodingTokens;
auto outputPackedMask = params.outputPackedMask + outputMaskStartId * numPackedMasks;
auto outputPackedMask = params.outputPackedMask + outputStartId * numPackedMasks;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens * numPackedMasks;
ti += static_cast<SizeType32>(blockDim.x))
{
outputPackedMask[ti] = inputPackedMask[ti];
}
}
auto const inputPositionIds = params.inputPositionIds + batchSlot * maxDecodingTokens;
auto outputPositionIds = params.outputPositionIds + params.numContextTokens + outputStartId;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens; ti += static_cast<SizeType32>(blockDim.x))
{
outputPositionIds[ti] = inputPositionIds[ti];
}
// Copy pos offsets. Copy only for maxGenerationLength
if (isGenerationRequest)
{
// Copy pos offsets. Copy only for maxGenerationLength
auto const basePosId = params.outputPositionIdsBase[batchIdx];
auto outputPositionOffsets = params.outputPositionOffsets + genIdx * maxGenerationLength;
auto outputPositionIds = params.outputPositionIds + genIdx * maxGenerationLength;
auto const inputPositionIds = params.inputPositionIds + batchSlot * maxDecodingTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
ti += static_cast<SizeType32>(blockDim.x))
{
auto const posId = inputPositionIds[ti];
outputPositionIds[params.numContextTokens + ti] = posId;
outputPositionOffsets[ti] = posId - basePosId + 1;
outputPositionOffsets[ti] = inputPositionIds[ti] - basePosId + 1;
}
}
}

View File

@ -148,8 +148,8 @@ void BeamSearchLayer<T>::forwardAsync(
bh.nMaxBatchSize = static_cast<std::int32_t>(op->outputIdsPtr.shape[0]);
bh.nBatchSize = ip->localBatchSize;
bh.batchSlots = ip->batchSlots ? ip->batchSlots->template getPtr<SizeType32 const>() : nullptr;
bh.nBeamWidth = static_cast<std::int32_t>(op->outputIdsPtr.shape[1]);
bh.nMaxSeqLen = static_cast<std::int32_t>(op->outputIdsPtr.shape[2]);
bh.nBeamWidth = static_cast<std::int32_t>(op->outputIds.shape[1]);
bh.nMaxSeqLen = static_cast<std::int32_t>(op->outputIds.shape[2]);
bh.nVocabSize = mVocabSizePadded;
bh.diversityRates = mDiversityRateDevice;
bh.lengthPenalties = mLengthPenaltyDevice;

View File

@ -548,6 +548,8 @@ public:
tc::Tensor temperatures; // [maxBatchSize] on gpu
//! Next generation lengths.
tc::Tensor generationLengths; // [maxBatchSize] on gpu
//! Next generation lengths on host.
tc::Tensor generationLengthsHost; // [maxBatchSize] on pinned
//! Maximum number of generated tokens for the next step across whole batch
tc::Tensor maxGenLengthHost; // [1] on pinned
};

View File

@ -245,14 +245,10 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<BaseDecodingOutputs>
}
}
outputs->outputIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
static_cast<size_t>(maxSeqLen)},
idsPtrHost);
outputs->outputIdsPtr = Tensor(
MEMORY_GPU, DataType::TYPE_INT32_PTR, {static_cast<size_t>(mDecoderDomain.getBatchSize())}, idsPtrHost);
outputs->parentIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
static_cast<size_t>(maxSeqLen)},
idsPtrHost + mDecoderDomain.getBatchSize());
{static_cast<size_t>(mDecoderDomain.getBatchSize())}, idsPtrHost + mDecoderDomain.getBatchSize());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -269,13 +269,6 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
params.checkParams();
// Copy max generation length
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
mStream);
params.checkParams();
// Copy max generation length
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
@ -285,6 +278,11 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
invokeCopyProbs(params, mStream);
// Copy max generation length
cudaMemcpyAsync(outputs.generationLengthsHost.template getPtr<SizeType32>(),
outputs.generationLengths.template getPtr<SizeType32 const>(),
sizeof(SizeType32) * mDecoderDomain.getBatchSize(), cudaMemcpyDeviceToHost, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -242,6 +242,9 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
xqaParams.spec_decoding_packed_mask = generationsParams.spec_decoding_packed_mask;
xqaParams.spec_decoding_position_offsets = generationsParams.spec_decoding_position_offsets;
xqaParams.spec_decoding_generation_lengths = generationsParams.spec_decoding_generation_lengths;
xqaParams.spec_decoding_is_generation_length_variable
= generationsParams.spec_decoding_is_generation_length_variable;
xqaParams.spec_decoding_max_generation_length = generationsParams.spec_decoding_max_generation_length;
xqaParams.total_num_input_tokens = generationsParams.total_num_input_tokens;
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
@ -399,7 +402,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block,
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention, int max_distance,
bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha,
bool use_cache, bool is_spec_decoding_enabled)
bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
int32_t spec_decoding_max_generation_length)
: mLayerIdx(layer_idx)
, mNumHeads(num_heads)
, mVisionStart(vision_start)
@ -443,6 +447,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
, mFP8ContextFMHA(use_fp8_context_fmha)
, mUseKVCache(use_cache)
, mIsSpecDecodingEnabled(is_spec_decoding_enabled)
, mSpecDecodingIsGenerationLengthVariable(spec_decoding_is_generation_length_variable)
, mSpecDecodingMaxGenerationLength(spec_decoding_max_generation_length)
, mDriver(CUDADriverWrapper::getInstance())
{
// Pre-check whether FMHA is supported in order to save memory allocation.
@ -559,6 +565,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
read(d, mFP8ContextFMHA);
read(d, mUseKVCache);
read(d, mIsSpecDecodingEnabled);
read(d, mSpecDecodingIsGenerationLengthVariable);
read(d, mSpecDecodingMaxGenerationLength);
read(d, mNbMultiBlockSemaphores);
mKVCacheQuantMode = tc::QuantMode(kvCacheQuantMode);
@ -1655,7 +1663,8 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
+ sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA) + sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm)
+ sizeof(mIsSpecDecodingEnabled) + sizeof(mNbMultiBlockSemaphores)
+ sizeof(mIsSpecDecodingEnabled) + sizeof(mSpecDecodingIsGenerationLengthVariable)
+ sizeof(mSpecDecodingMaxGenerationLength) + sizeof(mNbMultiBlockSemaphores)
+ sizeof(uint32_t) // size of DecoderXQARunnerResource buffer.
+ DecoderXQARunner::getResourceGlobal()->getSerializationSize();
}
@ -1705,6 +1714,8 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
write(d, mFP8ContextFMHA);
write(d, mUseKVCache);
write(d, mIsSpecDecodingEnabled);
write(d, mSpecDecodingIsGenerationLengthVariable);
write(d, mSpecDecodingMaxGenerationLength);
write(d, mNbMultiBlockSemaphores);
// An uint32_t that specifies the size of the serialized buffer, followed by the actual content.
@ -1797,6 +1808,10 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
mPluginAttributes.emplace_back(PluginField("use_fp8_context_fmha", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("use_cache", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("is_spec_decoding_enabled", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(
PluginField("spec_decoding_is_generation_length_variable", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(
PluginField("spec_decoding_max_generation_length", nullptr, PluginFieldType::kINT32, 0));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

View File

@ -52,7 +52,8 @@ public:
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false,
int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false,
bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, bool use_cache = true,
bool is_spec_decoding_enabled = false);
bool is_spec_decoding_enabled = false, bool spec_decoding_is_generation_length_variable = false,
int32_t spec_decoding_max_generation_length = 1);
GPTAttentionPluginCommon(void const* data, size_t length);
@ -224,6 +225,8 @@ protected:
int32_t const* spec_decoding_packed_mask = nullptr;
int32_t const* spec_decoding_position_offsets = nullptr;
int32_t const* spec_decoding_generation_lengths = nullptr;
bool spec_decoding_is_generation_length_variable = false;
int32_t spec_decoding_max_generation_length = 1;
int32_t total_num_input_tokens;
};
@ -324,6 +327,8 @@ protected:
bool mFP8ContextFMHA = false;
bool mDenseContextFMHA = false;
bool mIsSpecDecodingEnabled = false;
bool mSpecDecodingIsGenerationLengthVariable = false;
int32_t mSpecDecodingMaxGenerationLength = 1;
// Speculative decoding packed mask.
uint4* mSpecDecodingPackedMask;

View File

@ -55,7 +55,8 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block,
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention, int max_distance,
bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha,
bool use_cache, bool is_spec_decoding_enabled)
bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
int spec_decoding_max_generation_length)
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
@ -63,7 +64,8 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, enable_xqa, kv_cache_quant_mode,
remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type,
max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, dense_context_fmha,
use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_spec_decoding_enabled)
use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_spec_decoding_enabled,
spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length)
{
initEntryIdx();
}
@ -789,6 +791,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
enqueue_params.spec_decoding_packed_mask = spec_decoding_packed_mask;
enqueue_params.spec_decoding_position_offsets = spec_decoding_position_offsets;
enqueue_params.spec_decoding_generation_lengths = spec_decoding_generation_lengths;
enqueue_params.spec_decoding_is_generation_length_variable = mSpecDecodingIsGenerationLengthVariable;
enqueue_params.spec_decoding_max_generation_length = mSpecDecodingMaxGenerationLength;
}
enqueue_params.total_num_input_tokens = localNbTokens;
@ -975,7 +979,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
static_cast<bool>(p.getScalar<int8_t>("use_paged_context_fmha").value()),
static_cast<bool>(p.getScalar<int8_t>("use_fp8_context_fmha").value()),
static_cast<bool>(p.getScalar<int32_t>("use_cache").value()),
static_cast<bool>(p.getScalar<int8_t>("is_spec_decoding_enabled").value()));
static_cast<bool>(p.getScalar<int8_t>("is_spec_decoding_enabled").value()),
static_cast<bool>(p.getScalar<int8_t>("spec_decoding_is_generation_length_variable").value()),
p.getScalar<int32_t>("spec_decoding_max_generation_length").value());
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -94,7 +94,8 @@ public:
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false,
int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false,
bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, bool use_cache = true,
bool is_spec_decoding_enabled = false);
bool is_spec_decoding_enabled = false, bool spec_decoding_is_generation_length_variable = false,
int spec_decoding_max_generation_length = 1);
GPTAttentionPlugin(void const* data, size_t length);

View File

@ -31,10 +31,12 @@ static char const* MAMBA_CONV1D_PLUGIN_NAME{"MambaConv1d"};
PluginFieldCollection MambaConv1dPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> MambaConv1dPluginCreator::mPluginAttributes;
MambaConv1dPlugin::MambaConv1dPlugin(
int dim, int dconv, nvinfer1::DataType type, bool removePadding, bool pagedState, bool applySilu)
MambaConv1dPlugin::MambaConv1dPlugin(int dim, int dconv, int preStride, int postStride, nvinfer1::DataType type,
bool removePadding, bool pagedState, bool applySilu)
: mDim(dim)
, mDConv(dconv)
, mPreStride(preStride)
, mPostStride(postStride)
, mType(type)
, mRemovePadding(removePadding)
, mPagedState(pagedState)
@ -52,6 +54,8 @@ MambaConv1dPlugin::MambaConv1dPlugin(void const* data, size_t length)
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mDim);
read(d, mDConv);
read(d, mPreStride);
read(d, mPostStride);
read(d, mType);
read(d, mRemovePadding);
read(d, mPagedState);
@ -65,7 +69,8 @@ MambaConv1dPlugin::MambaConv1dPlugin(void const* data, size_t length)
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* MambaConv1dPlugin::clone() const noexcept
{
auto* plugin = new MambaConv1dPlugin(mDim, mDConv, mType, mRemovePadding, mPagedState, mApplySilu);
auto* plugin
= new MambaConv1dPlugin(mDim, mDConv, mPreStride, mPostStride, mType, mRemovePadding, mPagedState, mApplySilu);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
@ -78,7 +83,9 @@ nvinfer1::DimsExprs MambaConv1dPlugin::getOutputDimensions(
{
if (outputIndex == 0)
{
return inputs[getInputTensorIdx()];
auto ret = inputs[getInputTensorIdx()];
ret.d[mRemovePadding ? 1 : 2] = exprBuilder.constant(mDim);
return ret;
}
return inputs[getConvStateIdx()];
}
@ -113,9 +120,9 @@ size_t MambaConv1dPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inp
}
void MambaConv1dPlugin::setMambaConv1dParams(tensorrt_llm::kernels::MambaConv1dParamsBase& params, const size_t batch,
const size_t dim, const size_t maxSeqLen, const size_t dconv, void const* inPtr, void const* stateInPtr,
void* stateOutPtr, void const* convWeight, void const* convBias, void* outPtr, int const* lastTokenIds,
int const* stateSlotMapping, bool removePadding, bool applySilu)
const size_t dim, const size_t maxSeqLen, const size_t dconv, const size_t preStride, const size_t postStride,
void const* inPtr, void const* stateInPtr, void* stateOutPtr, void const* convWeight, void const* convBias,
void* outPtr, int const* lastTokenIds, int const* stateSlotMapping, bool removePadding, bool applySilu)
{
// Reset the parameters
memset(&params, 0, sizeof(params));
@ -124,6 +131,8 @@ void MambaConv1dPlugin::setMambaConv1dParams(tensorrt_llm::kernels::MambaConv1dP
params.dim = dim;
params.max_seqlen = maxSeqLen;
params.dconv = dconv;
params.pre_stride = preStride;
params.post_stride = postStride;
params.remove_padding = removePadding;
params.apply_silu = applySilu;
@ -179,8 +188,8 @@ int MambaConv1dPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
void* stateOutPtr
= mPagedState ? *reinterpret_cast<void**>(const_cast<void*>(inputs[getConvStateIdx()])) : outputs[1];
setMambaConv1dParams(mambaConv1dParams, batchSize, mDim, maxSeqLen, mDConv, inputs[getInputTensorIdx()], stateInPtr,
stateOutPtr, inputs[getWeightIdx()], inputs[getBiasIdx()], outputs[0],
setMambaConv1dParams(mambaConv1dParams, batchSize, mDim, maxSeqLen, mDConv, mPreStride, mPostStride,
inputs[getInputTensorIdx()], stateInPtr, stateOutPtr, inputs[getWeightIdx()], inputs[getBiasIdx()], outputs[0],
static_cast<int const*>(inputs[getLastTokenIdsIdx()]), slotMapping, mRemovePadding, mApplySilu);
if (reqTypes[0] == RequestType::kCONTEXT)
@ -252,8 +261,8 @@ void MambaConv1dPlugin::terminate() noexcept {}
size_t MambaConv1dPlugin::getSerializationSize() const noexcept
{
return sizeof(mDim) + sizeof(mDConv) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState)
+ sizeof(mApplySilu);
return sizeof(mDim) + sizeof(mDConv) + sizeof(mPreStride) + sizeof(mPostStride) + sizeof(mType)
+ sizeof(mRemovePadding) + sizeof(mPagedState) + sizeof(mApplySilu);
}
void MambaConv1dPlugin::serialize(void* buffer) const noexcept
@ -261,6 +270,8 @@ void MambaConv1dPlugin::serialize(void* buffer) const noexcept
char *d = static_cast<char*>(buffer), *a = d;
write(d, mDim);
write(d, mDConv);
write(d, mPreStride);
write(d, mPostStride);
write(d, mType);
write(d, mRemovePadding);
write(d, mPagedState);
@ -281,6 +292,8 @@ MambaConv1dPluginCreator::MambaConv1dPluginCreator()
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("dconv", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("pre_stride", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("post_stride", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 0));
@ -307,7 +320,7 @@ PluginFieldCollection const* MambaConv1dPluginCreator::getFieldNames() noexcept
IPluginV2* MambaConv1dPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int dim, dconv;
int dim, dconv, pre_stride, post_stride;
bool removePadding;
bool pagedState;
bool applySilu;
@ -326,6 +339,16 @@ IPluginV2* MambaConv1dPluginCreator::createPlugin(char const* name, PluginFieldC
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
dconv = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "pre_stride"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
pre_stride = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "post_stride"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
post_stride = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "type_id"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
@ -349,7 +372,8 @@ IPluginV2* MambaConv1dPluginCreator::createPlugin(char const* name, PluginFieldC
}
try
{
auto* obj = new MambaConv1dPlugin(dim, dconv, type, removePadding, pagedState, applySilu);
auto* obj
= new MambaConv1dPlugin(dim, dconv, pre_stride, post_stride, type, removePadding, pagedState, applySilu);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -44,7 +44,8 @@ namespace tensorrt_llm::plugins
class MambaConv1dPlugin : public BasePlugin
{
public:
MambaConv1dPlugin(int dim, int dconv, nvinfer1::DataType type, bool removePadding, bool pagedState, bool applySilu);
MambaConv1dPlugin(int dim, int dconv, int preStride, int postStride, nvinfer1::DataType type, bool removePadding,
bool pagedState, bool applySilu);
MambaConv1dPlugin(void const* data, size_t length);
@ -132,7 +133,8 @@ private:
void setMambaConv1dParams(tensorrt_llm::kernels::MambaConv1dParamsBase& params,
// sizes
const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dconv,
const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dconv, const size_t preStride,
const size_t postStride,
// device pointers
void const* inPtr, void const* stateInPtr, void* stateOutPtr, void const* convWeight, void const* convBias,
void* outPtr, int const* lastTokenIds, int const* stateSlotMapping, bool removePadding, bool applySilu);
@ -140,6 +142,8 @@ private:
private:
int mDim;
int mDConv;
int mPreStride;
int mPostStride;
nvinfer1::DataType mType;
bool mRemovePadding = false;
bool mPagedState = false;

View File

@ -90,7 +90,16 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions(
{
if (outputIndex == 0)
{
return inputs[getInputTensorIdx()];
if (mIsMamba2)
{
auto ret = inputs[getInputTensorIdx()];
ret.d[mRemovePadding ? 1 : 2] = exprBuilder.constant(mDim);
return ret;
}
else
{
return inputs[getInputTensorIdx()];
}
}
return inputs[getStateIdx()];
}
@ -136,9 +145,9 @@ size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
int B = inputs[getLastTokenIdsIdx()].dims.d[0];
int BxL = inputs[getInputTensorIdx()].dims.d[0]; // num_tokens
int H = mNHeads;
int P = inputs[getInputTensorIdx()].dims.d[1] / H;
int P = mDim / H;
int G = mNGroups;
int N = inputs[getBCIdx()].dims.d[1] / G / 2;
int N = mDState;
int Q = mChunkSize;
int BxC = (BxL + Q - 1) / Q + B;
@ -153,9 +162,9 @@ size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
int B = inputs[getInputTensorIdx()].dims.d[0];
int L = inputs[getInputTensorIdx()].dims.d[1];
int H = mNHeads;
int P = inputs[getInputTensorIdx()].dims.d[2] / H;
int P = mDim / H;
int G = mNGroups;
int N = inputs[getBCIdx()].dims.d[2] / G / 2;
int N = mDState;
int Q = mChunkSize;
int C = (L + Q - 1) / Q;
@ -268,16 +277,16 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
float* mxdA = nullptr;
T* mxCB = nullptr;
if (!mIsMamba2) /* no workspace needed */
if (!mIsMamba2 || reqTypes[0] == RequestType::kGENERATION) /* no workspace needed */
;
else if (mRemovePadding)
{
int B = inputDesc[getLastTokenIdsIdx()].dims.d[0];
int BxL = inputDesc[getInputTensorIdx()].dims.d[0]; // num_tokens
int H = mNHeads;
int P = inputDesc[getInputTensorIdx()].dims.d[1] / H;
int P = mDim / H;
int G = mNGroups;
int N = inputDesc[getBCIdx()].dims.d[1] / G / 2;
int N = mDState;
int Q = mChunkSize;
int BxC = (BxL + Q - 1) / Q + B;
@ -292,9 +301,9 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
int B = inputDesc[getInputTensorIdx()].dims.d[0];
int L = inputDesc[getInputTensorIdx()].dims.d[1];
int H = mNHeads;
int P = inputDesc[getInputTensorIdx()].dims.d[2] / H;
int P = mDim / H;
int G = mNGroups;
int N = inputDesc[getBCIdx()].dims.d[2] / G / 2;
int N = mDState;
int Q = mChunkSize;
int C = (L + Q - 1) / Q;

View File

@ -57,10 +57,10 @@ std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor,
tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens,
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream)
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream, std::optional<RequestIdType> clientId)
{
at::Tensor atTensor = tr::Torch::tensor(tensor);
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap());
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId);
};
}
@ -112,7 +112,7 @@ void LlmRequest::initBindings(py::module_& m)
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
.def_property("max_sent_token_len", &LlmRequest::getMaxSentTokenLen, &LlmRequest::setMaxSentTokenLen)
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
.def_property_readonly("lora_task_id", &LlmRequest::getLoraTaskId)

View File

@ -126,6 +126,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
py::enum_<tr::ModelConfig::ModelVariant>(m, "GptModelVariant")
.value("GPT", tr::ModelConfig::ModelVariant::kGpt)
.value("GLM", tr::ModelConfig::ModelVariant::kGlm)
.value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm)
.value("MAMBA", tr::ModelConfig::ModelVariant::kMamba)
.value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma);

View File

@ -44,6 +44,7 @@ namespace tensorrt_llm::pybind::executor
void InitBindings(pybind11::module_& m)
{
m.attr("__version__") = tle::version();
py::enum_<tle::ModelType>(m, "ModelType")
.value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY)
.value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY)
@ -228,14 +229,16 @@ void InitBindings(pybind11::module_& m)
std::optional<SizeType32> const&, std::optional<SizeType32> const&,
std::optional<std::list<VecTokens>>, std::optional<std::list<VecTokens>>, std::optional<Tensor>,
std::optional<tle::ExternalDraftTokensConfig>, std::optional<tle::PromptTuningConfig>,
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>>(),
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>,
std::optional<IdType>, bool>(),
py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false,
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(),
py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(),
py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(),
py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(),
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none())
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(),
py::arg("client_id") = py::none(), py::arg("return_all_generated_tokens") = false)
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
.def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens)
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -254,7 +257,10 @@ void InitBindings(pybind11::module_& m)
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
&tle::Request::setLogitsPostProcessorName)
.def_property(
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds);
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
.def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId)
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
&tle::Request::setReturnAllGeneratedTokens);
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
py::class_<tle::Result>(m, "Result")

View File

@ -46,6 +46,7 @@ void ExplicitDraftTokensBuffers::Inputs::create(SizeType32 maxNumSequences, Tllm
temperatures = manager.gpu(ITensor::makeShape({maxNumSequences}), dtype);
positionIdsBase = manager.gpu(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
generationLengths = manager.gpu(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
generationLengthsHost = manager.pinned(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
randomDataSample = manager.gpu(ITensor::makeShape({maxNumSequences}), dtype);
randomDataValidation = manager.gpu(ITensor::makeShape({maxNumSequences, maxNumPaths, maxDraftPathLen}), dtype);
draftTokens = manager.gpu(ITensor::makeShape({maxNumSequences, maxNumPaths, maxPathLen}), TRTTokenIdType);
@ -270,6 +271,20 @@ void ExplicitDraftTokensBuffers::setFromInputs(SizeType32 numCtxSequences, SizeT
auto const& manager = runtime.getBufferManager();
auto const& stream = runtime.getStream();
auto const explicitDraftTokensModule = std::dynamic_pointer_cast<runtime::ExplicitDraftTokensModule const>(
modelConfig.getSpeculativeDecodingModulePtr());
auto const seqSlotsPtr = bufferCast<SizeType32>(seqSlots);
auto const generationLengthsPtr = bufferCast<SizeType32>(*draftBuffers.generationLengthsHost);
SizeType32 totalGenLengths = 0;
for (SizeType32 si = 0; si < numGenSequences; ++si)
{
auto const slot = seqSlotsPtr[numCtxSequences + si];
totalGenLengths += generationLengthsPtr[slot];
}
// Reshape position ids.
engineInputs.positionIds->reshape(ITensor::makeShape({contextPositionIds.getShape().d[0] + totalGenLengths}));
// Copy position ids -- hacky solution to avoid filling them for the context requests.
TensorPtr posIdsSlice = ITensor::slice(engineInputs.positionIds, 0, contextPositionIds.getShape().d[0]);
manager.copy(contextPositionIds, *posIdsSlice);
@ -279,9 +294,6 @@ void ExplicitDraftTokensBuffers::setFromInputs(SizeType32 numCtxSequences, SizeT
auto const numSequences = numCtxSequences + numGenSequences;
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
auto const explicitDraftTokensModule = std::dynamic_pointer_cast<runtime::ExplicitDraftTokensModule const>(
modelConfig.getSpeculativeDecodingModulePtr());
auto const dtype = modelConfig.getDataType();
switch (dtype)
@ -327,7 +339,7 @@ void ExplicitDraftTokensBuffers::insertInputTensors(
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// inputs
inputBuffers.insert_or_assign("explicit_inverted_temperature", engineInputs.temperatures);
inputBuffers.insert_or_assign("redrafter_inverted_temperature", engineInputs.temperatures);
inputBuffers.insert_or_assign("device_request_types", engineInputs.requestTypesDevice);
inputBuffers.insert_or_assign("spec_decoding_generation_lengths", engineInputs.generationLengths);

View File

@ -438,6 +438,7 @@ void prepareSpeculativeDecodingOutputs(DecodingOutput& output, std::shared_ptr<t
outputParams->randomDataValidation = tcc::toTllmTensor(*explicitDraftTokensBuffers->randomDataValidation);
outputParams->temperatures = tcc::toTllmTensor(*explicitDraftTokensBuffers->temperatures);
outputParams->generationLengths = tcc::toTllmTensor(*explicitDraftTokensBuffers->generationLengths);
outputParams->generationLengthsHost = tcc::toTllmTensor(*explicitDraftTokensBuffers->generationLengthsHost);
outputParams->maxGenLengthHost = tcc::toTllmTensor(*explicitDraftTokensBuffers->maxGenLengthHost);
}

View File

@ -340,10 +340,15 @@ GptJsonConfig parseJson(InputType&& input)
if (engineVersionNone)
{
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
if (name == std::string("chatglm_6b"))
{
modelConfig.setModelVariant(ModelConfig::ModelVariant::kChatGlm);
// kChatGlm is only for ChatGLM-6B
}
if (name == std::string("glm_10b"))
{
modelConfig.setModelVariant(ModelConfig::ModelVariant::kGlm);
// kGlm is only for ChatGLM-6B and GLM-10B
// kGlm is only for GLM-10B
}
}
else
@ -352,10 +357,15 @@ GptJsonConfig parseJson(InputType&& input)
{
auto const& pretrainedConfig = json.at("pretrained_config");
auto const chatglmVersion = pretrainedConfig.at("chatglm_version").template get<std::string>();
if (chatglmVersion == "glm" || chatglmVersion == "chatglm")
if (chatglmVersion == "chatglm")
{
modelConfig.setModelVariant(ModelConfig::ModelVariant::kChatGlm);
// kChatGlm is only for ChatGLM-6B
}
if (chatglmVersion == "glm")
{
modelConfig.setModelVariant(ModelConfig::ModelVariant::kGlm);
// kGlm is only for ChatGLM-6B and GLM-10B
// kGlm is only for GLM-10B
}
}
}
@ -368,8 +378,8 @@ GptJsonConfig parseJson(InputType&& input)
auto const& pretrainedConfig = json.at("pretrained_config");
// TODO(rkobus): adjust param names
auto const maxNumPaths = parseJsonFieldOr(pretrainedConfig, "explicit_num_beams", 0);
auto const maxDraftPathLen = parseJsonFieldOr(pretrainedConfig, "explicit_draft_len_per_beam", 0);
auto const maxNumPaths = parseJsonFieldOr(pretrainedConfig, "redrafter_num_beams", 0);
auto const maxDraftPathLen = parseJsonFieldOr(pretrainedConfig, "redrafter_draft_len_per_beam", 0);
auto const maxDraftLen = maxNumPaths * maxDraftPathLen;
auto explicitDraftTokensModule

View File

@ -31,6 +31,7 @@ namespace tc = tensorrt_llm::common;
ITensor::UniquePtr ITensor::slice(SharedPtr tensor, std::size_t offset, std::size_t size)
{
TLLM_CHECK(tensor);
return std::make_unique<TensorView>(std::move(tensor), offset, size);
}

View File

@ -308,7 +308,7 @@ void TransformerBuffers::prepareContextStep(RuntimeBuffers* runtimeBuffers, Tens
}
positionIds = manager.copyFrom(positionIdsVec, inputShape, MemoryType::kGPU);
}
else if (modelVariant == ModelConfig::ModelVariant::kGlm)
else if (modelVariant == ModelConfig::ModelVariant::kChatGlm)
{
auto const positionIdsVec = getPositionIdsContextPhaseGlm(batchSize, maxInputLength, contextLengthsHostPtr,
modelConfig.useGptAttentionPlugin(), modelConfig.usePackedInput());
@ -578,7 +578,7 @@ void TransformerBuffers::prepareNextStep(RuntimeBuffers* runtimeBuffers, SizeTyp
manager.copy(*contextLengthsDevice, *positionIds);
kernels::invokeAdd(*positionIds, step, stream);
}
else if (modelVariant == ModelConfig::ModelVariant::kGlm)
else if (modelVariant == ModelConfig::ModelVariant::kChatGlm)
{
auto const positionIdsVec = getPositionIdsGenerationPhaseGlm(batchSize, beamWidth, step,
contextLengthsHostPtr, modelConfig.useGptAttentionPlugin(), modelConfig.usePackedInput());

View File

@ -35,8 +35,10 @@ find_library_create_target(nvonnxparser ${ONNX_PARSER_LIB_NAME} SHARED
include_directories(
${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include
${PROJECT_SOURCE_DIR}/include ${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include)
${PROJECT_SOURCE_DIR}/include
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${PROJECT_SOURCE_DIR}/tests/batch_manager)
set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
@ -77,6 +79,7 @@ add_gtest(transposeKVKernelTest runtime/transposeKVKernelTest.cpp)
add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp)
add_gtest(gptDecoderBatchTest runtime/gptDecoderBatchTest.cpp)
add_gtest(gptSessionTest runtime/gptSessionTest.cpp)
target_link_libraries(gptSessionTest PRIVATE modelSpecStatic)
add_gtest(allocatorTest common/allocatorTest.cpp)
add_gtest(memoryUtilsTest common/memoryUtilsTest.cu)
if(ENABLE_MULTI_DEVICE EQUAL 1)
@ -87,6 +90,8 @@ add_gtest(stringUtilsTest common/stringUtilsTest.cpp)
add_gtest(tllmExceptionTest common/tllmExceptionTest.cpp)
add_gtest(tensorTest common/tensorTest.cpp)
add_gtest(stlUtilsTest common/stlUtilsTest.cpp)
add_gtest(cudaProfilerUtilsTest common/cudaProfilerUtilsTest.cpp)
add_gtest(timestampUtilsTest common/timestampUtilsTest.cpp)
add_gtest(tllmRuntimeTest runtime/tllmRuntimeTest.cpp)
add_gtest(tllmBuffersTest runtime/tllmBuffersTest.cpp)
add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp)

View File

@ -58,6 +58,7 @@ PYTHONPATH=examples/gptj:$PYTHONPATH python3 cpp/tests/resources/scripts/build_g
PYTHONPATH=examples/llama:$PYTHONPATH python3 cpp/tests/resources/scripts/build_llama_engines.py
PYTHONPATH=examples/chatglm:$PYTHONPATH python3 cpp/tests/resources/scripts/build_chatglm_engines.py
PYTHONPATH=examples/medusa:$PYTHONPATH python3 cpp/tests/resources/scripts/build_medusa_engines.py
PYTHONPATH=examples/redrafter:$PYTHONPATH python3 cpp/tests/resources/scripts/build_redrafter_engines.py --has_tllm_checkpoint
```
It is possible to build engines with tensor and pipeline parallelism for LLaMA using 4 GPUs.
@ -76,6 +77,7 @@ PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_exp
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_llama_output.py
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm_output.py
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_medusa_output.py
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_redrafter_output.py
```
#### Generate data with tensor and pipeline parallelism

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <stdlib.h>
#include "tensorrt_llm/common/cudaProfilerUtils.h"
#ifdef _WIN32
int setenv(char const* name, char const* value, int overwrite)
{
int errcode = 0;
if (!overwrite)
{
size_t envsize = 0;
errcode = getenv_s(&envsize, NULL, 0, name);
if (errcode || envsize)
return errcode;
}
return _putenv_s(name, value);
}
#endif
std::string kEnvVarName{"TLLM_PROFILE_START_STOP"};
std::string kLegacyEnvVarName{"TLLM_GPTM_PROFILE_START_STOP"};
struct TestCase
{
std::optional<std::string> legacyEnvVarVal;
std::string envVarVal;
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> result;
};
TEST(CudaProfilerUtils, populateIterationIndexes)
{
std::vector<TestCase> testCases;
testCases.emplace_back(TestCase{std::nullopt, "", {{}, {}}});
testCases.emplace_back(TestCase{std::nullopt, "1", {{1}, {1}}});
testCases.emplace_back(TestCase{std::nullopt, "1,2,3", {{1, 2, 3}, {1, 2, 3}}});
testCases.emplace_back(TestCase{std::nullopt, "1-4,7-8", {{1, 7}, {4, 8}}});
testCases.emplace_back(TestCase{std::nullopt, "1,2,10-15", {{1, 2, 10}, {1, 2, 15}}});
testCases.emplace_back(TestCase{std::nullopt, "1,,10-15", {{1, 10}, {1, 15}}});
// Only legacy env var set
testCases.emplace_back(TestCase{"1-4,7-8", "", {{1, 7}, {4, 8}}});
// Both set, non-legacy has priority
testCases.emplace_back(TestCase{"1-4,7-8", "2-10,88-99", {{2, 88}, {10, 99}}});
for (auto const& testCase : testCases)
{
auto ret = setenv(kEnvVarName.c_str(), testCase.envVarVal.c_str(), 1); // does overwrite
EXPECT_EQ(ret, 0);
ret = setenv(
kLegacyEnvVarName.c_str(), testCase.legacyEnvVarVal.value_or(std::string()).c_str(), 1); // does overwrite
auto const [profileIterIdxs, stopIterIdxs]
= tensorrt_llm::common::populateIterationIndexes(kEnvVarName, kLegacyEnvVarName);
EXPECT_EQ(profileIterIdxs, testCase.result.first)
<< testCase.envVarVal << " " << testCase.legacyEnvVarVal.value_or("");
EXPECT_EQ(stopIterIdxs, testCase.result.second)
<< testCase.envVarVal << " " << testCase.legacyEnvVarVal.value_or("");
}
}

View File

@ -77,3 +77,25 @@ TEST(StringUtil, FormatFixedDecimals)
EXPECT_EQ(prefix, formatFixed(num));
}
}
TEST(StringUtil, str2set)
{
{
char delimiter{','};
std::vector<std::string> inputs{
{"apple,car,dog"}, {",apple,car,dog"}, {"apple,car,dog"}, {"apple,car,dog,,"}, {"apple,,,car,dog,"}};
for (auto const& input : inputs)
{
auto out = str2set(input, delimiter);
EXPECT_EQ(out.size(), 3);
EXPECT_EQ(out.count("apple"), 1);
EXPECT_EQ(out.count("car"), 1);
EXPECT_EQ(out.count("dog"), 1);
EXPECT_EQ(out.count("cat"), 0);
}
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "tensorrt_llm/common/timestampUtils.h"
#include <chrono>
#include <sstream>
#include <thread>
using namespace tensorrt_llm::common;
TEST(TimestampUtils, getCurrentTimestamp)
{
int32_t sleepMs = 100;
int32_t sleepUs = sleepMs * 1000;
;
int32_t tolUs = 5000;
auto timestamp = getCurrentTimestamp();
std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs));
auto timestamp2 = getCurrentTimestamp();
auto microseconds = std::stoi(timestamp.erase(0, timestamp.find('.') + 1));
std::cout << microseconds << std::endl;
auto microseconds2 = std::stoi(timestamp2.erase(0, timestamp2.find('.') + 1));
int32_t delta = (microseconds2 - microseconds);
if (delta < 0)
{
delta += 1000000;
}
EXPECT_NEAR(delta, sleepUs, tolUs) << "delta: " << delta << " expected " << sleepUs << std::endl;
}

View File

@ -710,6 +710,9 @@ void ExplicitDraftTokensLayerTest<T>::allocateBuffers()
mOutputGenerationLengths
= BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mOutputGenerationLengthsHost
= BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mMaxGenLengthHost = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
// inputs
@ -1012,6 +1015,8 @@ std::shared_ptr<ExplicitDraftTokensOutputs> ExplicitDraftTokensLayerTest<T>::cre
outputParams->generationLengths = tcc::toTllmTensor(*mOutputGenerationLengths);
outputParams->generationLengthsHost = tcc::toTllmTensor(*mOutputGenerationLengthsHost);
outputParams->maxGenLengthHost = tcc::toTllmTensor(*mMaxGenLengthHost);
return outputParams;

View File

@ -273,6 +273,7 @@ private:
TensorPtr mOutputDraftProbs;
TensorPtr mOutputTemperatures;
TensorPtr mOutputGenerationLengths;
TensorPtr mOutputGenerationLengthsHost;
TensorPtr mMaxGenLengthHost;
// inputs

Binary file not shown.

Binary file not shown.

View File

@ -21,7 +21,12 @@ import sys
import typing
from pathlib import Path
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
resources_dir = Path(__file__).parent.resolve().parent
model_dir = resources_dir / "models"
@ -42,7 +47,7 @@ def convert_ckpt(model_dir: str, output_dir: str, world_size: int):
def build_engine(ckpt_dir: str,
engine_dir: str,
is_ifb: bool = False,
is_chatglm_6b: bool = False):
is_chatglm_6b_or_glm_10b: bool = False):
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={ckpt_dir}",
@ -70,8 +75,8 @@ def build_engine(ckpt_dir: str,
"--paged_kv_cache=disable",
])
if is_chatglm_6b:
print("Disable Context FMHA for ChatGLM-6B")
if is_chatglm_6b_or_glm_10b:
print("Disable Context FMHA for ChatGLM-6B and GLM-10B")
build_cmd.extend(
["--context_fmha=disable", "--context_fmha_fp32_acc=disable"])
@ -81,7 +86,8 @@ def build_engine(ckpt_dir: str,
def build_engines(model_cache: typing.Optional[str] = None,
world_size: int = 1):
for model_name in ["chatglm-6b", "chatglm2-6b", "chatglm3-6b"]:
for model_name in ["chatglm-6b", "chatglm2-6b", "chatglm3-6b", "glm-10b"]:
is_chatglm_6b_or_glm_10b = model_name in ["chatglm-6b", "glm-10b"]
if model_cache and (Path(model_cache) / model_name).is_dir():
model_cache_dir = Path(model_cache) / model_name
if bCopyModel or model_name == "chatglm-6b":
@ -101,15 +107,16 @@ def build_engines(model_cache: typing.Optional[str] = None,
hf_dir = Path(model_cache)
else:
print("Clone model from HF")
hf_dir = model_dir / model_name
run_command(
[
"git", "clone",
f"https://huggingface.co/THUDM/{model_name}", model_name
],
cwd=model_dir,
)
if not hf_dir.is_dir():
print("Clone model from HF")
run_command(
[
"git", "clone",
f"https://huggingface.co/THUDM/{model_name}", model_name
],
cwd=model_dir,
)
# Build engines
print(f"Building {model_name}")
@ -125,14 +132,25 @@ def build_engines(model_cache: typing.Optional[str] = None,
convert_ckpt(hf_dir, ckpt_dir, world_size)
for engine_kind in ["fp16-plugin", "fp16-plugin-packed-paged"]:
engine_dir = Path(
model_dir
) / "rt_engine" / model_name / engine_kind / "tp1-pp1-gpu"
engine_dir.mkdir(parents=True, exist_ok=True)
model_spec_obj = model_spec.ModelSpec('input_tokens.npy',
_tb.DataType.HALF)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
model_spec_obj.use_gpt_plugin()
engine_dir = Path(
model_dir
) / "rt_engine" / model_name / model_spec_obj.get_model_path(
) / "tp1-pp1-gpu"
engine_dir.mkdir(parents=True, exist_ok=True)
build_engine(ckpt_dir, engine_dir, False, is_chatglm_6b_or_glm_10b)
build_engine(ckpt_dir, engine_dir, "paged" in engine_kind,
model_name == "chatglm-6b")
model_spec_obj.use_packed_input()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
engine_dir = Path(
model_dir
) / "rt_engine" / model_name / model_spec_obj.get_model_path(
) / "tp1-pp1-gpu"
engine_dir.mkdir(parents=True, exist_ok=True)
build_engine(ckpt_dir, engine_dir, True, is_chatglm_6b_or_glm_10b)
print("Done")

View File

@ -62,3 +62,37 @@ def wincopy(source: str, dest: str, isdir: bool, cwd=None) -> None:
raise _sp.CalledProcessError(returncode=result.returncode,
cmd=copy_cmd,
output=result.stderr)
# Helper function to locate model_spec module.
def init_model_spec_module():
import os
# Rely on unique built model_spec to locate the module.
cpp_root_dir = _pl.Path(__file__).parent.resolve().parent.parent.parent
found_locations = []
def find_model_spec_module(directory, found_locations):
for root, d, files in os.walk(directory):
for item in files:
if item == 'model_spec.so':
found_locations.append(root)
for d in os.listdir(cpp_root_dir):
if d.startswith("build"):
find_model_spec_module(os.path.join(cpp_root_dir, d),
found_locations)
if len(found_locations) == 0:
# In CI package, model_spec module is copied to its source directory.
find_model_spec_module(
os.path.join(cpp_root_dir, 'tests', 'batch_manager'),
found_locations)
assert len(
found_locations
) == 1, f'Can\'t uniquely locate model_spec module, found {found_locations}'
import sys
sys.path.insert(0, found_locations[0])

View File

@ -21,7 +21,12 @@ import sys
from pathlib import Path
from typing import Optional
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
def convert_ckpt(model_dir: str,
@ -44,6 +49,9 @@ def build_engine(
max_input_len: int = 256,
max_seq_len: int = 384,
):
if os.path.exists(engine_dir):
assert False
build_cmd = [
"trtllm-build",
'--log_level=error',
@ -139,11 +147,14 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
world_size=tp_size,
dtype='float32')
input_file = 'input_tokens.npy'
print("\nBuilding fp32 engines")
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FLOAT)
build_engine(str(fp32_ckpt_dir),
str(engine_dir / 'fp32-default' / tp_pp_dir))
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir))
model_spec_obj.use_gpt_plugin()
build_engine(str(fp32_ckpt_dir),
str(engine_dir / 'fp32-plugin' / tp_pp_dir),
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
'--gpt_attention_plugin=float32', '--context_fmha=enable',
'--context_fmha_fp32_acc=enable')
@ -155,13 +166,16 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
dtype='float16')
print("\nBuilding fp16 engines")
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-default' / tp_pp_dir))
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir))
model_spec_obj.use_gpt_plugin()
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin' / tp_pp_dir),
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
'--gpt_attention_plugin=float16')
model_spec_obj.use_packed_input()
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed' / tp_pp_dir),
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
'--gpt_attention_plugin=float16',
'--remove_input_padding=enable')
@ -175,30 +189,57 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
'--max_num_tokens=10000',
'--use_paged_context_fmha=enable',
]
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir),
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
*ifb_args)
model_spec_current = model_spec_obj.__copy__()
max_draft_tokens = 5
model_spec_current.use_draft_tokens_external_decoding()
model_spec_current.set_draft_tokens(max_draft_tokens)
build_engine(
str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged-draft-tokens' / tp_pp_dir),
'--max_draft_len=5',
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
f'--max_draft_len={max_draft_tokens}',
'--speculative_decoding_mode=draft_tokens_external', *ifb_args)
model_spec_current = model_spec_obj.__copy__()
model_spec_current.use_multiple_profiles()
build_engine(
str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged-nprofiles' / tp_pp_dir),
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
'--multiple_profiles=enable', *ifb_args)
model_spec_current = model_spec_obj.__copy__()
max_input_len = 128
model_spec_current.set_max_input_length(max_input_len)
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged-in128' / tp_pp_dir),
str(engine_dir / model_spec_current.get_model_path() /
tp_pp_dir),
*ifb_args,
max_input_len=128)
max_input_len=max_input_len)
# Build the target model with return accepted token logits
# Build with '--max_draft_len', '--speculative_decoding_mode' and '--gather_generation_logits'
model_spec_current = model_spec_obj.__copy__()
max_draft_len = 5
model_spec_current.use_draft_tokens_external_decoding()
model_spec_current.set_draft_tokens(max_draft_len)
model_spec_current.gather_logits()
model_spec_current.return_accepted_tokens_logits()
build_engine(
str(fp16_ckpt_dir),
str(engine_dir /
'fp16-plugin-packed-paged-return-accepted-tokens-logits' /
tp_pp_dir), '--max_draft_len=5',
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
f'--max_draft_len={max_draft_len}',
'--speculative_decoding_mode=draft_tokens_external',
'--gather_generation_logits', *ifb_args)
@ -206,22 +247,31 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
# to extract logits from python runtime and uses context FMHA for generation to match draft model executions,
# which uses context FMHA for draft tokens prediction.
# Currently the gather_all_token_logits is not supported with target model of speculative decoding
model_spec_current = model_spec_obj.__copy__()
model_spec_current.gather_logits()
build_engine(
str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged-gather' / tp_pp_dir),
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
'--gather_all_token_logits', *ifb_args)
model_spec_current = model_spec_obj.__copy__()
model_spec_current.use_look_ahead_decoding()
max_draft_len = 64
model_spec_current.set_draft_tokens(max_draft_len)
build_engine(
str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged-la-decoding' / tp_pp_dir),
'--max_draft_len=64', '--speculative_decoding_mode=lookahead_decoding',
*ifb_args)
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
f'--max_draft_len={max_draft_len}',
'--speculative_decoding_mode=lookahead_decoding', *ifb_args)
# build engine with lora enabled
build_engine(str(fp16_ckpt_dir),
str(engine_dir / "fp16-plugin-packed-paged-lora" / tp_pp_dir),
"--lora_target_modules=attn_qkv", '--lora_plugin=float16',
*ifb_args)
model_spec_current = model_spec_obj.__copy__()
model_spec_current.use_lora_plugin()
build_engine(
str(fp16_ckpt_dir),
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
"--lora_target_modules=attn_qkv", '--lora_plugin=float16', *ifb_args)
if model_cache:
llm_datasets_root = Path(model_cache) / "datasets"
@ -238,9 +288,16 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
dtype='float16')
print("\nBuilding fp16 SQ engines")
build_engine(str(fp16_sq_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed-paged-sq' / tp_pp_dir),
*ifb_args)
model_spec_current = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
model_spec_current.use_gpt_plugin()
model_spec_current.use_packed_input()
model_spec_current.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_current.set_quant_method(model_spec.QuantMethod.SMOOTH_QUANT)
build_engine(
str(fp16_sq_ckpt_dir),
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
*ifb_args)
if has_safetensor:
Path(str(safetensor_file) + ".bak").rename(safetensor_file)

View File

@ -21,7 +21,12 @@ import platform as _pf
import sys as _sys
import typing as _tp
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
def get_ckpt_without_quatization(model_dir, output_dir):
@ -114,6 +119,7 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False):
tp_size = 1
pp_size = 1
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu"
input_file = 'input_tokens.npy'
if only_fp8:
# with ifb, new plugin
@ -125,7 +131,12 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False):
# str(_pl.Path(model_cache) / 'fp8-quantized-modelopt' / 'gptj_tp1_rank0.npz')
fp8_ckpt_path = engine_dir / 'fp8' / tp_pp_dir
get_ckpt_with_modelopt_quant(hf_dir, fp8_ckpt_path, model_cache)
build_engine(fp8_ckpt_path, engine_dir / 'fp8-plugin' / tp_pp_dir,
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FP8)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
build_engine(fp8_ckpt_path,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--gpt_attention_plugin=float16',
'--paged_kv_cache=enable', '--remove_input_padding=enable',
"--context_fmha=disable")
@ -133,21 +144,28 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False):
fp16_ckpt_path = engine_dir / 'fp16' / tp_pp_dir
get_ckpt_without_quatization(hf_dir, fp16_ckpt_path)
print("\nBuilding fp16-plugin engine")
build_engine(fp16_ckpt_path, engine_dir / 'fp16-plugin' / tp_pp_dir,
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
build_engine(fp16_ckpt_path,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--gpt_attention_plugin=float16',
'--paged_kv_cache=disable',
'--remove_input_padding=disable', "--context_fmha=disable")
print("\nBuilding fp16-plugin-packed engine")
model_spec_obj.use_packed_input()
build_engine(fp16_ckpt_path,
engine_dir / 'fp16-plugin-packed' / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--gpt_attention_plugin=float16',
'--paged_kv_cache=disable',
'--remove_input_padding=enable', "--context_fmha=disable")
print("\nBuilding fp16-plugin-packed-paged engine")
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
build_engine(fp16_ckpt_path,
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--gpt_attention_plugin=float16',
'--paged_kv_cache=enable', '--remove_input_padding=enable',
"--context_fmha=disable")

View File

@ -19,7 +19,12 @@ import pathlib as _pl
import platform as _pf
import sys as _sys
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, *args):
@ -61,7 +66,7 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
if model_cache:
print("Copy model from model_cache")
model_cache_dir = _pl.Path(model_cache) / 'llama-models' / model_name
assert (model_cache_dir.is_dir())
assert (model_cache_dir.is_dir()), model_cache_dir
if _pf.system() == "Windows":
wincopy(source=str(model_cache_dir),
@ -77,14 +82,22 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
engine_dir = models_dir / 'rt_engine' / model_name
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
tp_pp_sizes = [(1, 1)]
if only_multi_gpu:
tp_pp_sizes = [(1, 4), (4, 1), (2, 2)]
for tp_size, pp_size in tp_pp_sizes:
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu"
print(f"\nBuilding fp16 tp{tp_size} pp{pp_size} engine")
model_spec_obj.use_tensor_parallelism(tp_size)
model_spec_obj.use_pipeline_parallelism(pp_size)
build_engine(hf_dir,
engine_dir / f'fp16-plugin-packed-paged/{tp_pp_dir}',
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
f'--tp_size={tp_size}', f'--pp_size={pp_size}')
print("Done.")

View File

@ -21,7 +21,12 @@ import platform as _pf
import sys as _sys
import typing as _tp
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
def build_engine(weight_dir: _pl.Path, ckpt_dir: _pl.Path, engine_dir: _pl.Path,
@ -106,23 +111,30 @@ def build_engines(model_cache: _tp.Optional[str] = None):
ckpt_dir = models_dir / 'rt_ckpt' / model_name
engine_dir = models_dir / 'rt_engine' / model_name
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
model_spec_obj.use_tensor_parallelism(tp_size)
model_spec_obj.use_pipeline_parallelism(pp_size)
print("\nBuilding fp16 engine")
build_engine(hf_dir, ckpt_dir / 'fp16-default' / tp_pp_dir,
engine_dir / 'fp16-default' / tp_pp_dir,
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--remove_input_padding=disable', '--paged_state=disable',
'--mamba_conv1d_plugin=disable')
print("\nBuilding fp16-plugin engine")
build_engine(hf_dir, ckpt_dir / 'fp16-plugin' / tp_pp_dir,
engine_dir / 'fp16-plugin' / tp_pp_dir,
model_spec_obj.use_mamba_plugin()
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--remove_input_padding=disable', '--paged_state=disable')
print("\nBuilding fp16-plugin-packed engine")
build_engine(hf_dir, ckpt_dir / 'fp16-plugin-packed' / tp_pp_dir,
engine_dir / 'fp16-plugin-packed' / tp_pp_dir,
model_spec_obj.use_packed_input()
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--remove_input_padding=enable', '--paged_state=disable')
print("\nBuilding fp16-plugin-packed-paged engine")
build_engine(hf_dir, ckpt_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--remove_input_padding=enable', '--paged_state=enable')
print("Done.")

View File

@ -19,7 +19,12 @@ import pathlib as _pl
import platform as _pf
import sys as _sys
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
def build_engine(weight_dir: _pl.Path, medusa_dir: _pl.Path,
@ -60,8 +65,8 @@ def build_engines(model_cache: str):
print("Copy model from model_cache")
model_cache_dir = _pl.Path(model_cache) / model_name
medusa_cache_dir = _pl.Path(model_cache) / medusa_name
assert model_cache_dir.is_dir()
assert medusa_cache_dir.is_dir()
assert model_cache_dir.is_dir(), model_cache_dir
assert medusa_cache_dir.is_dir(), model_cache_dir
if _pf.system() == "Windows":
wincopy(source=str(model_cache_dir),
@ -85,9 +90,15 @@ def build_engines(model_cache: str):
engine_dir = models_dir / 'rt_engine' / model_name
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
model_spec_obj.use_medusa()
print(f"\nBuilding fp16 engine")
build_engine(model_dir, medusa_dir,
engine_dir / 'fp16-plugin-packed-paged/tp1-pp1-gpu')
engine_dir / model_spec_obj.get_model_path() / 'tp1-pp1-gpu')
print("Done.")

View File

@ -21,7 +21,12 @@ import platform as _pf
import sys as _sys
import typing as _tp
from build_engines_utils import run_command, wincopy
from build_engines_utils import init_model_spec_module, run_command, wincopy
init_model_spec_module()
import model_spec
import tensorrt_llm.bindings as _tb
def build_engine(weight_dir: _pl.Path, ckpt_dir: _pl.Path, engine_dir: _pl.Path,
@ -105,10 +110,15 @@ def build_engines(model_cache: _tp.Optional[str] = None):
run_command([python_exe, "-m", "pip", "install", "transformers>=4.40.0"],
env=_os.environ,
timeout=300)
input_file = 'input_tokens.npy'
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.use_packed_input()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
print("\nBuilding fp16-plugin-packed-paged engine")
build_engine(hf_dir, ckpt_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
'--remove_input_padding=enable', '--paged_state=enable')
# Restore transformers version

View File

@ -0,0 +1,109 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse as _arg
import pathlib as _pl
import platform as _pf
from build_engines_utils import run_command, wincopy
def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path,
has_tllm_checkpoint: bool):
if not has_tllm_checkpoint:
raise RuntimeError(
'Convert checkpoint is not supported for ReDrafter. '
'Provide a path that contains a checkpoint in the tllm_ckpt folder and set --has_tllm_checkpoint flag'
)
else:
checkpoint_dir = weight_dir / 'tllm_ckpt'
build_args = ["trtllm-build"] + (
['--checkpoint_dir', str(checkpoint_dir)] if engine_dir else []) + [
'--output_dir',
str(engine_dir),
'--gemm_plugin=float16',
'--max_batch_size=8',
'--max_input_len=64',
'--max_seq_len=1024',
'--log_level=error',
'--paged_kv_cache=enable',
'--remove_input_padding=enable',
'--speculative_decoding_mode=explicit_draft_tokens',
]
run_command(build_args)
def build_engines(model_cache: str, has_tllm_checkpoint: bool):
resources_dir = _pl.Path(__file__).parent.resolve().parent
models_dir = resources_dir / 'models'
model_name = 'vicuna_redrafter'
if has_tllm_checkpoint:
base_model_name = 'vicuna-7b-v1.3'
# FIXME(nkorobov): rename folder in the cache
# model_name = 'redrafter-vicuna-7b-v1.3'
if model_cache:
print("Copy model from model_cache")
model_cache_dir = _pl.Path(model_cache) / model_name
assert model_cache_dir.is_dir()
if has_tllm_checkpoint:
base_model_cache_dir = _pl.Path(model_cache) / base_model_name
assert base_model_cache_dir.is_dir()
if _pf.system() == "Windows":
wincopy(source=str(model_cache_dir),
dest=model_name,
isdir=True,
cwd=models_dir)
if has_tllm_checkpoint:
wincopy(source=str(base_model_cache_dir),
dest=base_model_name,
isdir=True,
cwd=models_dir)
else:
run_command(
["rsync", "-av", str(model_cache_dir), "."], cwd=models_dir)
if has_tllm_checkpoint:
run_command(["rsync", "-av",
str(base_model_cache_dir), "."],
cwd=models_dir)
model_dir = models_dir / model_name
assert model_dir.is_dir()
engine_dir = models_dir / 'rt_engine' / model_name
print(f"\nBuilding fp16 engine")
build_engine(model_dir, engine_dir / 'fp16-plugin-packed-paged/tp1-pp1-gpu',
has_tllm_checkpoint)
print("Done.")
if __name__ == "__main__":
parser = _arg.ArgumentParser()
parser.add_argument("--model_cache",
type=str,
help="Directory where models are stored")
parser.add_argument(
"--has_tllm_checkpoint",
action='store_true',
help="True if the provided path contains the trt-llm checkpoint.")
build_engines(**vars(parser.parse_args()))

View File

@ -18,6 +18,14 @@ from pathlib import Path
import numpy as np
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
resources_dir = Path(__file__).parent.resolve().parent
model_path = resources_dir / "models"
@ -35,18 +43,30 @@ def generate_output(
tp_size = 1
pp_size = 1
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu/"
input_file = f"input_tokens_{model_name}.npy"
data_input_file_name = resources_dir / "data" / f"input_tokens_{model_name}.npy"
data_input_file_name = resources_dir / "data" / input_file
if num_beams == 1:
output_dir = resources_dir / "data" / model_name / "sampling"
else:
output_dir = resources_dir / "data" / model_name / f"beam_search_{num_beams}"
output_dir.mkdir(exist_ok=True, parents=True)
for engine_kind in ["fp16-plugin", "fp16-plugin-packed-paged"]:
engine_dir = model_path / 'rt_engine' / model_name / engine_kind / tp_pp_dir
output_npy_file_name = output_dir / f"output_tokens_{engine_kind.replace('-', '_')}_tp{tp_size}_pp{pp_size}.npy"
output_csv_file_name = output_dir / f"output_tokens_{engine_kind.replace('-', '_')}_tp{tp_size}_pp{pp_size}.csv"
model_spec_obj_list = [
model_spec.ModelSpec(
input_file, _tb.DataType.HALF).use_gpt_plugin().set_kv_cache_type(
model_spec.KVCacheType.CONTINUOUS),
model_spec.ModelSpec(input_file, _tb.DataType.HALF).use_gpt_plugin().
use_packed_input().set_kv_cache_type(model_spec.KVCacheType.PAGED),
]
for model_spec_obj in model_spec_obj_list:
engine_dir = model_path / 'rt_engine' / model_name / model_spec_obj.get_model_path(
) / tp_pp_dir
base_output_name = os.path.splitext(
model_spec_obj.get_results_file())[0]
output_npy_file_name = output_dir / f'{base_output_name}.npy'
output_csv_file_name = output_dir / f'{base_output_name}.csv'
args_list = [
'--engine_dir',
@ -85,8 +105,13 @@ def generate_output(
data = np.load(str(output_npy_file_name))
if model_name == 'chatglm-6b':
data[data == 3] = 130005
else:
elif model_name == 'chatglm2-6b' or model_name == 'chatglm3-6b':
data[data == 0] = 2
elif model_name == 'glm-10b':
data[data == 50256] = 50258
else:
raise NameError('bad model name')
np.save(str(output_npy_file_name), data)
@ -97,4 +122,5 @@ if __name__ == '__main__':
generate_output(model_name='chatglm2-6b', num_beams=2)
generate_output(model_name='chatglm3-6b', num_beams=1)
generate_output(model_name='chatglm3-6b', num_beams=2)
generate_output(model_name='glm-10b', num_beams=1)
print("Done")

View File

@ -17,12 +17,21 @@
from pathlib import Path
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
def generate_output(engine: str,
num_beams: int,
input_name: str,
output_name: str,
model_spec_obj: model_spec.ModelSpec,
max_output_len: int = 8,
output_logits: bool = False,
output_cum_log_probs: bool = False,
@ -36,106 +45,140 @@ def generate_output(engine: str,
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
data_dir = resources_dir / 'data'
input_file = data_dir / (input_name + '.npy')
input_file = data_dir / input_name
model_data_dir = data_dir / model
if num_beams <= 1:
output_dir = model_data_dir / 'sampling'
else:
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
model_spec_obj.use_tensor_parallelism(tp_size).use_pipeline_parallelism(
pp_size)
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
output_logits_npy = None
if output_logits:
output_logits_npy = str(output_dir / (output_name + '_logits' + '.npy'))
logits_file = base_output_name + '_logits.npy'
output_logits_npy = str(output_dir / logits_file)
results_file = str(output_dir / (base_output_name + '.npy'))
results_csv = str(output_dir / (base_output_name + '.csv'))
args_list = [
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(models_dir / model), '--output_npy',
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(models_dir / model), '--output_npy', results_file, '--output_csv',
results_csv, '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--output_logits_npy',
str(output_logits_npy), '--use_py_session'
]
output_cum_log_probs_npy = None
if output_cum_log_probs:
output_cum_log_probs_npy = str(
output_dir / (output_name + '_cum_log_probs' + '.npy'))
args_list.extend(
['--output_cum_log_probs_npy',
str(output_cum_log_probs_npy)])
assert not os.path.exists(results_file) and not os.path.exists(results_csv)
if output_cum_log_probs:
args_list.extend([
'--output_cum_log_probs_npy',
f'{output_dir / model_spec_obj.get_cum_log_probs_file()}'
])
output_log_probs_npy = None
if output_log_probs:
output_log_probs_npy = str(output_dir /
(output_name + '_log_probs' + '.npy'))
args_list.extend(['--output_log_probs_npy', str(output_log_probs_npy)])
args_list.extend([
'--output_log_probs_npy',
f'{output_dir / model_spec_obj.get_log_probs_file()}'
])
args = run.parse_arguments(args_list)
print(args_list)
run.main(args)
def generate_outputs(num_beams):
print('Generating GPT2 FP32 outputs')
input_name = 'input_tokens.npy'
input_name_long = 'input_tokens_long.npy'
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.FLOAT)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
if num_beams == 1:
generate_output(engine='fp32-default',
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp32')
generate_output(engine='fp32-plugin',
input_name=input_name,
model_spec_obj=model_spec_obj)
model_spec_obj.use_gpt_plugin()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp32_plugin')
input_name=input_name,
model_spec_obj=model_spec_obj)
print('Generating GPT2 FP16 outputs')
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
if num_beams == 1:
generate_output(engine='fp16-default',
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16')
generate_output(engine='fp16-plugin',
input_name=input_name,
model_spec_obj=model_spec_obj)
model_spec_obj.use_gpt_plugin()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin')
generate_output(engine='fp16-plugin-packed',
input_name=input_name,
model_spec_obj=model_spec_obj)
model_spec_obj.use_packed_input()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed')
generate_output(engine='fp16-plugin-packed-paged-gather',
input_name=input_name,
model_spec_obj=model_spec_obj)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.gather_logits()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed_paged_gather',
input_name=input_name,
model_spec_obj=model_spec_obj,
output_logits=True,
output_log_probs=True,
output_cum_log_probs=True)
generate_output(engine='fp16-plugin-packed-paged',
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed_paged',
input_name=input_name,
model_spec_obj=model_spec_obj,
output_logits=False,
output_log_probs=True,
output_cum_log_probs=True)
generate_output(engine='fp16-plugin-packed-paged',
model_spec_obj.set_max_output_length(128)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_long_fp16_plugin_packed_paged',
input_name=input_name,
model_spec_obj=model_spec_obj,
output_logits=False,
max_output_len=128)
generate_output(
engine='fp16-plugin-packed-paged',
num_beams=num_beams,
input_name='input_tokens_long',
output_name='output_tokens_long_input_fp16_plugin_packed_paged',
output_logits=False)
generate_output(engine='fp16-plugin-packed-paged-sq',
model_spec_obj = model_spec.ModelSpec(input_name_long, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.use_packed_input()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed_paged_sq',
input_name=input_name_long,
model_spec_obj=model_spec_obj,
output_logits=False)
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.use_packed_input()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.set_quant_method(model_spec.QuantMethod.SMOOTH_QUANT)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name=input_name,
model_spec_obj=model_spec_obj,
output_logits=False)

View File

@ -18,11 +18,19 @@ import argparse as _arg
from pathlib import Path
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
def generate_output(engine: str,
num_beams: int,
output_name: str,
model_spec_obj: model_spec.ModelSpec,
max_output_len: int = 4):
tp_size = 1
@ -42,15 +50,15 @@ def generate_output(engine: str,
else:
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
args = run.parse_arguments([
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(hf_dir), '--output_npy',
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(output_dir / (base_output_name + '.npy')), '--output_csv',
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--use_py_session'
])
@ -58,22 +66,35 @@ def generate_output(engine: str,
def generate_outputs(only_fp8, num_beams):
input_file = 'input_tokens.npy'
if only_fp8 and num_beams == 1:
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FP8)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
print('Generating GPT-J FP8-kv-cache outputs')
generate_output(engine='fp8-plugin',
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
output_name='output_tokens_fp8_plugin')
model_spec_obj=model_spec_obj)
elif not only_fp8:
print('Generating GPT-J FP16 outputs')
generate_output(engine='fp16-plugin',
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
output_name='output_tokens_fp16_plugin')
generate_output(engine='fp16-plugin-packed',
model_spec_obj=model_spec_obj)
model_spec_obj.use_packed_input()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
output_name='output_tokens_fp16_plugin_packed')
generate_output(engine='fp16-plugin-packed-paged',
model_spec_obj=model_spec_obj)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
output_name='output_tokens_fp16_plugin_packed_paged')
model_spec_obj=model_spec_obj)
if __name__ == '__main__':

View File

@ -18,11 +18,19 @@ import argparse as _arg
from pathlib import Path
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
def generate_output(engine: str,
num_beams: int,
output_name: str,
model_spec_obj: model_spec.ModelSpec,
tp_size: int = 1,
pp_size: int = 1,
max_output_len: int = 8):
@ -42,15 +50,15 @@ def generate_output(engine: str,
else:
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
args = run.parse_arguments([
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(hf_dir), '--output_npy',
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(output_dir / (base_output_name + '.npy')), '--output_csv',
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--use_py_session'
])
@ -59,15 +67,22 @@ def generate_output(engine: str,
def generate_outputs(num_beams, only_multi_gpu=False):
tp_pp_sizes = [(1, 1)] if not only_multi_gpu else [(4, 1), (2, 2), (1, 4)]
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
for tp_size, pp_size in tp_pp_sizes:
print(
f'Generating outputs for Llama FP16 with TP={tp_size} and PP={pp_size}'
)
generate_output(engine='fp16-plugin-packed-paged',
model_spec_obj.use_tensor_parallelism(tp_size)
model_spec_obj.use_pipeline_parallelism(pp_size)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
tp_size=tp_size,
pp_size=pp_size,
output_name='output_tokens_fp16_plugin_packed_paged')
model_spec_obj=model_spec_obj)
if __name__ == '__main__':

View File

@ -17,12 +17,20 @@
from pathlib import Path
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
def generate_output(engine: str,
num_beams: int,
input_name: str,
output_name: str,
model_spec_obj: model_spec.ModelSpec,
max_output_len: int = 8,
output_logits: bool = False):
tp_size = 1
@ -34,26 +42,27 @@ def generate_output(engine: str,
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
data_dir = resources_dir / 'data'
input_file = data_dir / (input_name + '.npy')
input_file = data_dir / input_name
model_data_dir = data_dir / model
if num_beams <= 1:
output_dir = model_data_dir / 'sampling'
else:
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
output_logits_npy = None
if output_logits:
output_logits_npy = str(output_dir / (output_name + '_logits' + '.npy'))
output_logits_npy = str(output_dir /
(base_output_name + '_logits' + '.npy'))
args = run.parse_arguments([
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(models_dir / 'gpt-neox-20b'), '--output_npy',
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(output_dir / (base_output_name + '.npy')), '--output_csv',
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--output_logits_npy',
str(output_logits_npy), '--use_py_session'
@ -63,25 +72,35 @@ def generate_output(engine: str,
def generate_outputs(num_beams):
print('Generating Mamba FP16 outputs')
generate_output(engine='fp16-default',
input_name = 'input_tokens.npy'
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16')
input_name=input_name,
model_spec_obj=model_spec_obj)
print('Generating Mamba FP16-plugin outputs')
generate_output(engine='fp16-plugin',
model_spec_obj.use_gpt_plugin()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin')
input_name=input_name,
model_spec_obj=model_spec_obj)
print('Generating Mamba FP16-plugin-packed outputs')
generate_output(engine='fp16-plugin-packed',
model_spec_obj.use_packed_input()
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed')
input_name=input_name,
model_spec_obj=model_spec_obj)
print('Generating Mamba FP16-plugin-packed-paged outputs')
generate_output(engine='fp16-plugin-packed-paged',
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed_paged')
input_name=input_name,
model_spec_obj=model_spec_obj)
if __name__ == '__main__':

View File

@ -18,9 +18,19 @@ import argparse as _arg
from pathlib import Path
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
def generate_output(engine: str, output_name: str, max_output_len: int = 8):
def generate_output(engine: str,
model_spec_obj: model_spec.ModelSpec,
max_output_len: int = 8):
model = 'vicuna-7b-v1.3'
resources_dir = Path(__file__).parent.resolve().parent
@ -30,19 +40,19 @@ def generate_output(engine: str, output_name: str, max_output_len: int = 8):
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
data_dir = resources_dir / 'data'
input_file = data_dir / 'input_tokens.npy'
input_file = data_dir / 'input_vicuna.npy'
model_data_dir = data_dir / model
output_dir = model_data_dir / 'sampling'
output_name += '_tp1_pp1'
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
args = run.parse_arguments([
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(hf_dir), '--output_npy',
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(output_dir / (base_output_name + '.npy')), '--output_csv',
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
str(max_output_len), '--use_py_session',
'--medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]',
'--temperature', '1.0'
@ -52,9 +62,18 @@ def generate_output(engine: str, output_name: str, max_output_len: int = 8):
def generate_outputs():
print(f'Generating outputs for Medusa FP16')
generate_output(engine='fp16-plugin-packed-paged',
output_name='output_tokens_long_fp16_plugin_packed_paged',
max_output_len=128)
max_output_len = 128
model_spec_obj = model_spec.ModelSpec('input_tokens_long.npy',
_tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_max_output_length(max_output_len)
model_spec_obj.use_packed_input()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_medusa()
generate_output(engine=model_spec_obj.get_model_path(),
model_spec_obj=model_spec_obj,
max_output_len=max_output_len)
if __name__ == '__main__':

View File

@ -17,12 +17,20 @@
from pathlib import Path
import run
from build_engines_utils import init_model_spec_module
init_model_spec_module()
import os
import model_spec
import tensorrt_llm.bindings as _tb
def generate_output(engine: str,
num_beams: int,
input_name: str,
output_name: str,
model_spec_obj: model_spec.ModelSpec,
max_output_len: int = 8,
output_logits: bool = False):
tp_size = 1
@ -34,26 +42,27 @@ def generate_output(engine: str,
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
data_dir = resources_dir / 'data'
input_file = data_dir / (input_name + '.npy')
input_file = data_dir / input_name
model_data_dir = data_dir / model
if num_beams <= 1:
output_dir = model_data_dir / 'sampling'
else:
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
output_logits_npy = None
if output_logits:
output_logits_npy = str(output_dir / (output_name + '_logits' + '.npy'))
output_logits_npy = str(output_dir /
(base_output_name + '_logits' + '.npy'))
args = run.parse_arguments([
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(models_dir / model), '--output_npy',
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(output_dir / (base_output_name + '.npy')), '--output_csv',
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--output_logits_npy',
str(output_logits_npy), '--use_py_session'
@ -62,11 +71,17 @@ def generate_output(engine: str,
def generate_outputs(num_beams):
input_file = 'input_tokens.npy'
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
model_spec_obj.use_gpt_plugin()
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
model_spec_obj.use_packed_input()
print('Generating RecurrentGemma FP16-plugin-packed-paged outputs')
generate_output(engine='fp16-plugin-packed-paged',
generate_output(engine=model_spec_obj.get_model_path(),
num_beams=num_beams,
input_name='input_tokens',
output_name='output_tokens_fp16_plugin_packed_paged')
input_name=input_file,
model_spec_obj=model_spec_obj)
if __name__ == '__main__':

Some files were not shown because too many files have changed in this diff Show More