mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1954)
* Update TensorRT-LLM --------- Co-authored-by: Altair-Alpha <62340011+Altair-Alpha@users.noreply.github.com>
This commit is contained in:
parent
a96cccafcf
commit
2d234357c6
8
.gitignore
vendored
8
.gitignore
vendored
@ -40,3 +40,11 @@ tensorrt_llm/bindings/*.pyi
|
||||
# Testing
|
||||
.coverage.*
|
||||
results_trt/
|
||||
|
||||
# build/debug
|
||||
*.safetensors
|
||||
*/tllm_debug/**
|
||||
*.patch
|
||||
|
||||
# Generated files
|
||||
cpp/include/tensorrt_llm/executor/version.h
|
||||
|
||||
@ -177,13 +177,6 @@ def parse_arguments():
|
||||
'If this option is specified, it will override the max decoder input len of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined max decoder input len'
|
||||
))
|
||||
parser.add_argument(
|
||||
'--max_output_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max output len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_seq_len',
|
||||
'--max_decoder_seq_len',
|
||||
@ -360,21 +353,6 @@ def main(args):
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
|
||||
if args.max_output_len:
|
||||
logger.warning(
|
||||
'--max_output_len has been deprecated in favor of --max_seq_len')
|
||||
if args.max_input_len:
|
||||
if args.max_seq_len:
|
||||
logger.warning(
|
||||
'--max_seq_len has been overwritten due to --max_output_len being specified'
|
||||
)
|
||||
args.max_seq_len = args.max_input_len + args.max_output_len
|
||||
else:
|
||||
raise Exception(
|
||||
f"--max_output_len is specified but not --max_input_len")
|
||||
|
||||
del args.max_output_len
|
||||
|
||||
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
|
||||
# Current Mem Monitor will cause benchmark script hang
|
||||
# because MPI does not work well with multiprocessing.
|
||||
|
||||
@ -129,13 +129,6 @@ def parse_arguments():
|
||||
help=
|
||||
('If this option is specified, it will override the max input len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_output_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max output len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_seq_len',
|
||||
'--max_decoder_seq_len',
|
||||
|
||||
@ -185,7 +185,7 @@ When the benchmark runs successfully, you will see a report out of the run simil
|
||||
[RANK 0] Completed request submission.
|
||||
[RANK 0] Calculating results.
|
||||
[RANK 0] Reporting...
|
||||
[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_output_len 128 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
|
||||
[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_seq_len 256 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
|
||||
===========================================================
|
||||
= METADATA
|
||||
===========================================================
|
||||
|
||||
@ -128,6 +128,27 @@ if(INDEX_RANGE_CHECK)
|
||||
message(WARNING "Check index range to detect OOB accesses")
|
||||
endif()
|
||||
|
||||
# Read the project version
|
||||
set(TRTLLM_VERSION_DIR ${PROJECT_SOURCE_DIR}/../tensorrt_llm)
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS
|
||||
${TRTLLM_VERSION_DIR}/version.py)
|
||||
execute_process(
|
||||
COMMAND python3 -c "import version; print(version.__version__)"
|
||||
WORKING_DIRECTORY ${TRTLLM_VERSION_DIR}
|
||||
OUTPUT_VARIABLE TRTLLM_VERSION
|
||||
RESULT_VARIABLE TRTLLM_VERSION_RESULT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(TRTLLM_VERSION_RESULT EQUAL 0)
|
||||
message(STATUS "TensorRT-LLM version: ${TRTLLM_VERSION}")
|
||||
else()
|
||||
message(FATAL_ERROR "Failed to determine Tensorrt-LLM version")
|
||||
endif()
|
||||
|
||||
configure_file(
|
||||
cmake/templates/version.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include/tensorrt_llm/executor/version.h)
|
||||
|
||||
# Determine CUDA version before enabling the language extension
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
@ -139,7 +160,7 @@ if(CMAKE_CUDA_COMPILER)
|
||||
"${CMAKE_CUDA_COMPILER} --version | egrep -o 'V[0-9]+.[0-9]+.[0-9]+' | cut -c2-"
|
||||
RESULT_VARIABLE _BASH_SUCCESS
|
||||
OUTPUT_VARIABLE CMAKE_CUDA_COMPILER_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(NOT _BASH_SUCCESS EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to determine CUDA version")
|
||||
|
||||
24
cpp/cmake/templates/version.h
Normal file
24
cpp/cmake/templates/version.h
Normal file
@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// THIS FILE IS AUTO GENERATED FROM cmake/templates/version.h. DO NOT EDIT.
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
static auto constexpr kTensorRtLlmVersion = "@TRTLLM_VERSION@";
|
||||
}
|
||||
@ -122,6 +122,9 @@ private:
|
||||
void decoupled_execution_loop();
|
||||
std::shared_ptr<std::thread> worker_thread_;
|
||||
std::shared_ptr<nvinfer1::ILogger> mLogger{};
|
||||
|
||||
inline static std::string const kPROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP";
|
||||
inline static std::string const kLEGACY_PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_GPTM_PROFILE_START_STOP";
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -375,6 +375,11 @@ public:
|
||||
return mSecondaryPool;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumLayers() const
|
||||
{
|
||||
return mNumLayers;
|
||||
}
|
||||
|
||||
//! \brief Get index in pool to K or V block.
|
||||
//! \param blockId the blockId as returned by getBlockId()
|
||||
//! \param fieldIdx either 0 (K) or 1 (V),
|
||||
@ -592,6 +597,8 @@ public:
|
||||
void removeToken(SizeType32 seqSlotIdx);
|
||||
void rewindKVCache(SizeType32 seqSlotIdx, SizeType32 rewindLengths);
|
||||
|
||||
[[nodiscard]] GenerationRequest const& getSequence(SizeType32 seqSlotIdx) const;
|
||||
|
||||
[[nodiscard]] bool isCrossKv() const
|
||||
{
|
||||
return mCacheType == CacheType::kCROSS;
|
||||
@ -634,4 +641,5 @@ private:
|
||||
// KV cache type (self or cross)
|
||||
CacheType mCacheType;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
99
cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Normal file
99
cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Normal file
@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
class BlockIterator
|
||||
{
|
||||
public:
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
using value_type = runtime::ITensor;
|
||||
using pointer = runtime::ITensor::SharedPtr;
|
||||
using reference = value_type&;
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
BlockIterator(runtime::ITensor::SharedPtr blockPoolPtr, std::vector<SizeType32> blockIds, size_t idx)
|
||||
: mPool{std::move(blockPoolPtr)}
|
||||
, mBlockIds{std::move(blockIds)}
|
||||
, mIdx{idx}
|
||||
{
|
||||
TLLM_CHECK(mPool);
|
||||
TLLM_CHECK(mIdx <= mBlockIds.size());
|
||||
update();
|
||||
}
|
||||
|
||||
[[nodiscard]] pointer operator->()
|
||||
{
|
||||
return mCurrent;
|
||||
}
|
||||
|
||||
[[nodiscard]] reference operator*()
|
||||
{
|
||||
return *mCurrent;
|
||||
}
|
||||
|
||||
BlockIterator& operator++()
|
||||
{
|
||||
mIdx++;
|
||||
update();
|
||||
return *this;
|
||||
}
|
||||
|
||||
BlockIterator operator++(int)
|
||||
{
|
||||
auto ret = *this;
|
||||
ret.update();
|
||||
mIdx++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator==(BlockIterator const& other) const
|
||||
{
|
||||
return mIdx == other.mIdx && mPool.get() == other.mPool.get();
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator!=(BlockIterator const& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
void update()
|
||||
{
|
||||
if (mIdx < mBlockIds.size())
|
||||
{
|
||||
mCurrent = runtime::ITensor::slice(mPool, mBlockIds.at(mIdx), 1);
|
||||
}
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr mPool;
|
||||
runtime::ITensor::SharedPtr mCurrent;
|
||||
const std::vector<SizeType32> mBlockIds;
|
||||
size_t mIdx;
|
||||
};
|
||||
|
||||
[[nodiscard]] BlockIterator getBlockBeginIt(
|
||||
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
|
||||
|
||||
[[nodiscard]] BlockIterator getBlockEndIt(
|
||||
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
@ -62,7 +63,8 @@ public:
|
||||
using VecLogProbs = std::vector<float>;
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using TensorPtr = TTensor;
|
||||
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream const&)>;
|
||||
using LogitsPostProcessor = std::function<void(
|
||||
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
|
||||
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
@ -77,19 +79,21 @@ public:
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
|
||||
bool applyLogitsPostProcessorBatched = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false,
|
||||
std::optional<RequestIdType> clientId = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
, mSamplingConfig(samplingConfig)
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
, mIsStreaming(isStreaming)
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mLogitsPostProcessor(logitsPostProcessor)
|
||||
, mApplyLogitsPostProcessorBatched(applyLogitsPostProcessorBatched)
|
||||
, mClientId(clientId)
|
||||
, mIsStreaming(isStreaming)
|
||||
, mOrigPromptLen(mPromptLen)
|
||||
, mMaxSentTokenPos(mPromptLen - 1)
|
||||
, mMaxSentTokenLen(mPromptLen)
|
||||
, mEmbeddingBias(std::move(embeddingBias))
|
||||
, mBadWordsList(std::move(badWordsList))
|
||||
, mStopWordsList(std::move(stopWordsList))
|
||||
@ -105,6 +109,7 @@ public:
|
||||
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
|
||||
, mDraftLogits(draftLogits)
|
||||
, mNumTokensPerIteration(1)
|
||||
, mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1))
|
||||
, mReturnContextLogits(returnContextLogits)
|
||||
, mReturnGenerationLogits(returnGenerationLogits)
|
||||
, mExcludeInputFromOutput(excludeInputFromOutput)
|
||||
@ -125,11 +130,12 @@ public:
|
||||
, mMaxNewTokens(req.getMaxNewTokens())
|
||||
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
, mIsStreaming(req.getStreaming())
|
||||
, mEndId(req.getEndId())
|
||||
, mPadId(req.getPadId())
|
||||
, mClientId(req.getClientId())
|
||||
, mIsStreaming(req.getStreaming())
|
||||
, mOrigPromptLen(mPromptLen)
|
||||
, mMaxSentTokenPos(mPromptLen - 1)
|
||||
, mMaxSentTokenLen(mPromptLen)
|
||||
, mEmbeddingBias(std::nullopt)
|
||||
, mBadWordsList(std::nullopt)
|
||||
, mStopWordsList(std::nullopt)
|
||||
@ -145,6 +151,7 @@ public:
|
||||
, mDraftTokens(std::make_shared<VecTokens>())
|
||||
, mDraftLogits(std::nullopt)
|
||||
, mNumTokensPerIteration(1)
|
||||
, mReturnAllGeneratedTokens(req.getReturnAllGeneratedTokens())
|
||||
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
|
||||
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
|
||||
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
|
||||
@ -152,6 +159,16 @@ public:
|
||||
, mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput)
|
||||
, mDecodingIter(0)
|
||||
{
|
||||
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnAllGeneratedTokens == false)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Setting mReturnAllGeneratedTokens to True since streaming AND beam search are done simultaneously. "
|
||||
"Returning the full beams at each streaming step is needed because beam search + streaming can change "
|
||||
"previous outputs. Initialize request with mReturnAllGeneratedTokens = True to dismiss this error. "
|
||||
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output "
|
||||
"length).");
|
||||
mReturnAllGeneratedTokens = true;
|
||||
}
|
||||
if (req.getEncoderInputTokenIds())
|
||||
{
|
||||
mState = REQUEST_STATE_ENCODER_INIT;
|
||||
@ -440,20 +457,20 @@ public:
|
||||
mSeqSlot.reset();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated token positions.
|
||||
/// @return The maximum position of the tokens sent to the client
|
||||
[[nodiscard]] SizeType32 getMaxSentTokenPos() const
|
||||
/// @brief Get the maximum length of tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated tokens.
|
||||
/// @return The maximum length of the tokens sent to the client.
|
||||
[[nodiscard]] SizeType32 getMaxSentTokenLen() const
|
||||
{
|
||||
return mMaxSentTokenPos;
|
||||
return mMaxSentTokenLen;
|
||||
}
|
||||
|
||||
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated token positions.
|
||||
/// @param pos The maximum position
|
||||
void setMaxSentTokenPos(SizeType32 pos)
|
||||
/// @brief Sets the maximum length of tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated tokens.
|
||||
/// @param maxSentLength The new maximum length.
|
||||
void setMaxSentTokenLen(SizeType32 maxSentLength)
|
||||
{
|
||||
mMaxSentTokenPos = pos;
|
||||
mMaxSentTokenLen = maxSentLength;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<TensorPtr> getPromptEmbeddingTable() const
|
||||
@ -599,7 +616,7 @@ public:
|
||||
|
||||
void setNumTokensPerIteration(SizeType32 numTokensPerIteration)
|
||||
{
|
||||
mNumTokensPerIteration = numTokensPerIteration;
|
||||
mNumTokensPerIteration = std::max(1, numTokensPerIteration);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumTokensPerIteration() const
|
||||
@ -669,6 +686,23 @@ public:
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isStreaming() const noexcept
|
||||
{
|
||||
return mIsStreaming;
|
||||
}
|
||||
|
||||
void constexpr setStreaming(bool isStreaming) noexcept
|
||||
{
|
||||
mIsStreaming = isStreaming;
|
||||
}
|
||||
|
||||
void setReturnAllGeneratedTokens(bool const returnAllGeneratedTokens)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mIsStreaming || mSamplingConfig.beamWidth == 1 || returnAllGeneratedTokens,
|
||||
"returnAllGeneratedTokens must be true if streaming AND beam search are used.");
|
||||
mReturnAllGeneratedTokens = returnAllGeneratedTokens;
|
||||
}
|
||||
|
||||
void setReturnContextLogits(bool const returnContextLogits)
|
||||
{
|
||||
mReturnContextLogits = returnContextLogits;
|
||||
@ -739,7 +773,7 @@ public:
|
||||
return mGenerationLogitsFragments;
|
||||
}
|
||||
|
||||
void addGenerationFragments(TensorPtr& genLogits)
|
||||
void addGenerationLogitsFragment(TensorPtr& genLogits)
|
||||
{
|
||||
mGenerationLogitsFragments.push_back(genLogits);
|
||||
}
|
||||
@ -876,21 +910,27 @@ public:
|
||||
executor::Result result;
|
||||
result.isFinal = isGenerationCompleteState();
|
||||
|
||||
auto nbBeams = mSamplingConfig.beamWidth;
|
||||
auto maxNbTokens = getMaxBeamNumTokens();
|
||||
// FIXME(nkorobov): For streaming we do not allow beam search and
|
||||
// streaming index calculation here applies only for sampling
|
||||
// getNumTokensPerIteration takes accepted draft tokens into account
|
||||
int nbTokensOut = mIsStreaming ? std::max(getNumTokensPerIteration(), 1) : maxNbTokens;
|
||||
if (mExcludeInputFromOutput && !mIsStreaming)
|
||||
auto const nbBeams = mSamplingConfig.beamWidth;
|
||||
auto const maxNbTokens = getMaxBeamNumTokens();
|
||||
|
||||
auto const calculateNbTokensOut = [this](SizeType32 maxNbTokens)
|
||||
{
|
||||
nbTokensOut -= getOrigPromptLen();
|
||||
}
|
||||
if (!mIsStreaming)
|
||||
{
|
||||
return maxNbTokens - (mExcludeInputFromOutput ? getOrigPromptLen() : 0);
|
||||
}
|
||||
return mReturnAllGeneratedTokens ? maxNbTokens - getOrigPromptLen()
|
||||
: maxNbTokens - getMaxSentTokenLen();
|
||||
};
|
||||
|
||||
auto const maxNbTokensOut = calculateNbTokensOut(maxNbTokens);
|
||||
|
||||
result.outputTokenIds.resize(nbBeams);
|
||||
SizeType32 tokenPos = maxNbTokens - nbTokensOut;
|
||||
|
||||
bool shouldSendResponse = isGenerationCompleteState() || (mIsStreaming && tokenPos > getMaxSentTokenPos());
|
||||
auto const startTokenPos = maxNbTokens - maxNbTokensOut;
|
||||
|
||||
auto const shouldSendResponse
|
||||
= isGenerationCompleteState() || (mIsStreaming && maxNbTokens > getMaxSentTokenLen());
|
||||
|
||||
if (!shouldSendResponse)
|
||||
{
|
||||
@ -900,24 +940,14 @@ public:
|
||||
{
|
||||
for (SizeType32 beam = 0; beam < nbBeams; ++beam)
|
||||
{
|
||||
auto tokens = getTokens(beam);
|
||||
auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size();
|
||||
auto const& tokens = getTokens(beam);
|
||||
auto const nbTokensOut = calculateNbTokensOut(tokens.size());
|
||||
|
||||
// Take accepted draft tokens into account when streaming
|
||||
auto const numAcceptedTokens = std::max(0, getNumTokensPerIteration() - 1);
|
||||
nbTokens += mIsStreaming ? numAcceptedTokens : 0;
|
||||
|
||||
if (mExcludeInputFromOutput && !mIsStreaming)
|
||||
if (nbTokensOut > 0)
|
||||
{
|
||||
nbTokens -= getOrigPromptLen();
|
||||
auto const first = tokens.data() + startTokenPos;
|
||||
result.outputTokenIds.at(beam).assign(first, first + nbTokensOut);
|
||||
}
|
||||
if (nbTokens > 0)
|
||||
{
|
||||
result.outputTokenIds.at(beam).assign(
|
||||
tokens.data() + tokenPos, tokens.data() + tokenPos + nbTokens);
|
||||
}
|
||||
// Correct next token position by accepted draft tokens
|
||||
tokenPos += numAcceptedTokens;
|
||||
}
|
||||
|
||||
if (returnLogProbs())
|
||||
@ -954,7 +984,7 @@ public:
|
||||
}
|
||||
|
||||
// Update position of last sent response
|
||||
mMaxSentTokenPos = tokenPos;
|
||||
setMaxSentTokenLen(maxNbTokens);
|
||||
|
||||
auto response = executor::Response(mRequestId, std::move(result));
|
||||
return response;
|
||||
@ -972,21 +1002,25 @@ public:
|
||||
// Tokens [beam_size, mPromptLen + getMaxNumGeneratedTokens()]
|
||||
runtime::SamplingConfig mSamplingConfig;
|
||||
LlmRequestState_t mState;
|
||||
bool mIsStreaming;
|
||||
std::optional<TokenIdType> mEndId;
|
||||
std::optional<TokenIdType> mPadId;
|
||||
std::optional<SizeType32> mSeqSlot;
|
||||
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
|
||||
bool mApplyLogitsPostProcessorBatched;
|
||||
std::optional<RequestIdType> mClientId;
|
||||
// Position of mask token in GLM model inputs
|
||||
SizeType32 mMaskPosition{0};
|
||||
|
||||
protected:
|
||||
bool mIsStreaming;
|
||||
|
||||
BeamTokens mTokens;
|
||||
SizeType32 mOrigPromptLen;
|
||||
// Number of tokens already in KV cache before context phase.
|
||||
// A value > 0 indicates cached KV cache blocks were reused.
|
||||
// Up to inputLen - 1 tokens can be reused.
|
||||
SizeType32 mPrepopulatedPromptLen{0};
|
||||
SizeType32 mMaxSentTokenPos;
|
||||
SizeType32 mMaxSentTokenLen;
|
||||
|
||||
std::optional<TensorPtr> mEmbeddingBias;
|
||||
std::optional<TensorPtr> mBadWordsList;
|
||||
@ -1011,6 +1045,8 @@ protected:
|
||||
std::optional<TensorPtr> mDraftLogits;
|
||||
SizeType32 mNumTokensPerIteration;
|
||||
|
||||
// whether to return the full beams on each iteration. True when doing streaming + beamsearch
|
||||
bool mReturnAllGeneratedTokens;
|
||||
// Save logits
|
||||
bool mReturnContextLogits;
|
||||
bool mReturnGenerationLogits;
|
||||
@ -1108,13 +1144,14 @@ public:
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
|
||||
bool applyLogitsPostProcessorBatched = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false,
|
||||
std::optional<RequestIdType> clientId = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
|
||||
excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched,
|
||||
std::move(encoderInputTokens), returnEncoderOutput)
|
||||
std::move(encoderInputTokens), returnEncoderOutput, clientId)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
31
cpp/include/tensorrt_llm/common/cudaProfilerUtils.h
Normal file
31
cpp/include/tensorrt_llm/common/cudaProfilerUtils.h
Normal file
@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
/// @brief Populate the start and end profiling iteration indexes from the provided environment variables
|
||||
/// Try to set from envVarName first, and if that fails, try to set from legacyEnvVarName
|
||||
/// Env variable values are expected to be in the format "1,2,3-5,6-8,9"
|
||||
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
|
||||
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName = std::nullopt);
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@ -24,6 +24,7 @@
|
||||
#include <memory> // std::make_unique
|
||||
#include <sstream> // std::stringstream
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
@ -106,4 +107,7 @@ inline bool strStartsWith(std::string const& str, std::string const& prefix)
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
/// @brief Split a string into a set of strings using a delimiter
|
||||
std::unordered_set<std::string> str2set(std::string const& input, char delimiter);
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -16,8 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
|
||||
@ -37,6 +37,9 @@ class MpiComm;
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
/// @brief Version of TRT-LLM
|
||||
char const* version() noexcept;
|
||||
|
||||
class Model;
|
||||
class Serialization;
|
||||
|
||||
@ -252,6 +255,8 @@ public:
|
||||
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
|
||||
/// name provided to the ExecutorConfig.
|
||||
/// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models
|
||||
/// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens
|
||||
/// after every streaming step.
|
||||
Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||
@ -262,7 +267,8 @@ public:
|
||||
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
|
||||
std::optional<LoraConfig> loraConfig = std::nullopt,
|
||||
std::optional<std::string> logitsPostProcessorName = std::nullopt,
|
||||
std::optional<VecTokens> encoderInputTokenIds = std::nullopt);
|
||||
std::optional<VecTokens> encoderInputTokenIds = std::nullopt, std::optional<IdType> clientId = std::nullopt,
|
||||
bool returnAllGeneratedTokens = false);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
@ -288,6 +294,8 @@ public:
|
||||
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
|
||||
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
|
||||
[[nodiscard]] std::optional<VecTokens> getEncoderInputTokenIds() const;
|
||||
[[nodiscard]] std::optional<IdType> getClientId() const;
|
||||
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
void setSamplingConfig(SamplingConfig const& config);
|
||||
@ -302,6 +310,8 @@ public:
|
||||
void setLoraConfig(LoraConfig const& loraConfig);
|
||||
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
|
||||
void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds);
|
||||
void setClientId(IdType clientId);
|
||||
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
@ -53,10 +53,12 @@ using IterationType = std::uint64_t;
|
||||
using RandomSeedType = std::uint64_t;
|
||||
using VecLogProbs = std::vector<FloatType>;
|
||||
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
|
||||
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&)>;
|
||||
using LogitsPostProcessor
|
||||
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
|
||||
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
|
||||
using LogitsPostProcessorBatched = std::function<void(std::vector<IdType> const&, std::vector<Tensor>&,
|
||||
std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&)>;
|
||||
std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&,
|
||||
std::vector<std::optional<IdType>> const&)>;
|
||||
using MedusaChoices = std::vector<std::vector<SizeType32>>;
|
||||
|
||||
enum class DataType
|
||||
|
||||
@ -41,30 +41,32 @@ public:
|
||||
class Inputs
|
||||
{
|
||||
public:
|
||||
//! [batchSize]
|
||||
//! [maxBatchSize]
|
||||
TensorPtr temperatures;
|
||||
//! [batchSize]
|
||||
//! [maxBatchSize]
|
||||
TensorPtr positionIdsBase;
|
||||
//! [batchSize] or [numGenSequences]
|
||||
//! [maxBatchSize] or [numGenSequences]
|
||||
TensorPtr generationLengths;
|
||||
//! [batchSize]
|
||||
//! [maxBatchSize]
|
||||
TensorPtr randomDataSample;
|
||||
//! [batchSize, maxNumPaths, maxPathDraftLen] or [numGenSequences, maxNumPaths, maxPathDraftLen]
|
||||
//! [maxBatchSize, maxNumPaths, maxPathDraftLen] or [numGenSequences, maxNumPaths, maxPathDraftLen]
|
||||
TensorPtr randomDataValidation;
|
||||
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
|
||||
//! [maxBatchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
|
||||
TensorPtr draftTokens;
|
||||
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
|
||||
//! [maxBatchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
|
||||
TensorPtr draftIndices;
|
||||
//! [batchSize, maxNumPaths, maxPathDraftLen, vocabSize]
|
||||
//! [maxBatchSize, maxNumPaths, maxPathDraftLen, vocabSize]
|
||||
//! or [numGenSequences, maxNumPaths, maxPathDraftLen, vocabSize]
|
||||
TensorPtr draftProbs;
|
||||
//! [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
//! [maxBatchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
//! or [numGenSequences, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
TensorPtr packedMasks;
|
||||
//! [batchSize] or [numGenSequences]
|
||||
//! [maxBatchSize] or [numGenSequences]
|
||||
TensorPtr positionIds;
|
||||
// [1], on pinned
|
||||
TensorPtr maxGenLengthHost;
|
||||
// [maxBatchSize]
|
||||
TensorPtr generationLengthsHost;
|
||||
|
||||
void create(SizeType32 maxNumSequences, runtime::TllmRuntime const& runtime,
|
||||
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
|
||||
|
||||
@ -40,10 +40,11 @@ public:
|
||||
enum class ModelVariant : std::int32_t
|
||||
{
|
||||
kGpt = 0,
|
||||
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
|
||||
kMamba = 2, // https://github.com/state-spaces/mamba
|
||||
kRecurrentGemma = 3, // https://github.com/google-deepmind/recurrentgemma
|
||||
kEncDec = 4,
|
||||
kChatGlm = 1, // https://github.com/THUDM/ChatGLM-6B
|
||||
kGlm = 2, // https://github.com/THUDM/GLM
|
||||
kMamba = 3, // https://github.com/state-spaces/mamba
|
||||
kRecurrentGemma = 4, // https://github.com/google-deepmind/recurrentgemma
|
||||
kEncDec = 5,
|
||||
};
|
||||
|
||||
struct RnnConfig
|
||||
@ -526,7 +527,7 @@ public:
|
||||
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
|
||||
{
|
||||
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm
|
||||
|| mModelVariant == ModelVariant::kRecurrentGemma;
|
||||
|| mModelVariant == ModelVariant::kChatGlm || mModelVariant == ModelVariant::kRecurrentGemma;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasRnnConfig() const noexcept
|
||||
|
||||
@ -106,6 +106,7 @@ int parseTacticToId(nlohmann::json tactic_config)
|
||||
{K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
|
||||
{K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
|
||||
{K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
|
||||
{K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B},
|
||||
};
|
||||
|
||||
if (c.tile_config_sm90 != tile_map.at(K(tile_shape[0], tile_shape[1])))
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5804fde474d6489db29204259b7e6c368117acadb7fb6dc807868ee0391c458b
|
||||
size 3953206
|
||||
oid sha256:f41188ef30e21d12ebcb92ee6546badb330f6c63a90fff535f3e613d61f103f9
|
||||
size 4268820
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:85802a0e66148acb17d017a64dd982287775ce7bf5aa4e8bb7e5466b3736c7ee
|
||||
size 4019734
|
||||
oid sha256:510d90d67edcdbbe164493637772e50ef2f8d88d927f561c46512052aed7624c
|
||||
size 4365768
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
00fb525bdf4ff217c16940540b2357c4 libtensorrt_llm_batch_manager_static.a
|
||||
97d2db7f62745001d871bc89fb38eed6 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
|
||||
6aa84afae42ed1e725b8809b8f8a38fb libtensorrt_llm_batch_manager_static.a
|
||||
cc359afb584b086456510f084f6617ed libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
db055e58b6c6c8cf7350b66a583f9c388c4eac07 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33a724d7e9eabc358c0d674151d45cef8849ae702cc5f2f88b259299a8306574
|
||||
size 3842582
|
||||
oid sha256:920951af1730c7304fd1a7c286ddc8f96a17f918aaaf7815da385bf92c37e54c
|
||||
size 4129858
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:490a93ff13a67949a30e279fc3df27456c7f5d4084158c3089befccf78118b7f
|
||||
size 3799140
|
||||
oid sha256:1d9b525a3855dd5a853604031efb306b08afd1ee425aba2a7846f7cd77f89ddb
|
||||
size 4107114
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:663a163c3177644ed86fa7a2145fe5e9dbf6f2f0ed06c96d367236da323a3432
|
||||
size 22523526
|
||||
oid sha256:2a903a8cae43ec88d69fba666c3da1f301f1cb0aaf37256715da7363ee04a236
|
||||
size 23909614
|
||||
|
||||
84
cpp/tensorrt_llm/common/cudaProfilerUtils.cpp
Normal file
84
cpp/tensorrt_llm/common/cudaProfilerUtils.cpp
Normal file
@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaProfilerUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::tuple<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexesImpl(
|
||||
std::string const& envVarName)
|
||||
{
|
||||
auto envVarVal = std::getenv(envVarName.c_str());
|
||||
auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""};
|
||||
auto values = tensorrt_llm::common::str2set(envVarValStr, ',');
|
||||
std::unordered_set<int32_t> startSet;
|
||||
std::unordered_set<int32_t> endSet;
|
||||
for (std::string const& value : values)
|
||||
{
|
||||
size_t dashIdx = value.find("-");
|
||||
if (dashIdx != std::string::npos)
|
||||
{
|
||||
int32_t start = std::stoi(value.substr(0, dashIdx));
|
||||
startSet.insert(start);
|
||||
int32_t end = std::stoi(value.substr(dashIdx + 1));
|
||||
endSet.insert(end);
|
||||
}
|
||||
else
|
||||
{
|
||||
int32_t start_end = std::stoi(value);
|
||||
startSet.insert(start_end);
|
||||
endSet.insert(start_end);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(startSet, endSet);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
|
||||
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName)
|
||||
{
|
||||
auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName);
|
||||
|
||||
// If empty, try to use legacy env var name
|
||||
if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty())
|
||||
{
|
||||
std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value());
|
||||
|
||||
if (!profileIterIdxs.empty() || !stopIterIdxs.empty())
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
|
||||
"Please "
|
||||
"use %s "
|
||||
"instead.",
|
||||
legacyEnvVarName.value().c_str(), envVarName.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(profileIterIdxs, stopIterIdxs);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@ -20,6 +20,7 @@
|
||||
#include <cerrno>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
@ -54,4 +55,22 @@ std::string fmtstr(char const* format, ...)
|
||||
return result;
|
||||
};
|
||||
|
||||
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
|
||||
{
|
||||
std::unordered_set<std::string> values;
|
||||
if (!input.empty())
|
||||
{
|
||||
std::stringstream valStream(input);
|
||||
std::string val;
|
||||
while (std::getline(valStream, val, delimiter))
|
||||
{
|
||||
if (!val.empty())
|
||||
{
|
||||
values.insert(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
return values;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
42
cpp/tensorrt_llm/common/timestampUtils.cpp
Normal file
42
cpp/tensorrt_llm/common/timestampUtils.cpp
Normal file
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorrt_llm/common/timestampUtils.h"
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::string getCurrentTimestamp()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto now_t = std::chrono::system_clock::to_time_t(now);
|
||||
auto tm = *std::localtime(&now_t);
|
||||
|
||||
auto epoch_to_now = now.time_since_epoch();
|
||||
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(epoch_to_now);
|
||||
auto us = std::chrono::duration_cast<std::chrono::microseconds>(epoch_to_now - seconds);
|
||||
|
||||
std::ostringstream stream;
|
||||
stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S");
|
||||
stream << "." << std::setfill('0') << std::setw(6) << us.count();
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
25
cpp/tensorrt_llm/common/timestampUtils.h
Normal file
25
cpp/tensorrt_llm/common/timestampUtils.h
Normal file
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu"
|
||||
std::string getCurrentTimestamp();
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@ -15,6 +15,7 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#if !defined(_MSC_VER)
|
||||
|
||||
@ -93,6 +93,8 @@ enum class CutlassTileConfigSM90
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape256x128x128B,
|
||||
};
|
||||
|
||||
enum class MainloopScheduleType
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:497b00031131c1dc705e848e52f3d43148f55505e37bdad97f4933b2c074469d
|
||||
size 1400502
|
||||
oid sha256:af8889214b82f8e65a226b6558dbdef474552850b50f07df76cbc24aeac94d6c
|
||||
size 1410084
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:417978bdb5c19f97d9758475acacfa18a4038fc3c5a83f981b02ee220104e0c7
|
||||
size 1425792
|
||||
oid sha256:09b641ce17db25301b7c4e9049bc11dc105f749be0742b91986bf47601c1bbc7
|
||||
size 1437532
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
1df55ac2948ca7b7fe2d5e79934e660e libtensorrt_llm_executor_static.a
|
||||
ea1641928d184d117deec0696763b274 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
|
||||
56835da1fe4f3edc7f28bac50b253b2d libtensorrt_llm_executor_static.a
|
||||
4f1d05581c8c8d4663500c5300446778 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
db055e58b6c6c8cf7350b66a583f9c388c4eac07 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d0441d473852d11f50bcf23f4934b38d7e4c6d4a42f057eb04beb8aea4211cac
|
||||
size 1451118
|
||||
oid sha256:cb73df78859b9bf2d425a4c307403863b9de62820ff3b7e0ff2bbe6ac9f35894
|
||||
size 1459664
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dc8619f99cf5a2e04bdb1482f157a9852bd745e90cf9e03a7878f73ed07e5610
|
||||
size 1383936
|
||||
oid sha256:c7bf468b3c45d0c8e605ada27e16edeaf7b22928883d28eca4e8f9b568a01eff
|
||||
size 1391962
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:772d1b83e739b926729b99999fbb81768569ffb172c2e120665b2d31b987bb47
|
||||
size 14071986
|
||||
oid sha256:6491a8b88087cb0be7af82f9523dae800f5d217730941441f1804e5ccd4770b5
|
||||
size 14289284
|
||||
|
||||
@ -29,38 +29,38 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int G_, int N_,
|
||||
typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
// const half *g_mxY_, // B*L*H*P
|
||||
// const half *g_mxOs_, // B*C*H*N*P
|
||||
// const half *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
// const float *g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
half* g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
half* g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int G_, int N_,
|
||||
typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
// const bf16 *g_mxY_, // B*L*H*P
|
||||
// const bf16 *g_mxOs_, // B*C*H*N*P
|
||||
// const bf16 *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
// const float *g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
bf16* g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
bf16* g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
@ -68,21 +68,21 @@ template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
int warpM_, int warpN_, // warp number
|
||||
int pipeS_, class Tp_>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> bmm_chunk_kernel(int B_,
|
||||
int L_, int G_, int N_,
|
||||
int L_, int H_, int P_, int G_, int N_,
|
||||
// const Tp_ *g_mxY_, // B*L*H*P
|
||||
// const Tp_ *g_mxOs_, // B*C*H*N*P
|
||||
// const Tp_ *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
// const float *g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
Tp_* g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
Tp_* g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
@ -98,13 +98,17 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
// auto H = Rn<ID>{H_};
|
||||
// auto P = Rn<ID>{P_};
|
||||
auto H = Rn<ID>{H_};
|
||||
auto P = Rn<ID>{P_};
|
||||
auto G = Rn<ID>{G_};
|
||||
auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
|
||||
auto bOffset = Rn<ID>{H_ * P_};
|
||||
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
@ -167,10 +171,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get(
|
||||
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
|
||||
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
|
||||
+ cOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
@ -180,10 +183,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
|
||||
cp_shared_global<16>(b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get(
|
||||
(aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
|
||||
+ cn<0> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *xbcDim
|
||||
+ bOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
@ -285,10 +287,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
|
||||
* G * N
|
||||
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
|
||||
+ cOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
@ -299,10 +300,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - nStart * cn<tileN_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *cn<2>
|
||||
* G * N
|
||||
+ cn<0> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + nStart * cn<tileN_> + thread(iStep) / cn<tileK_>) *xbcDim
|
||||
+ bOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
|
||||
@ -29,57 +29,54 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_,
|
||||
typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
// const half *g_mxY_, // B*L*H*P
|
||||
// const half *g_mxOs_, // B*C*H*N*P
|
||||
// const half *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float* g_mxdc_, // B*C*H*Q
|
||||
float* g_mxdA_, // B*C*H*Q
|
||||
half const* g_mxdt_, // B*L*H
|
||||
half const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
float const* g_mxdb_, // H
|
||||
float const* g_mxA_, // H
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
// const half *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_,
|
||||
typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
// const bf16 *g_mxY_, // B*L*H*P
|
||||
// const bf16 *g_mxOs_, // B*C*H*N*P
|
||||
// const bf16 *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float* g_mxdc_, // B*C*H*Q
|
||||
float* g_mxdA_, // B*C*H*Q
|
||||
bf16 const* g_mxdt_, // B*L*H
|
||||
bf16 const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
float const* g_mxdb_, // H
|
||||
float const* g_mxA_, // H
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
// const bf16 *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileH_, int warpH_, bool dtSoftplus_, class Tp_, class Wt_ = float>
|
||||
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_cumsum_kernel(int B_,
|
||||
int L_, int H_,
|
||||
int L_, int H_, int P_, int G_, int N_,
|
||||
// const Tp_ *g_mxY_, // B*L*H*P
|
||||
// const Tp_ *g_mxOs_, // B*C*H*N*P
|
||||
// const Tp_ *g_mxFs_, // B *H*N*P
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float* g_mxdc_, // B*C*H*Q
|
||||
float* g_mxdA_, // B*C*H*Q
|
||||
Tp_ const* g_mxdt_, // B*L*H
|
||||
Tp_ const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
Wt_ const* g_mxdb_, // H
|
||||
Wt_ const* g_mxA_, // H
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
// const Tp_ *g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
using namespace tensorrt_llm::common;
|
||||
@ -94,11 +91,12 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
// auto B = Rn<ID>{B_};
|
||||
auto L = Rn<ID>{L_};
|
||||
auto H = Rn<ID>{H_};
|
||||
// auto P = Rn<ID>{P_};
|
||||
// auto G = Rn<ID>{G_};
|
||||
// auto N = Rn<ID>{N_};
|
||||
auto P = Rn<ID>{P_};
|
||||
auto G = Rn<ID>{G_};
|
||||
auto N = Rn<ID>{N_};
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
auto dt_dim = g_mxZ_ ? Rn<ID>{2 * H_ * P_ + 2 * G_ * N_ + H_} : Rn<ID>{H_ * P_ + 2 * G_ * N_ + H_};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
@ -138,7 +136,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
|
||||
if (thread(iStep) < cn<tileH_> && blockIdx_y * Q + iQ < L)
|
||||
{
|
||||
r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * H + blockIdx_x * cn<tileH_> + thread(iStep))])
|
||||
r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * dt_dim + dt_dim - H + blockIdx_x * cn<tileH_>
|
||||
+ thread(iStep))])
|
||||
+ r_db;
|
||||
|
||||
if (dtSoftplus_)
|
||||
|
||||
@ -36,14 +36,13 @@ typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_,
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
half const* g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxBC_, // B*L*2*G*N
|
||||
float const* g_mxD_, // H
|
||||
half const* g_mxX_, // B*L*H*P
|
||||
half const* g_mxZ_, // B*L*H*P
|
||||
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
@ -53,14 +52,13 @@ typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_,
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
bf16 const* g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxBC_, // B*L*2*G*N
|
||||
float const* g_mxD_, // H
|
||||
bf16 const* g_mxX_, // B*L*H*P
|
||||
bf16 const* g_mxZ_, // B*L*H*P
|
||||
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
@ -75,14 +73,13 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
// const float *g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
Tp_ const* g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxBC_, // B*L*2*G*N
|
||||
Wt_ const* g_mxD_, // H
|
||||
Tp_ const* g_mxX_, // B*L*H*P
|
||||
Tp_ const* g_mxZ_, // B*L*H*P
|
||||
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
@ -105,6 +102,10 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
|
||||
auto zdtDim = Rn<ID>{2 * H_ * P_ + 2 * G_ * N_ + H_};
|
||||
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
@ -185,10 +186,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get(
|
||||
(aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2> * G * N
|
||||
+ cn<1> * G * N + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
|
||||
+ cOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
@ -365,10 +365,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
|
||||
cp_shared_global<16>(
|
||||
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *cn<2>
|
||||
* G * N
|
||||
+ cn<1> * G * N + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
|
||||
+ cOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
|
||||
= int4{0, 0, 0, 0};
|
||||
@ -402,8 +401,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_> + N)
|
||||
cp_shared_global<16>(
|
||||
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
|
||||
g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *H * P
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *xbcDim
|
||||
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxOs
|
||||
@ -434,10 +433,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
|
||||
< L)
|
||||
{
|
||||
*(int*) &tmp16[0] = *(int*) (g_mxX_
|
||||
*(int*) &tmp16[0] = *(int*) (g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *xbcDim
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
@ -452,10 +450,9 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
|
||||
< L)
|
||||
{
|
||||
*(int*) &tmp16[2] = *(int*) (g_mxX_
|
||||
*(int*) &tmp16[2] = *(int*) (g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *xbcDim
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
@ -484,8 +481,7 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
{
|
||||
*(int*) &tmp16[0] = *(int*) (g_mxZ_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *zdtDim
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
@ -502,8 +498,7 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
{
|
||||
*(int*) &tmp16[2] = *(int*) (g_mxZ_
|
||||
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
|
||||
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *H
|
||||
* P
|
||||
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *zdtDim
|
||||
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
|
||||
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
|
||||
|
||||
|
||||
@ -36,14 +36,13 @@ typedef void (*ChunkStateKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_,
|
||||
float* g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
half const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
half const* g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
|
||||
@ -53,14 +52,13 @@ typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_,
|
||||
float* g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
bf16 const* g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
bf16 const* g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_);
|
||||
|
||||
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
|
||||
@ -75,14 +73,13 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
float* g_mxSt_, // B*C*H*N*P
|
||||
float const* g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
Tp_ const* g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
Tp_ const* g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
@ -105,6 +102,10 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
auto Q = cn<Q_>;
|
||||
auto C = Rn<ID>{div_up(L.var, Q_)};
|
||||
|
||||
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
|
||||
auto bOffset = Rn<ID>{H_ * P_};
|
||||
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
|
||||
|
||||
auto aStart = blockIdx_z * L;
|
||||
auto cStart = blockIdx_z * C;
|
||||
|
||||
@ -183,8 +184,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileM_ * tileK_>
|
||||
&& thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - iK * cn<tileK_>)
|
||||
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileM_>) *xbcDim + bOffset
|
||||
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(iK) * cn<2>))
|
||||
@ -195,8 +196,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileN_ * tileK_>
|
||||
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - iK * cn<tileK_>)
|
||||
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>),
|
||||
g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *xbcDim
|
||||
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_>)
|
||||
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(iK) * cn<2>))
|
||||
@ -329,8 +330,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileM_ * tileK_> && thread(iStep) / cn<tileM_> < L - blockIdx_y * Q - jK * cn<tileK_>
|
||||
&& jK * cn<tileK_> < Q)
|
||||
cp_shared_global<16>(b_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>),
|
||||
g_mxBC_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileM_>) *cn<2> * G * N
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileM_>) *xbcDim + bOffset
|
||||
+ gStart * N + mStart * cn<tileM_> + thread(iStep) % cn<tileM_>));
|
||||
else if (thread(iStep) < cn<tileM_ * tileK_> && jK * cn<tileK_> < Q)
|
||||
*(int4*) ((char*) s_mxB + swizzle<tileM_ * 2, tileM_ * 2>(thread(iStep) * cn<2>, baseB(jK) * cn<2>))
|
||||
@ -341,8 +342,8 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
if (thread(iStep) < cn<tileN_ * tileK_> && thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_>
|
||||
&& jK * cn<tileK_> < Q)
|
||||
cp_shared_global<16>(b_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>),
|
||||
g_mxX_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *H * P
|
||||
g_mxXBC_
|
||||
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *xbcDim
|
||||
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
|
||||
else if (thread(iStep) < cn<tileN_ * tileK_> && jK * cn<tileK_> < Q)
|
||||
*(int4*) ((char*) s_mxX + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseX(jK) * cn<2>))
|
||||
|
||||
@ -36,14 +36,13 @@ typedef void (*StatePassingKernelFuncFp16)(int B_, int L_, int H_, int P_, int N
|
||||
float const* g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const half *g_mxdt_, // B*L*H
|
||||
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const half *g_mxCB_, // B*C*G*Q*Q
|
||||
// const half *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const half *g_mxX_, // B*L*H*P
|
||||
// const half *g_mxZ_, // B*L*H*P
|
||||
// const half *g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
|
||||
|
||||
typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N_,
|
||||
@ -53,14 +52,13 @@ typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N
|
||||
float const* g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const bf16 *g_mxdt_, // B*L*H
|
||||
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const float *g_mxdb_, // H
|
||||
// const float *g_mxA_, // H
|
||||
// const bf16 *g_mxCB_, // B*C*G*Q*Q
|
||||
// const bf16 *g_mxBC_, // B*L*2*G*N
|
||||
// const float *g_mxD_, // H
|
||||
// const bf16 *g_mxX_, // B*L*H*P
|
||||
// const bf16 *g_mxZ_, // B*L*H*P
|
||||
// const bf16 *g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
|
||||
|
||||
template <int Q_, int tileH_, int warpH_, class Tp_>
|
||||
@ -72,14 +70,13 @@ __global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __n
|
||||
float const* g_mxSt_, // B*C*H*N*P
|
||||
// const float *g_mxdc_, // B*C*H*Q
|
||||
float const* g_mxdA_, // B*C*H*Q
|
||||
// const Tp_ *g_mxdt_, // B*L*H
|
||||
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
|
||||
// const Wt_ *g_mxdb_, // H
|
||||
// const Wt_ *g_mxA_, // H
|
||||
// const Tp_ *g_mxCB_, // B*C*G*Q*Q
|
||||
// const Tp_ *g_mxBC_, // B*L*2*G*N
|
||||
// const Wt_ *g_mxD_, // H
|
||||
// const Tp_ *g_mxX_, // B*L*H*P
|
||||
// const Tp_ *g_mxZ_, // B*L*H*P
|
||||
// const Tp_ *g_mxXBC_, // B*L*(H*P+2*G*N)
|
||||
// const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H)
|
||||
bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_)
|
||||
{
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
@ -91,6 +91,7 @@ __global__ void cumsum_last_dim(T const* d_in, T* d_out, int length)
|
||||
|
||||
// Store items from a blocked arrangement
|
||||
BlockStoreT(temp_storage.store).Store(cur_d_out, data, cur_tile_size);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -70,8 +70,8 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
||||
}
|
||||
}
|
||||
|
||||
bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
|
||||
int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only)
|
||||
bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape,
|
||||
int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only)
|
||||
{
|
||||
|
||||
// All tile sizes have a k_tile of 64.
|
||||
@ -181,7 +181,7 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
{
|
||||
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -196,28 +196,29 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool supports_mcast_along_m(const CutlassTileConfigSM90 tile)
|
||||
bool supports_mcast_along_m(CutlassTileConfigSM90 const tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM90::CtaShape256x128x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
|
||||
bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
@ -288,8 +289,8 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
}
|
||||
|
||||
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
|
||||
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only)
|
||||
std::vector<int> const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts,
|
||||
int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only)
|
||||
{
|
||||
|
||||
if (occupancies.size() != candidate_configs.size())
|
||||
|
||||
@ -247,7 +247,7 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
|
||||
TLLM_CHECK(hopper_input.ptr_b);
|
||||
TLLM_CHECK(hopper_input.ptr_d);
|
||||
|
||||
const MainloopArguments mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
|
||||
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
|
||||
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
|
||||
|
||||
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
|
||||
@ -255,12 +255,15 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
|
||||
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
|
||||
const EpilogueArguments epilogue_params
|
||||
EpilogueArguments const epilogue_params
|
||||
= {epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c,
|
||||
reinterpret_cast<ElementD**>(hopper_input.ptr_d), hopper_input.stride_d};
|
||||
|
||||
typename GemmKernel::TileScheduler::Arguments scheduler_args{
|
||||
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
|
||||
|
||||
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
|
||||
mainloop_params, epilogue_params, hw_info};
|
||||
mainloop_params, epilogue_params, hw_info, scheduler_args};
|
||||
|
||||
size_t calculated_ws_size = gemm.get_workspace_size(args);
|
||||
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
|
||||
|
||||
@ -190,6 +190,7 @@ void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int
|
||||
SHAPE_CASE(128, 64, 128)
|
||||
SHAPE_CASE(128, 128, 128)
|
||||
SHAPE_CASE(128, 256, 128)
|
||||
SHAPE_CASE(256, 128, 128)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
|
||||
@ -323,7 +323,7 @@ def generate_sm90_grouped_gemm_operations():
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
||||
M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM
|
||||
N_TILES = [16, 32, 64, 128, 256]
|
||||
cta_shapes_mn = product(M_TILES, N_TILES)
|
||||
cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)]
|
||||
|
||||
warp_shape = [0, 0, 0] # ignored except for naming
|
||||
stages = 0 # auto
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
8b0f8deb35940359b39f876fc5e94e4f libtensorrt_llm_nvrtc_wrapper.so
|
||||
d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit
|
||||
db055e58b6c6c8cf7350b66a583f9c388c4eac07 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:78209a1351f9f21f635bf9f763f4947031ea12b7526c5782094e9869b667a23f
|
||||
oid sha256:c439d4074454207e5a26887a041d3e7868dd05dab30b903536bac5428758c9eb
|
||||
size 1091072
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:38d470721122f47b75e91b7967fa56cebcb48c8abcc4b3ddefe4f39c85d061f8
|
||||
oid sha256:ba9784dd196da1c35f9326de49b0851395f7d858131b9a60334211feaec34c52
|
||||
size 3488
|
||||
|
||||
@ -231,8 +231,11 @@ public:
|
||||
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, mTileSize);
|
||||
int const* maskPtr = xqaParams.spec_decoding_packed_mask;
|
||||
int const* cuQSeqLens = launchParams.cu_seq_lens;
|
||||
unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable ? // true for ReDrafter
|
||||
xqaParams.spec_decoding_max_generation_length
|
||||
: qSeqLen;
|
||||
// TODO: merge SingleQueryToken params and MultiQueryTokens params into one kernelParams.
|
||||
void* kernelParams[] = {&qSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
|
||||
void* kernelParams[] = {&maxQSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
|
||||
&launchParams.output, &xqa_q_input_ptr, &maskPtr, &launchParams.kvCacheParams, &launchParams.batch_size,
|
||||
&launchParams.kv_scale_quant_orig, &launchParams.scratch};
|
||||
int multi_block = 1;
|
||||
|
||||
@ -45,12 +45,14 @@ struct XQAParams
|
||||
int32_t sink_token_length = 0;
|
||||
int timestep = 0;
|
||||
void const* qkv_bias;
|
||||
int32_t const* sequence_lengths; //
|
||||
int32_t const* context_lengths; // maybe not used now
|
||||
void const* alibi_slopes; // maybe not used now
|
||||
int32_t const* sequence_lengths; //
|
||||
int32_t const* context_lengths; // maybe not used now
|
||||
void const* alibi_slopes; // maybe not used now
|
||||
int32_t const* spec_decoding_packed_mask;
|
||||
int const* spec_decoding_position_offsets; // rotary embedding.
|
||||
int const* spec_decoding_generation_lengths; // variable input lengths.
|
||||
int const* spec_decoding_position_offsets; // for position embedding.
|
||||
int const* spec_decoding_generation_lengths; // variable input lengths.
|
||||
bool spec_decoding_is_generation_length_variable; // whether the generation lengths actually vary
|
||||
int32_t spec_decoding_max_generation_length; // max possible input length
|
||||
|
||||
// almost copy from GPTAttentionPluginCommon.
|
||||
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.
|
||||
|
||||
@ -482,11 +482,12 @@ __global__ void finalizeKernel(BeamHypotheses bh)
|
||||
|
||||
void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s %s start", __FILE__, __PRETTY_FUNCTION__);
|
||||
|
||||
int const nBM = bh.nBeamWidth;
|
||||
size_t const smem_size = sizeof(int) * nBM * 2 + sizeof(float) * nBM * 2;
|
||||
finalizeKernel<<<bh.nBatchSize, roundUp(nBM * 2, 32), smem_size, stream>>>(bh);
|
||||
TLLM_LOG_TRACE("%s %s stop", __FILE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
__global__ void initializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const nMaxSeqLen)
|
||||
|
||||
@ -197,9 +197,9 @@ __global__
|
||||
|| std::is_same_v<T_, __nv_bfloat16>
|
||||
#endif
|
||||
>
|
||||
mambaConv1dContextKernel(int B_, int L_, int D_, T_* g_mxYa_, T_* g_mxYs_, T_ const* g_mxXa_, T_ const* g_mxXs_,
|
||||
T_ const* g_mxW_, T_ const* g_mxB_, bool removePadding_, bool applySilu_, int const* lastTokenIdsPtr_,
|
||||
int const* stateSlotMappingPtr_ = nullptr)
|
||||
mambaConv1dContextKernel(int B_, int L_, int D_, int S_pre_, int S_post_, T_* g_mxYa_, T_* g_mxYs_,
|
||||
T_ const* g_mxXa_, T_ const* g_mxXs_, T_ const* g_mxW_, T_ const* g_mxB_, bool removePadding_, bool applySilu_,
|
||||
int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_ = nullptr)
|
||||
{
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
@ -262,17 +262,20 @@ __global__
|
||||
int STEP = 256 * warpL_ * warpD_;
|
||||
|
||||
int L = L_;
|
||||
int DS_ = (D_ + S_pre_ + S_post_);
|
||||
|
||||
int aStart = blockIdx.z * L_ * D_;
|
||||
int aIStart = blockIdx.z * L_ * DS_;
|
||||
int aOStart = blockIdx.z * L_ * D_;
|
||||
int sStart = blockIdx.z * (K_ - 1) * D_;
|
||||
int lStart = blockIdx.y * tileL_;
|
||||
int dStart = blockIdx.x * tileD_;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
|
||||
L = lastTokenIdsPtr_[blockIdx.z] - aStart;
|
||||
aStart = aStart * D_;
|
||||
aIStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
|
||||
L = lastTokenIdsPtr_[blockIdx.z] - aIStart;
|
||||
aOStart = aIStart * D_;
|
||||
aIStart = aIStart * DS_;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -295,8 +298,8 @@ __global__
|
||||
+ 2
|
||||
* swizzle<tileD_ * 2, tileD_, T_>(
|
||||
i + thread * 8 + (warpL_ * laneL * pipe_ + 1 - K_) * tileD_),
|
||||
g_mxXa_ + aStart + (1 - K_ + lStart + thread * 8 / tileD_) * D_ + i * (D_ / tileD_) + dStart
|
||||
+ thread * 8 % tileD_);
|
||||
g_mxXa_ + aIStart + (1 - K_ + lStart + thread * 8 / tileD_) * DS_ + S_pre_ + i * (D_ / tileD_)
|
||||
+ dStart + thread * 8 % tileD_);
|
||||
}
|
||||
else if (g_mxXs_)
|
||||
{
|
||||
@ -331,7 +334,7 @@ __global__
|
||||
if (i + STEP <= warpL_ * laneL * tileD_ || i + thread * 8 < warpL_ * laneL * tileD_)
|
||||
cp_shared_global<16>(
|
||||
b_mxX + 2 * swizzle<tileD_ * 2, tileD_, T_>(i + thread * 8 + iL * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
|
||||
}
|
||||
else
|
||||
@ -341,7 +344,7 @@ __global__
|
||||
if (i + thread * 8 < (L - lStart - iL * warpL_ * laneL) * tileD_)
|
||||
cp_shared_global<16>(
|
||||
b_mxX + 2 * swizzle<tileD_ * 2, tileD_, T_>(i + thread * 8 + iL * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
|
||||
}
|
||||
|
||||
@ -436,7 +439,7 @@ __global__
|
||||
+ 2
|
||||
* swizzle<tileD_ * 2, tileD_, T_>(
|
||||
i + thread * 8 + jL % pipe_ * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
|
||||
}
|
||||
else if (jL < tileL_ / (warpL_ * laneL))
|
||||
@ -448,14 +451,14 @@ __global__
|
||||
+ 2
|
||||
* swizzle<tileD_ * 2, tileD_, T_>(
|
||||
i + thread * 8 + jL % pipe_ * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 8 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 8 % tileD_);
|
||||
}
|
||||
|
||||
cp_commit_group();
|
||||
|
||||
int offset
|
||||
= aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_ + dStart + thread * 8 % tileD_;
|
||||
= aOStart + iL * warpL_ * laneL * D_ + (lStart + thread * 8 / tileD_) * D_ + dStart + thread * 8 % tileD_;
|
||||
|
||||
if (lStart + (iL + 1) * warpL_ * laneL <= L)
|
||||
{
|
||||
@ -501,9 +504,9 @@ __global__
|
||||
}
|
||||
|
||||
template <int K_, int tileL_, int tileD_, int warpL_, int warpD_, int laneD_, int pipe_, typename T_>
|
||||
__global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(int B_, int L_, int D_, T_* g_mxYa_,
|
||||
T_* g_mxYs_, T_ const* g_mxXa_, T_ const* g_mxXs_, T_ const* g_mxW_, T_ const* g_mxB_, bool removePadding_,
|
||||
bool applySilu_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_ = nullptr)
|
||||
__global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(int B_, int L_, int D_, int S_pre_,
|
||||
int S_post_, T_* g_mxYa_, T_* g_mxYs_, T_ const* g_mxXa_, T_ const* g_mxXs_, T_ const* g_mxW_, T_ const* g_mxB_,
|
||||
bool removePadding_, bool applySilu_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_ = nullptr)
|
||||
{
|
||||
static_assert(laneD_ >= 1 && laneD_ <= 32 && (laneD_ & (laneD_ - 1)) == 0);
|
||||
|
||||
@ -542,17 +545,20 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
|
||||
int STEP = 128 * warpL_ * warpD_;
|
||||
|
||||
int L = L_;
|
||||
int DS_ = (D_ + S_pre_ + S_post_);
|
||||
|
||||
int aStart = blockIdx.z * L_ * D_;
|
||||
int aIStart = blockIdx.z * L_ * DS_;
|
||||
int aOStart = blockIdx.z * L_ * D_;
|
||||
int sStart = blockIdx.z * (K_ - 1) * D_;
|
||||
int lStart = blockIdx.y * tileL_;
|
||||
int dStart = blockIdx.x * tileD_;
|
||||
|
||||
if (removePadding_)
|
||||
{
|
||||
aStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
|
||||
L = lastTokenIdsPtr_[blockIdx.z] - aStart;
|
||||
aStart = aStart * D_;
|
||||
aIStart = blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0;
|
||||
L = lastTokenIdsPtr_[blockIdx.z] - aIStart;
|
||||
aOStart = aIStart * D_;
|
||||
aIStart = aIStart * DS_;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -575,8 +581,8 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
|
||||
+ 4
|
||||
* swizzle<tileD_ * 4, tileD_, T_>(
|
||||
i + thread * 4 + (warpL_ * laneL * pipe_ + 1 - K_) * tileD_),
|
||||
g_mxXa_ + aStart + (1 - K_ + lStart + thread * 4 / tileD_) * D_ + i * (D_ / tileD_) + dStart
|
||||
+ thread * 4 % tileD_);
|
||||
g_mxXa_ + aIStart + (1 - K_ + lStart + thread * 4 / tileD_) * DS_ + S_pre_ + i * (D_ / tileD_)
|
||||
+ dStart + thread * 4 % tileD_);
|
||||
}
|
||||
else if (g_mxXs_)
|
||||
{
|
||||
@ -611,7 +617,7 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
|
||||
if (i + STEP <= warpL_ * laneL * tileD_ || i + thread * 4 < warpL_ * laneL * tileD_)
|
||||
cp_shared_global<16>(
|
||||
b_mxX + 4 * swizzle<tileD_ * 4, tileD_, T_>(i + thread * 4 + iL * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
|
||||
}
|
||||
else
|
||||
@ -621,7 +627,7 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
|
||||
if (i + thread * 4 < (L - lStart - iL * warpL_ * laneL) * tileD_)
|
||||
cp_shared_global<16>(
|
||||
b_mxX + 4 * swizzle<tileD_ * 4, tileD_, T_>(i + thread * 4 + iL * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + iL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
|
||||
}
|
||||
|
||||
@ -684,7 +690,7 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
|
||||
+ 4
|
||||
* swizzle<tileD_ * 4, tileD_, T_>(
|
||||
i + thread * 4 + jL % pipe_ * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
|
||||
}
|
||||
else if (jL < tileL_ / (warpL_ * laneL))
|
||||
@ -696,14 +702,14 @@ __global__ std::enable_if_t<std::is_same_v<T_, float>> mambaConv1dContextKernel(
|
||||
+ 4
|
||||
* swizzle<tileD_ * 4, tileD_, T_>(
|
||||
i + thread * 4 + jL % pipe_ * warpL_ * laneL * tileD_),
|
||||
g_mxXa_ + aStart + jL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_
|
||||
g_mxXa_ + aIStart + jL * warpL_ * laneL * DS_ + (lStart + thread * 4 / tileD_) * DS_ + S_pre_
|
||||
+ i * (D_ / tileD_) + dStart + thread * 4 % tileD_);
|
||||
}
|
||||
|
||||
cp_commit_group();
|
||||
|
||||
int offset
|
||||
= aStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_ + dStart + thread * 4 % tileD_;
|
||||
= aOStart + iL * warpL_ * laneL * D_ + (lStart + thread * 4 / tileD_) * D_ + dStart + thread * 4 % tileD_;
|
||||
|
||||
if (lStart + (iL + 1) * warpL_ * laneL <= L)
|
||||
{
|
||||
@ -755,6 +761,8 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
|
||||
int L = params.max_seqlen;
|
||||
int D = params.dim;
|
||||
int K = params.dconv;
|
||||
int S_pre = params.pre_stride;
|
||||
int S_post = params.post_stride;
|
||||
|
||||
int tileL = 32;
|
||||
int tileD = 128;
|
||||
@ -763,9 +771,9 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
|
||||
int laneD = 4;
|
||||
int pipe = 4;
|
||||
|
||||
void (*f)(int B_, int L_, int D_, input_t* g_mxYa_, input_t* g_mxYs_, input_t const* g_mxXa_,
|
||||
input_t const* g_mxXs_, input_t const* g_mxW_, input_t const* g_mxB_, bool removePadding_, bool applySilu_,
|
||||
int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
|
||||
void (*f)(int B_, int L_, int D_, int S_pre_, int S_post_, input_t* g_mxYa_, input_t* g_mxYs_,
|
||||
input_t const* g_mxXa_, input_t const* g_mxXs_, input_t const* g_mxW_, input_t const* g_mxB_,
|
||||
bool removePadding_, bool applySilu_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_);
|
||||
|
||||
if (std::is_same_v<input_t, float>)
|
||||
{
|
||||
@ -1071,7 +1079,7 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
|
||||
|
||||
cudaFuncSetAttribute(f, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem);
|
||||
|
||||
f<<<blks, thds, shmem, stream>>>(B, L, D, ya, ys, xa, xs, w, b, rmpd, silu, ltip, ssmp);
|
||||
f<<<blks, thds, shmem, stream>>>(B, L, D, S_pre, S_post, ya, ys, xa, xs, w, b, rmpd, silu, ltip, ssmp);
|
||||
}
|
||||
|
||||
template <typename input_t, int DCONV = 4, int CHANNELS_PER_THREAD = 4>
|
||||
@ -1088,6 +1096,7 @@ __launch_bounds__(64, 8) __global__
|
||||
int num_channels = params.dim;
|
||||
int const micro_batch = blockIdx.y;
|
||||
int const channel = (blockIdx.x * blockDim.x + threadIdx.x) * CHANNELS_PER_THREAD;
|
||||
int const num_channels_in = num_channels + params.pre_stride + params.post_stride;
|
||||
|
||||
if (channel >= num_channels)
|
||||
{
|
||||
@ -1096,7 +1105,7 @@ __launch_bounds__(64, 8) __global__
|
||||
|
||||
weight += channel;
|
||||
bias += channel;
|
||||
input += channel;
|
||||
input += (channel + params.pre_stride);
|
||||
output += channel;
|
||||
state_in += channel;
|
||||
state_out += channel;
|
||||
@ -1119,7 +1128,7 @@ __launch_bounds__(64, 8) __global__
|
||||
++sample)
|
||||
{
|
||||
int const slot_idx = params.state_slot_mapping_ptr == nullptr ? sample : params.state_slot_mapping_ptr[sample];
|
||||
input_t* token_input = input + sample * params.dim;
|
||||
input_t* token_input = input + sample * num_channels_in;
|
||||
input_t* token_output = output + sample * params.dim;
|
||||
input_t* token_state_in = state_in + slot_idx * (params.dconv - 1) * params.dim;
|
||||
input_t* token_state_out = state_out + slot_idx * (params.dconv - 1) * params.dim;
|
||||
|
||||
@ -26,7 +26,7 @@ namespace kernels
|
||||
|
||||
struct MambaConv1dParamsBase
|
||||
{
|
||||
int batch, dim, max_seqlen, dconv;
|
||||
int batch, dim, max_seqlen, dconv, pre_stride, post_stride;
|
||||
bool remove_padding;
|
||||
bool apply_silu;
|
||||
void* __restrict__ in_ptr;
|
||||
|
||||
@ -370,9 +370,8 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
float const* mxdb = (float const*) params.delta_bias_ptr;
|
||||
float const* mxA = (float const*) params.A_ptr;
|
||||
half* mxCB = (half*) params.CB_ptr;
|
||||
half const* mxBC = (half const*) params.BC_ptr;
|
||||
float const* mxD = (float const*) params.D_ptr;
|
||||
half const* mxX = (half const*) params.u_ptr;
|
||||
half const* mxXBC = (half const*) params.u_ptr;
|
||||
half const* mxZ = (half const*) params.z_ptr;
|
||||
|
||||
auto rp = params.remove_padding;
|
||||
@ -380,16 +379,16 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
auto ssmp = params.slot_mapping_ptr;
|
||||
|
||||
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
|
||||
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
|
||||
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, P, G, N, mxdc, mxdA, mxdt, mxdb, mxA, mxZ, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
|
||||
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
|
||||
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxXBC, rp, ltip);
|
||||
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
|
||||
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
|
||||
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
|
||||
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
|
||||
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, H, P, G, N, mxCB, mxXBC, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
|
||||
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
|
||||
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
|
||||
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxD, mxXBC, mxZ, rp, ltip);
|
||||
}
|
||||
else if constexpr (std::is_same_v<input_t, __nv_bfloat16>)
|
||||
{
|
||||
@ -413,9 +412,8 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
float const* mxdb = (float const*) params.delta_bias_ptr;
|
||||
float const* mxA = (float const*) params.A_ptr;
|
||||
__nv_bfloat16* mxCB = (__nv_bfloat16*) params.CB_ptr;
|
||||
__nv_bfloat16 const* mxBC = (__nv_bfloat16 const*) params.BC_ptr;
|
||||
float const* mxD = (float const*) params.D_ptr;
|
||||
__nv_bfloat16 const* mxX = (__nv_bfloat16 const*) params.u_ptr;
|
||||
__nv_bfloat16 const* mxXBC = (__nv_bfloat16 const*) params.u_ptr;
|
||||
__nv_bfloat16 const* mxZ = (__nv_bfloat16 const*) params.z_ptr;
|
||||
|
||||
auto rp = params.remove_padding;
|
||||
@ -423,16 +421,16 @@ void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
auto ssmp = params.slot_mapping_ptr;
|
||||
|
||||
cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]);
|
||||
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip);
|
||||
chunk_cumsum<<<bds[0], tds[0], shms[0], stream>>>(B, L, H, P, G, N, mxdc, mxdA, mxdt, mxdb, mxA, mxZ, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]);
|
||||
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip);
|
||||
chunk_state<<<bds[1], tds[1], shms[1], stream>>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxXBC, rp, ltip);
|
||||
cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]);
|
||||
state_passing<<<bds[2], tds[2], shms[2], stream>>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp);
|
||||
cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]);
|
||||
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, G, N, mxCB, mxBC, rp, ltip);
|
||||
bmm_chunk<<<bds[3], tds[3], shms[3], stream>>>(B, L, H, P, G, N, mxCB, mxXBC, rp, ltip);
|
||||
cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]);
|
||||
chunk_scan<<<bds[4], tds[4], shms[4], stream>>>(
|
||||
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip);
|
||||
B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxD, mxXBC, mxZ, rp, ltip);
|
||||
}
|
||||
}
|
||||
|
||||
@ -458,7 +456,8 @@ INSTANTIATE_CHUNK_SCAN_DATA_TYPE(__nv_bfloat16, float);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, bool MAMBA_V1 = true>
|
||||
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, bool MAMBA_V1 = true,
|
||||
int STATE_UNROLL = 16>
|
||||
__launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParamsBase params)
|
||||
{
|
||||
|
||||
@ -485,24 +484,43 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
int const head = channel / head_dim;
|
||||
int const head_chl = channel % head_dim;
|
||||
int const group = head / (nheads / ngroups);
|
||||
|
||||
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
|
||||
int const bc_offset = MAMBA_V1 ? sample * (DSTATE * 2 + params.dt_rank) : sample * DSTATE * ngroups * 2;
|
||||
int const b_offset = MAMBA_V1 ? params.dt_rank : DSTATE * group;
|
||||
int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : DSTATE * (ngroups + group);
|
||||
int const dt_d_idx = MAMBA_V1 ? channel : head;
|
||||
int const bc_dim = MAMBA_V1 ? 2 * DSTATE : 2 * ngroups * DSTATE;
|
||||
int const x_dim = MAMBA_V1 ? num_channels : num_channels + bc_dim;
|
||||
int const z_dim = MAMBA_V1 ? num_channels : 2 * num_channels + bc_dim + nheads;
|
||||
int const dt_dim = MAMBA_V1 ? num_channels : (z ? z_dim : z_dim - num_channels);
|
||||
int const dt_offset = MAMBA_V1 ? sample * dt_dim : sample * dt_dim + dt_dim - nheads;
|
||||
int const bc_offset = MAMBA_V1 ? sample * (bc_dim + params.dt_rank) : sample * (num_channels + bc_dim);
|
||||
int const b_offset = MAMBA_V1 ? params.dt_rank : num_channels + DSTATE * group;
|
||||
int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : num_channels + DSTATE * (ngroups + group);
|
||||
|
||||
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
|
||||
input_t* my_output = &output[sample * num_channels];
|
||||
|
||||
float rA[DSTATE];
|
||||
float rB[DSTATE];
|
||||
float rC[DSTATE];
|
||||
float rState[DSTATE];
|
||||
float my_x, my_dt, my_z, my_dt_bias, my_D;
|
||||
my_x = toFloat(x[sample * num_channels + channel]);
|
||||
my_z = z ? toFloat(z[sample * num_channels + channel]) : 0.f;
|
||||
int const state_loops = (DSTATE + STATE_UNROLL - 1) / STATE_UNROLL;
|
||||
|
||||
float my_x, my_dt, my_z, my_dt_bias, out;
|
||||
my_x = toFloat(x[sample * x_dim + channel]);
|
||||
my_z = z ? toFloat(z[sample * z_dim + channel]) : 0.f;
|
||||
my_dt = toFloat(dt[dt_offset + dt_d_idx]);
|
||||
my_dt_bias = dt_bias ? toFloat(dt_bias[dt_d_idx]) : 0.f;
|
||||
out = D ? toFloat(D[dt_d_idx]) * my_x : 0.f;
|
||||
|
||||
float dt_b = my_dt + my_dt_bias;
|
||||
float dt_b_sp = 1.0f;
|
||||
if (dt_softplus)
|
||||
{
|
||||
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
|
||||
}
|
||||
|
||||
if (MAMBA_V1)
|
||||
{
|
||||
float rA[DSTATE];
|
||||
float rB[DSTATE];
|
||||
float rC[DSTATE];
|
||||
float rState[DSTATE];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
@ -511,49 +529,48 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
rC[i] = toFloat(C[bc_offset + c_offset + i]);
|
||||
rState[i] = toFloat(my_state[i * num_channels + channel]);
|
||||
}
|
||||
my_dt = toFloat(dt[sample * num_channels + channel]);
|
||||
my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f;
|
||||
my_D = D ? toFloat(D[channel]) : 0.f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
float dA = __expf(rA[i] * dt_b_sp);
|
||||
float dB = rB[i] * dt_b_sp;
|
||||
float sdA = rState[i] * dA;
|
||||
float dBx = dB * my_x;
|
||||
float newState = sdA + dBx;
|
||||
// Write the new state back out to the cache
|
||||
convertAndStore(&my_state[i * num_channels + channel], newState);
|
||||
out += newState * rC[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
float A_tmp = toFloat(A[head]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
float rB[STATE_UNROLL];
|
||||
float rC[STATE_UNROLL];
|
||||
float rState[STATE_UNROLL];
|
||||
for (int si = 0; si < state_loops; si++)
|
||||
{
|
||||
rA[i] = A_tmp;
|
||||
rB[i] = toFloat(B[bc_offset + b_offset + i]);
|
||||
rC[i] = toFloat(C[bc_offset + c_offset + i]);
|
||||
rState[i] = toFloat(my_state[(head * DSTATE + i) * head_dim + head_chl]);
|
||||
}
|
||||
my_dt = toFloat(dt[sample * nheads + head]);
|
||||
my_dt_bias = dt_bias ? toFloat(dt_bias[head]) : 0.f;
|
||||
my_D = D ? toFloat(D[head]) : 0.f;
|
||||
}
|
||||
|
||||
float dt_b = my_dt + my_dt_bias;
|
||||
float dt_b_sp;
|
||||
if (dt_softplus)
|
||||
{
|
||||
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
|
||||
}
|
||||
|
||||
float out = D ? my_D * my_x : 0.f;
|
||||
|
||||
int i_offset = si * STATE_UNROLL;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
float dA = __expf(rA[i] * dt_b_sp);
|
||||
float dB = rB[i] * dt_b_sp;
|
||||
float sdA = rState[i] * dA;
|
||||
float dBx = dB * my_x;
|
||||
float newState = sdA + dBx;
|
||||
// Write the new state back out to the cache
|
||||
if (MAMBA_V1)
|
||||
convertAndStore(&my_state[i * num_channels + channel], newState);
|
||||
else
|
||||
convertAndStore(&my_state[(head * DSTATE + i) * head_dim + head_chl], newState);
|
||||
out += newState * rC[i];
|
||||
for (int i = 0; i < STATE_UNROLL; i++)
|
||||
{
|
||||
rB[i] = toFloat(B[bc_offset + b_offset + i_offset + i]);
|
||||
rC[i] = toFloat(C[bc_offset + c_offset + i_offset + i]);
|
||||
rState[i] = toFloat(my_state[(head * DSTATE + i_offset + i) * head_dim + head_chl]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < STATE_UNROLL; i++)
|
||||
{
|
||||
float dA = __expf(A_tmp * dt_b_sp);
|
||||
float dB = rB[i] * dt_b_sp;
|
||||
float sdA = rState[i] * dA;
|
||||
float dBx = dB * my_x;
|
||||
float newState = sdA + dBx;
|
||||
// Write the new state back out to the cache
|
||||
convertAndStore(&my_state[(head * DSTATE + i_offset + i) * head_dim + head_chl], newState);
|
||||
out += newState * rC[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (z)
|
||||
|
||||
@ -477,10 +477,10 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
|
||||
params.outputTemperatures[batchIdx] = params.inputTemperatures[batchSlot];
|
||||
}
|
||||
|
||||
// Copy random validation data.
|
||||
auto const numDecodingDraftTokens = params.numPaths * (params.maxPathLength - 1);
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
// Copy random validation data.
|
||||
auto const numDecodingDraftTokens = params.numPaths * (params.maxPathLength - 1);
|
||||
auto outputRandomDataValidation = params.outputRandomDataValidation + genIdx * numDecodingDraftTokens;
|
||||
auto const inputRandomDataValidation = params.inputRandomDataValidation + batchSlot * numDecodingDraftTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numDecodingDraftTokens;
|
||||
@ -488,11 +488,8 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
|
||||
{
|
||||
outputRandomDataValidation[ti] = inputRandomDataValidation[ti];
|
||||
}
|
||||
}
|
||||
|
||||
// Copy draft tokens and indices
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
// Copy draft tokens and indices
|
||||
auto const numUnpackedTokens = numDecodingDraftTokens + params.numPaths;
|
||||
auto outputNextDraftTokens = params.outputNextDraftTokens + genIdx * numUnpackedTokens;
|
||||
auto outputNextDraftIndices = params.outputNextDraftIndices + genIdx * numUnpackedTokens;
|
||||
@ -504,41 +501,37 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
|
||||
outputNextDraftTokens[ti] = inputNextDraftTokens[ti];
|
||||
outputNextDraftIndices[ti] = inputNextDraftIndices[ti];
|
||||
}
|
||||
}
|
||||
|
||||
auto const maxGenerationLength = params.maxGenerationLength[0];
|
||||
auto const maxDecodingTokens = numDecodingDraftTokens + 1;
|
||||
auto const numPackedMasks = divUp(maxGenerationLength, 32);
|
||||
auto const outputMaskStartId = (genIdx == 0) ? 0 : params.cumSumGenerationLengths[genIdx - 1];
|
||||
auto const numTokens = (genIdx == 0)
|
||||
? params.cumSumGenerationLengths[0]
|
||||
: params.cumSumGenerationLengths[genIdx] - params.cumSumGenerationLengths[genIdx - 1];
|
||||
// Copy packed masks.
|
||||
// Masks are placed next to each other with offsets of cumSumGenerationLengths[bi-1]
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
auto const maxGenerationLength = params.maxGenerationLength[0];
|
||||
auto const maxDecodingTokens = numDecodingDraftTokens + 1;
|
||||
auto const numPackedMasks = divUp(maxGenerationLength, 32);
|
||||
auto const outputStartId = (genIdx == 0) ? 0 : params.cumSumGenerationLengths[genIdx - 1];
|
||||
auto const numTokens = (genIdx == 0)
|
||||
? params.cumSumGenerationLengths[0]
|
||||
: params.cumSumGenerationLengths[genIdx] - params.cumSumGenerationLengths[genIdx - 1];
|
||||
// Copy packed masks.
|
||||
// Masks are placed next to each other with offsets of cumSumGenerationLengths[bi-1]
|
||||
auto const inputPackedMask = params.inputPackedMask + batchSlot * numPackedMasks * maxDecodingTokens;
|
||||
auto outputPackedMask = params.outputPackedMask + outputMaskStartId * numPackedMasks;
|
||||
auto outputPackedMask = params.outputPackedMask + outputStartId * numPackedMasks;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens * numPackedMasks;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
outputPackedMask[ti] = inputPackedMask[ti];
|
||||
}
|
||||
}
|
||||
auto const inputPositionIds = params.inputPositionIds + batchSlot * maxDecodingTokens;
|
||||
auto outputPositionIds = params.outputPositionIds + params.numContextTokens + outputStartId;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens; ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
outputPositionIds[ti] = inputPositionIds[ti];
|
||||
}
|
||||
|
||||
// Copy pos offsets. Copy only for maxGenerationLength
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
// Copy pos offsets. Copy only for maxGenerationLength
|
||||
auto const basePosId = params.outputPositionIdsBase[batchIdx];
|
||||
auto outputPositionOffsets = params.outputPositionOffsets + genIdx * maxGenerationLength;
|
||||
auto outputPositionIds = params.outputPositionIds + genIdx * maxGenerationLength;
|
||||
auto const inputPositionIds = params.inputPositionIds + batchSlot * maxDecodingTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto const posId = inputPositionIds[ti];
|
||||
outputPositionIds[params.numContextTokens + ti] = posId;
|
||||
outputPositionOffsets[ti] = posId - basePosId + 1;
|
||||
outputPositionOffsets[ti] = inputPositionIds[ti] - basePosId + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,8 +148,8 @@ void BeamSearchLayer<T>::forwardAsync(
|
||||
bh.nMaxBatchSize = static_cast<std::int32_t>(op->outputIdsPtr.shape[0]);
|
||||
bh.nBatchSize = ip->localBatchSize;
|
||||
bh.batchSlots = ip->batchSlots ? ip->batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
bh.nBeamWidth = static_cast<std::int32_t>(op->outputIdsPtr.shape[1]);
|
||||
bh.nMaxSeqLen = static_cast<std::int32_t>(op->outputIdsPtr.shape[2]);
|
||||
bh.nBeamWidth = static_cast<std::int32_t>(op->outputIds.shape[1]);
|
||||
bh.nMaxSeqLen = static_cast<std::int32_t>(op->outputIds.shape[2]);
|
||||
bh.nVocabSize = mVocabSizePadded;
|
||||
bh.diversityRates = mDiversityRateDevice;
|
||||
bh.lengthPenalties = mLengthPenaltyDevice;
|
||||
|
||||
@ -548,6 +548,8 @@ public:
|
||||
tc::Tensor temperatures; // [maxBatchSize] on gpu
|
||||
//! Next generation lengths.
|
||||
tc::Tensor generationLengths; // [maxBatchSize] on gpu
|
||||
//! Next generation lengths on host.
|
||||
tc::Tensor generationLengthsHost; // [maxBatchSize] on pinned
|
||||
//! Maximum number of generated tokens for the next step across whole batch
|
||||
tc::Tensor maxGenLengthHost; // [1] on pinned
|
||||
};
|
||||
|
||||
@ -245,14 +245,10 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<BaseDecodingOutputs>
|
||||
}
|
||||
}
|
||||
|
||||
outputs->outputIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
|
||||
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
|
||||
static_cast<size_t>(maxSeqLen)},
|
||||
idsPtrHost);
|
||||
outputs->outputIdsPtr = Tensor(
|
||||
MEMORY_GPU, DataType::TYPE_INT32_PTR, {static_cast<size_t>(mDecoderDomain.getBatchSize())}, idsPtrHost);
|
||||
outputs->parentIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
|
||||
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
|
||||
static_cast<size_t>(maxSeqLen)},
|
||||
idsPtrHost + mDecoderDomain.getBatchSize());
|
||||
{static_cast<size_t>(mDecoderDomain.getBatchSize())}, idsPtrHost + mDecoderDomain.getBatchSize());
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
|
||||
@ -269,13 +269,6 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
|
||||
params.checkParams();
|
||||
|
||||
// Copy max generation length
|
||||
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
|
||||
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
|
||||
mStream);
|
||||
|
||||
params.checkParams();
|
||||
|
||||
// Copy max generation length
|
||||
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
|
||||
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
|
||||
@ -285,6 +278,11 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
|
||||
invokeCopyProbs(params, mStream);
|
||||
|
||||
// Copy max generation length
|
||||
cudaMemcpyAsync(outputs.generationLengthsHost.template getPtr<SizeType32>(),
|
||||
outputs.generationLengths.template getPtr<SizeType32 const>(),
|
||||
sizeof(SizeType32) * mDecoderDomain.getBatchSize(), cudaMemcpyDeviceToHost, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
|
||||
@ -242,6 +242,9 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
xqaParams.spec_decoding_packed_mask = generationsParams.spec_decoding_packed_mask;
|
||||
xqaParams.spec_decoding_position_offsets = generationsParams.spec_decoding_position_offsets;
|
||||
xqaParams.spec_decoding_generation_lengths = generationsParams.spec_decoding_generation_lengths;
|
||||
xqaParams.spec_decoding_is_generation_length_variable
|
||||
= generationsParams.spec_decoding_is_generation_length_variable;
|
||||
xqaParams.spec_decoding_max_generation_length = generationsParams.spec_decoding_max_generation_length;
|
||||
|
||||
xqaParams.total_num_input_tokens = generationsParams.total_num_input_tokens;
|
||||
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
|
||||
@ -399,7 +402,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block,
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention, int max_distance,
|
||||
bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha,
|
||||
bool use_cache, bool is_spec_decoding_enabled)
|
||||
bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
|
||||
int32_t spec_decoding_max_generation_length)
|
||||
: mLayerIdx(layer_idx)
|
||||
, mNumHeads(num_heads)
|
||||
, mVisionStart(vision_start)
|
||||
@ -443,6 +447,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
, mFP8ContextFMHA(use_fp8_context_fmha)
|
||||
, mUseKVCache(use_cache)
|
||||
, mIsSpecDecodingEnabled(is_spec_decoding_enabled)
|
||||
, mSpecDecodingIsGenerationLengthVariable(spec_decoding_is_generation_length_variable)
|
||||
, mSpecDecodingMaxGenerationLength(spec_decoding_max_generation_length)
|
||||
, mDriver(CUDADriverWrapper::getInstance())
|
||||
{
|
||||
// Pre-check whether FMHA is supported in order to save memory allocation.
|
||||
@ -559,6 +565,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
||||
read(d, mFP8ContextFMHA);
|
||||
read(d, mUseKVCache);
|
||||
read(d, mIsSpecDecodingEnabled);
|
||||
read(d, mSpecDecodingIsGenerationLengthVariable);
|
||||
read(d, mSpecDecodingMaxGenerationLength);
|
||||
read(d, mNbMultiBlockSemaphores);
|
||||
|
||||
mKVCacheQuantMode = tc::QuantMode(kvCacheQuantMode);
|
||||
@ -1655,7 +1663,8 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
|
||||
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
|
||||
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
|
||||
+ sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA) + sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm)
|
||||
+ sizeof(mIsSpecDecodingEnabled) + sizeof(mNbMultiBlockSemaphores)
|
||||
+ sizeof(mIsSpecDecodingEnabled) + sizeof(mSpecDecodingIsGenerationLengthVariable)
|
||||
+ sizeof(mSpecDecodingMaxGenerationLength) + sizeof(mNbMultiBlockSemaphores)
|
||||
+ sizeof(uint32_t) // size of DecoderXQARunnerResource buffer.
|
||||
+ DecoderXQARunner::getResourceGlobal()->getSerializationSize();
|
||||
}
|
||||
@ -1705,6 +1714,8 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mFP8ContextFMHA);
|
||||
write(d, mUseKVCache);
|
||||
write(d, mIsSpecDecodingEnabled);
|
||||
write(d, mSpecDecodingIsGenerationLengthVariable);
|
||||
write(d, mSpecDecodingMaxGenerationLength);
|
||||
write(d, mNbMultiBlockSemaphores);
|
||||
|
||||
// An uint32_t that specifies the size of the serialized buffer, followed by the actual content.
|
||||
@ -1797,6 +1808,10 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.emplace_back(PluginField("use_fp8_context_fmha", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("use_cache", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("is_spec_decoding_enabled", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("spec_decoding_is_generation_length_variable", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("spec_decoding_max_generation_length", nullptr, PluginFieldType::kINT32, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
@ -52,7 +52,8 @@ public:
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false,
|
||||
int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false,
|
||||
bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, bool use_cache = true,
|
||||
bool is_spec_decoding_enabled = false);
|
||||
bool is_spec_decoding_enabled = false, bool spec_decoding_is_generation_length_variable = false,
|
||||
int32_t spec_decoding_max_generation_length = 1);
|
||||
|
||||
GPTAttentionPluginCommon(void const* data, size_t length);
|
||||
|
||||
@ -224,6 +225,8 @@ protected:
|
||||
int32_t const* spec_decoding_packed_mask = nullptr;
|
||||
int32_t const* spec_decoding_position_offsets = nullptr;
|
||||
int32_t const* spec_decoding_generation_lengths = nullptr;
|
||||
bool spec_decoding_is_generation_length_variable = false;
|
||||
int32_t spec_decoding_max_generation_length = 1;
|
||||
int32_t total_num_input_tokens;
|
||||
};
|
||||
|
||||
@ -324,6 +327,8 @@ protected:
|
||||
bool mFP8ContextFMHA = false;
|
||||
bool mDenseContextFMHA = false;
|
||||
bool mIsSpecDecodingEnabled = false;
|
||||
bool mSpecDecodingIsGenerationLengthVariable = false;
|
||||
int32_t mSpecDecodingMaxGenerationLength = 1;
|
||||
|
||||
// Speculative decoding packed mask.
|
||||
uint4* mSpecDecodingPackedMask;
|
||||
|
||||
@ -55,7 +55,8 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
|
||||
tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block,
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention, int max_distance,
|
||||
bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha,
|
||||
bool use_cache, bool is_spec_decoding_enabled)
|
||||
bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
|
||||
int spec_decoding_max_generation_length)
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
|
||||
unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
|
||||
@ -63,7 +64,8 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
|
||||
tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, enable_xqa, kv_cache_quant_mode,
|
||||
remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type,
|
||||
max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, dense_context_fmha,
|
||||
use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_spec_decoding_enabled)
|
||||
use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_spec_decoding_enabled,
|
||||
spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length)
|
||||
{
|
||||
initEntryIdx();
|
||||
}
|
||||
@ -789,6 +791,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
enqueue_params.spec_decoding_packed_mask = spec_decoding_packed_mask;
|
||||
enqueue_params.spec_decoding_position_offsets = spec_decoding_position_offsets;
|
||||
enqueue_params.spec_decoding_generation_lengths = spec_decoding_generation_lengths;
|
||||
enqueue_params.spec_decoding_is_generation_length_variable = mSpecDecodingIsGenerationLengthVariable;
|
||||
enqueue_params.spec_decoding_max_generation_length = mSpecDecodingMaxGenerationLength;
|
||||
}
|
||||
enqueue_params.total_num_input_tokens = localNbTokens;
|
||||
|
||||
@ -975,7 +979,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
|
||||
static_cast<bool>(p.getScalar<int8_t>("use_paged_context_fmha").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("use_fp8_context_fmha").value()),
|
||||
static_cast<bool>(p.getScalar<int32_t>("use_cache").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("is_spec_decoding_enabled").value()));
|
||||
static_cast<bool>(p.getScalar<int8_t>("is_spec_decoding_enabled").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("spec_decoding_is_generation_length_variable").value()),
|
||||
p.getScalar<int32_t>("spec_decoding_max_generation_length").value());
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -94,7 +94,8 @@ public:
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false,
|
||||
int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false,
|
||||
bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, bool use_cache = true,
|
||||
bool is_spec_decoding_enabled = false);
|
||||
bool is_spec_decoding_enabled = false, bool spec_decoding_is_generation_length_variable = false,
|
||||
int spec_decoding_max_generation_length = 1);
|
||||
|
||||
GPTAttentionPlugin(void const* data, size_t length);
|
||||
|
||||
|
||||
@ -31,10 +31,12 @@ static char const* MAMBA_CONV1D_PLUGIN_NAME{"MambaConv1d"};
|
||||
PluginFieldCollection MambaConv1dPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> MambaConv1dPluginCreator::mPluginAttributes;
|
||||
|
||||
MambaConv1dPlugin::MambaConv1dPlugin(
|
||||
int dim, int dconv, nvinfer1::DataType type, bool removePadding, bool pagedState, bool applySilu)
|
||||
MambaConv1dPlugin::MambaConv1dPlugin(int dim, int dconv, int preStride, int postStride, nvinfer1::DataType type,
|
||||
bool removePadding, bool pagedState, bool applySilu)
|
||||
: mDim(dim)
|
||||
, mDConv(dconv)
|
||||
, mPreStride(preStride)
|
||||
, mPostStride(postStride)
|
||||
, mType(type)
|
||||
, mRemovePadding(removePadding)
|
||||
, mPagedState(pagedState)
|
||||
@ -52,6 +54,8 @@ MambaConv1dPlugin::MambaConv1dPlugin(void const* data, size_t length)
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mDim);
|
||||
read(d, mDConv);
|
||||
read(d, mPreStride);
|
||||
read(d, mPostStride);
|
||||
read(d, mType);
|
||||
read(d, mRemovePadding);
|
||||
read(d, mPagedState);
|
||||
@ -65,7 +69,8 @@ MambaConv1dPlugin::MambaConv1dPlugin(void const* data, size_t length)
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* MambaConv1dPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new MambaConv1dPlugin(mDim, mDConv, mType, mRemovePadding, mPagedState, mApplySilu);
|
||||
auto* plugin
|
||||
= new MambaConv1dPlugin(mDim, mDConv, mPreStride, mPostStride, mType, mRemovePadding, mPagedState, mApplySilu);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
@ -78,7 +83,9 @@ nvinfer1::DimsExprs MambaConv1dPlugin::getOutputDimensions(
|
||||
{
|
||||
if (outputIndex == 0)
|
||||
{
|
||||
return inputs[getInputTensorIdx()];
|
||||
auto ret = inputs[getInputTensorIdx()];
|
||||
ret.d[mRemovePadding ? 1 : 2] = exprBuilder.constant(mDim);
|
||||
return ret;
|
||||
}
|
||||
return inputs[getConvStateIdx()];
|
||||
}
|
||||
@ -113,9 +120,9 @@ size_t MambaConv1dPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inp
|
||||
}
|
||||
|
||||
void MambaConv1dPlugin::setMambaConv1dParams(tensorrt_llm::kernels::MambaConv1dParamsBase& params, const size_t batch,
|
||||
const size_t dim, const size_t maxSeqLen, const size_t dconv, void const* inPtr, void const* stateInPtr,
|
||||
void* stateOutPtr, void const* convWeight, void const* convBias, void* outPtr, int const* lastTokenIds,
|
||||
int const* stateSlotMapping, bool removePadding, bool applySilu)
|
||||
const size_t dim, const size_t maxSeqLen, const size_t dconv, const size_t preStride, const size_t postStride,
|
||||
void const* inPtr, void const* stateInPtr, void* stateOutPtr, void const* convWeight, void const* convBias,
|
||||
void* outPtr, int const* lastTokenIds, int const* stateSlotMapping, bool removePadding, bool applySilu)
|
||||
{
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@ -124,6 +131,8 @@ void MambaConv1dPlugin::setMambaConv1dParams(tensorrt_llm::kernels::MambaConv1dP
|
||||
params.dim = dim;
|
||||
params.max_seqlen = maxSeqLen;
|
||||
params.dconv = dconv;
|
||||
params.pre_stride = preStride;
|
||||
params.post_stride = postStride;
|
||||
|
||||
params.remove_padding = removePadding;
|
||||
params.apply_silu = applySilu;
|
||||
@ -179,8 +188,8 @@ int MambaConv1dPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
void* stateOutPtr
|
||||
= mPagedState ? *reinterpret_cast<void**>(const_cast<void*>(inputs[getConvStateIdx()])) : outputs[1];
|
||||
|
||||
setMambaConv1dParams(mambaConv1dParams, batchSize, mDim, maxSeqLen, mDConv, inputs[getInputTensorIdx()], stateInPtr,
|
||||
stateOutPtr, inputs[getWeightIdx()], inputs[getBiasIdx()], outputs[0],
|
||||
setMambaConv1dParams(mambaConv1dParams, batchSize, mDim, maxSeqLen, mDConv, mPreStride, mPostStride,
|
||||
inputs[getInputTensorIdx()], stateInPtr, stateOutPtr, inputs[getWeightIdx()], inputs[getBiasIdx()], outputs[0],
|
||||
static_cast<int const*>(inputs[getLastTokenIdsIdx()]), slotMapping, mRemovePadding, mApplySilu);
|
||||
|
||||
if (reqTypes[0] == RequestType::kCONTEXT)
|
||||
@ -252,8 +261,8 @@ void MambaConv1dPlugin::terminate() noexcept {}
|
||||
|
||||
size_t MambaConv1dPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mDim) + sizeof(mDConv) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState)
|
||||
+ sizeof(mApplySilu);
|
||||
return sizeof(mDim) + sizeof(mDConv) + sizeof(mPreStride) + sizeof(mPostStride) + sizeof(mType)
|
||||
+ sizeof(mRemovePadding) + sizeof(mPagedState) + sizeof(mApplySilu);
|
||||
}
|
||||
|
||||
void MambaConv1dPlugin::serialize(void* buffer) const noexcept
|
||||
@ -261,6 +270,8 @@ void MambaConv1dPlugin::serialize(void* buffer) const noexcept
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mDim);
|
||||
write(d, mDConv);
|
||||
write(d, mPreStride);
|
||||
write(d, mPostStride);
|
||||
write(d, mType);
|
||||
write(d, mRemovePadding);
|
||||
write(d, mPagedState);
|
||||
@ -281,6 +292,8 @@ MambaConv1dPluginCreator::MambaConv1dPluginCreator()
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("dconv", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("pre_stride", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("post_stride", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 0));
|
||||
@ -307,7 +320,7 @@ PluginFieldCollection const* MambaConv1dPluginCreator::getFieldNames() noexcept
|
||||
IPluginV2* MambaConv1dPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int dim, dconv;
|
||||
int dim, dconv, pre_stride, post_stride;
|
||||
bool removePadding;
|
||||
bool pagedState;
|
||||
bool applySilu;
|
||||
@ -326,6 +339,16 @@ IPluginV2* MambaConv1dPluginCreator::createPlugin(char const* name, PluginFieldC
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
dconv = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "pre_stride"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
pre_stride = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "post_stride"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
post_stride = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
@ -349,7 +372,8 @@ IPluginV2* MambaConv1dPluginCreator::createPlugin(char const* name, PluginFieldC
|
||||
}
|
||||
try
|
||||
{
|
||||
auto* obj = new MambaConv1dPlugin(dim, dconv, type, removePadding, pagedState, applySilu);
|
||||
auto* obj
|
||||
= new MambaConv1dPlugin(dim, dconv, pre_stride, post_stride, type, removePadding, pagedState, applySilu);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -44,7 +44,8 @@ namespace tensorrt_llm::plugins
|
||||
class MambaConv1dPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
MambaConv1dPlugin(int dim, int dconv, nvinfer1::DataType type, bool removePadding, bool pagedState, bool applySilu);
|
||||
MambaConv1dPlugin(int dim, int dconv, int preStride, int postStride, nvinfer1::DataType type, bool removePadding,
|
||||
bool pagedState, bool applySilu);
|
||||
|
||||
MambaConv1dPlugin(void const* data, size_t length);
|
||||
|
||||
@ -132,7 +133,8 @@ private:
|
||||
|
||||
void setMambaConv1dParams(tensorrt_llm::kernels::MambaConv1dParamsBase& params,
|
||||
// sizes
|
||||
const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dconv,
|
||||
const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dconv, const size_t preStride,
|
||||
const size_t postStride,
|
||||
// device pointers
|
||||
void const* inPtr, void const* stateInPtr, void* stateOutPtr, void const* convWeight, void const* convBias,
|
||||
void* outPtr, int const* lastTokenIds, int const* stateSlotMapping, bool removePadding, bool applySilu);
|
||||
@ -140,6 +142,8 @@ private:
|
||||
private:
|
||||
int mDim;
|
||||
int mDConv;
|
||||
int mPreStride;
|
||||
int mPostStride;
|
||||
nvinfer1::DataType mType;
|
||||
bool mRemovePadding = false;
|
||||
bool mPagedState = false;
|
||||
|
||||
@ -90,7 +90,16 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions(
|
||||
{
|
||||
if (outputIndex == 0)
|
||||
{
|
||||
return inputs[getInputTensorIdx()];
|
||||
if (mIsMamba2)
|
||||
{
|
||||
auto ret = inputs[getInputTensorIdx()];
|
||||
ret.d[mRemovePadding ? 1 : 2] = exprBuilder.constant(mDim);
|
||||
return ret;
|
||||
}
|
||||
else
|
||||
{
|
||||
return inputs[getInputTensorIdx()];
|
||||
}
|
||||
}
|
||||
return inputs[getStateIdx()];
|
||||
}
|
||||
@ -136,9 +145,9 @@ size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
|
||||
int B = inputs[getLastTokenIdsIdx()].dims.d[0];
|
||||
int BxL = inputs[getInputTensorIdx()].dims.d[0]; // num_tokens
|
||||
int H = mNHeads;
|
||||
int P = inputs[getInputTensorIdx()].dims.d[1] / H;
|
||||
int P = mDim / H;
|
||||
int G = mNGroups;
|
||||
int N = inputs[getBCIdx()].dims.d[1] / G / 2;
|
||||
int N = mDState;
|
||||
int Q = mChunkSize;
|
||||
int BxC = (BxL + Q - 1) / Q + B;
|
||||
|
||||
@ -153,9 +162,9 @@ size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i
|
||||
int B = inputs[getInputTensorIdx()].dims.d[0];
|
||||
int L = inputs[getInputTensorIdx()].dims.d[1];
|
||||
int H = mNHeads;
|
||||
int P = inputs[getInputTensorIdx()].dims.d[2] / H;
|
||||
int P = mDim / H;
|
||||
int G = mNGroups;
|
||||
int N = inputs[getBCIdx()].dims.d[2] / G / 2;
|
||||
int N = mDState;
|
||||
int Q = mChunkSize;
|
||||
int C = (L + Q - 1) / Q;
|
||||
|
||||
@ -268,16 +277,16 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
float* mxdA = nullptr;
|
||||
T* mxCB = nullptr;
|
||||
|
||||
if (!mIsMamba2) /* no workspace needed */
|
||||
if (!mIsMamba2 || reqTypes[0] == RequestType::kGENERATION) /* no workspace needed */
|
||||
;
|
||||
else if (mRemovePadding)
|
||||
{
|
||||
int B = inputDesc[getLastTokenIdsIdx()].dims.d[0];
|
||||
int BxL = inputDesc[getInputTensorIdx()].dims.d[0]; // num_tokens
|
||||
int H = mNHeads;
|
||||
int P = inputDesc[getInputTensorIdx()].dims.d[1] / H;
|
||||
int P = mDim / H;
|
||||
int G = mNGroups;
|
||||
int N = inputDesc[getBCIdx()].dims.d[1] / G / 2;
|
||||
int N = mDState;
|
||||
int Q = mChunkSize;
|
||||
int BxC = (BxL + Q - 1) / Q + B;
|
||||
|
||||
@ -292,9 +301,9 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
int B = inputDesc[getInputTensorIdx()].dims.d[0];
|
||||
int L = inputDesc[getInputTensorIdx()].dims.d[1];
|
||||
int H = mNHeads;
|
||||
int P = inputDesc[getInputTensorIdx()].dims.d[2] / H;
|
||||
int P = mDim / H;
|
||||
int G = mNGroups;
|
||||
int N = inputDesc[getBCIdx()].dims.d[2] / G / 2;
|
||||
int N = mDState;
|
||||
int Q = mChunkSize;
|
||||
int C = (L + Q - 1) / Q;
|
||||
|
||||
|
||||
@ -57,10 +57,10 @@ std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
|
||||
|
||||
return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor,
|
||||
tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens,
|
||||
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream)
|
||||
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream, std::optional<RequestIdType> clientId)
|
||||
{
|
||||
at::Tensor atTensor = tr::Torch::tensor(tensor);
|
||||
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap());
|
||||
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId);
|
||||
};
|
||||
}
|
||||
|
||||
@ -112,7 +112,7 @@ void LlmRequest::initBindings(py::module_& m)
|
||||
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
||||
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
||||
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
|
||||
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
|
||||
.def_property("max_sent_token_len", &LlmRequest::getMaxSentTokenLen, &LlmRequest::setMaxSentTokenLen)
|
||||
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
|
||||
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
|
||||
.def_property_readonly("lora_task_id", &LlmRequest::getLoraTaskId)
|
||||
|
||||
@ -126,6 +126,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
py::enum_<tr::ModelConfig::ModelVariant>(m, "GptModelVariant")
|
||||
.value("GPT", tr::ModelConfig::ModelVariant::kGpt)
|
||||
.value("GLM", tr::ModelConfig::ModelVariant::kGlm)
|
||||
.value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm)
|
||||
.value("MAMBA", tr::ModelConfig::ModelVariant::kMamba)
|
||||
.value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma);
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ namespace tensorrt_llm::pybind::executor
|
||||
|
||||
void InitBindings(pybind11::module_& m)
|
||||
{
|
||||
m.attr("__version__") = tle::version();
|
||||
py::enum_<tle::ModelType>(m, "ModelType")
|
||||
.value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY)
|
||||
.value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY)
|
||||
@ -228,14 +229,16 @@ void InitBindings(pybind11::module_& m)
|
||||
std::optional<SizeType32> const&, std::optional<SizeType32> const&,
|
||||
std::optional<std::list<VecTokens>>, std::optional<std::list<VecTokens>>, std::optional<Tensor>,
|
||||
std::optional<tle::ExternalDraftTokensConfig>, std::optional<tle::PromptTuningConfig>,
|
||||
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>>(),
|
||||
std::optional<tle::LoraConfig>, std::optional<std::string>, std::optional<VecTokens>,
|
||||
std::optional<IdType>, bool>(),
|
||||
py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false,
|
||||
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
|
||||
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(),
|
||||
py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(),
|
||||
py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(),
|
||||
py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(),
|
||||
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none())
|
||||
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(),
|
||||
py::arg("client_id") = py::none(), py::arg("return_all_generated_tokens") = false)
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -254,7 +257,10 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
|
||||
&tle::Request::setLogitsPostProcessorName)
|
||||
.def_property(
|
||||
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds);
|
||||
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
|
||||
.def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId)
|
||||
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
|
||||
&tle::Request::setReturnAllGeneratedTokens);
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
|
||||
py::class_<tle::Result>(m, "Result")
|
||||
|
||||
@ -46,6 +46,7 @@ void ExplicitDraftTokensBuffers::Inputs::create(SizeType32 maxNumSequences, Tllm
|
||||
temperatures = manager.gpu(ITensor::makeShape({maxNumSequences}), dtype);
|
||||
positionIdsBase = manager.gpu(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
|
||||
generationLengths = manager.gpu(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
|
||||
generationLengthsHost = manager.pinned(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
|
||||
randomDataSample = manager.gpu(ITensor::makeShape({maxNumSequences}), dtype);
|
||||
randomDataValidation = manager.gpu(ITensor::makeShape({maxNumSequences, maxNumPaths, maxDraftPathLen}), dtype);
|
||||
draftTokens = manager.gpu(ITensor::makeShape({maxNumSequences, maxNumPaths, maxPathLen}), TRTTokenIdType);
|
||||
@ -270,6 +271,20 @@ void ExplicitDraftTokensBuffers::setFromInputs(SizeType32 numCtxSequences, SizeT
|
||||
auto const& manager = runtime.getBufferManager();
|
||||
auto const& stream = runtime.getStream();
|
||||
|
||||
auto const explicitDraftTokensModule = std::dynamic_pointer_cast<runtime::ExplicitDraftTokensModule const>(
|
||||
modelConfig.getSpeculativeDecodingModulePtr());
|
||||
|
||||
auto const seqSlotsPtr = bufferCast<SizeType32>(seqSlots);
|
||||
auto const generationLengthsPtr = bufferCast<SizeType32>(*draftBuffers.generationLengthsHost);
|
||||
SizeType32 totalGenLengths = 0;
|
||||
for (SizeType32 si = 0; si < numGenSequences; ++si)
|
||||
{
|
||||
auto const slot = seqSlotsPtr[numCtxSequences + si];
|
||||
totalGenLengths += generationLengthsPtr[slot];
|
||||
}
|
||||
|
||||
// Reshape position ids.
|
||||
engineInputs.positionIds->reshape(ITensor::makeShape({contextPositionIds.getShape().d[0] + totalGenLengths}));
|
||||
// Copy position ids -- hacky solution to avoid filling them for the context requests.
|
||||
TensorPtr posIdsSlice = ITensor::slice(engineInputs.positionIds, 0, contextPositionIds.getShape().d[0]);
|
||||
manager.copy(contextPositionIds, *posIdsSlice);
|
||||
@ -279,9 +294,6 @@ void ExplicitDraftTokensBuffers::setFromInputs(SizeType32 numCtxSequences, SizeT
|
||||
auto const numSequences = numCtxSequences + numGenSequences;
|
||||
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
||||
|
||||
auto const explicitDraftTokensModule = std::dynamic_pointer_cast<runtime::ExplicitDraftTokensModule const>(
|
||||
modelConfig.getSpeculativeDecodingModulePtr());
|
||||
|
||||
auto const dtype = modelConfig.getDataType();
|
||||
|
||||
switch (dtype)
|
||||
@ -327,7 +339,7 @@ void ExplicitDraftTokensBuffers::insertInputTensors(
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
// inputs
|
||||
inputBuffers.insert_or_assign("explicit_inverted_temperature", engineInputs.temperatures);
|
||||
inputBuffers.insert_or_assign("redrafter_inverted_temperature", engineInputs.temperatures);
|
||||
inputBuffers.insert_or_assign("device_request_types", engineInputs.requestTypesDevice);
|
||||
|
||||
inputBuffers.insert_or_assign("spec_decoding_generation_lengths", engineInputs.generationLengths);
|
||||
|
||||
@ -438,6 +438,7 @@ void prepareSpeculativeDecodingOutputs(DecodingOutput& output, std::shared_ptr<t
|
||||
outputParams->randomDataValidation = tcc::toTllmTensor(*explicitDraftTokensBuffers->randomDataValidation);
|
||||
outputParams->temperatures = tcc::toTllmTensor(*explicitDraftTokensBuffers->temperatures);
|
||||
outputParams->generationLengths = tcc::toTllmTensor(*explicitDraftTokensBuffers->generationLengths);
|
||||
outputParams->generationLengthsHost = tcc::toTllmTensor(*explicitDraftTokensBuffers->generationLengthsHost);
|
||||
outputParams->maxGenLengthHost = tcc::toTllmTensor(*explicitDraftTokensBuffers->maxGenLengthHost);
|
||||
}
|
||||
|
||||
|
||||
@ -340,10 +340,15 @@ GptJsonConfig parseJson(InputType&& input)
|
||||
|
||||
if (engineVersionNone)
|
||||
{
|
||||
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
|
||||
if (name == std::string("chatglm_6b"))
|
||||
{
|
||||
modelConfig.setModelVariant(ModelConfig::ModelVariant::kChatGlm);
|
||||
// kChatGlm is only for ChatGLM-6B
|
||||
}
|
||||
if (name == std::string("glm_10b"))
|
||||
{
|
||||
modelConfig.setModelVariant(ModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B and GLM-10B
|
||||
// kGlm is only for GLM-10B
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -352,10 +357,15 @@ GptJsonConfig parseJson(InputType&& input)
|
||||
{
|
||||
auto const& pretrainedConfig = json.at("pretrained_config");
|
||||
auto const chatglmVersion = pretrainedConfig.at("chatglm_version").template get<std::string>();
|
||||
if (chatglmVersion == "glm" || chatglmVersion == "chatglm")
|
||||
if (chatglmVersion == "chatglm")
|
||||
{
|
||||
modelConfig.setModelVariant(ModelConfig::ModelVariant::kChatGlm);
|
||||
// kChatGlm is only for ChatGLM-6B
|
||||
}
|
||||
if (chatglmVersion == "glm")
|
||||
{
|
||||
modelConfig.setModelVariant(ModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B and GLM-10B
|
||||
// kGlm is only for GLM-10B
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -368,8 +378,8 @@ GptJsonConfig parseJson(InputType&& input)
|
||||
auto const& pretrainedConfig = json.at("pretrained_config");
|
||||
|
||||
// TODO(rkobus): adjust param names
|
||||
auto const maxNumPaths = parseJsonFieldOr(pretrainedConfig, "explicit_num_beams", 0);
|
||||
auto const maxDraftPathLen = parseJsonFieldOr(pretrainedConfig, "explicit_draft_len_per_beam", 0);
|
||||
auto const maxNumPaths = parseJsonFieldOr(pretrainedConfig, "redrafter_num_beams", 0);
|
||||
auto const maxDraftPathLen = parseJsonFieldOr(pretrainedConfig, "redrafter_draft_len_per_beam", 0);
|
||||
auto const maxDraftLen = maxNumPaths * maxDraftPathLen;
|
||||
|
||||
auto explicitDraftTokensModule
|
||||
|
||||
@ -31,6 +31,7 @@ namespace tc = tensorrt_llm::common;
|
||||
|
||||
ITensor::UniquePtr ITensor::slice(SharedPtr tensor, std::size_t offset, std::size_t size)
|
||||
{
|
||||
TLLM_CHECK(tensor);
|
||||
return std::make_unique<TensorView>(std::move(tensor), offset, size);
|
||||
}
|
||||
|
||||
|
||||
@ -308,7 +308,7 @@ void TransformerBuffers::prepareContextStep(RuntimeBuffers* runtimeBuffers, Tens
|
||||
}
|
||||
positionIds = manager.copyFrom(positionIdsVec, inputShape, MemoryType::kGPU);
|
||||
}
|
||||
else if (modelVariant == ModelConfig::ModelVariant::kGlm)
|
||||
else if (modelVariant == ModelConfig::ModelVariant::kChatGlm)
|
||||
{
|
||||
auto const positionIdsVec = getPositionIdsContextPhaseGlm(batchSize, maxInputLength, contextLengthsHostPtr,
|
||||
modelConfig.useGptAttentionPlugin(), modelConfig.usePackedInput());
|
||||
@ -578,7 +578,7 @@ void TransformerBuffers::prepareNextStep(RuntimeBuffers* runtimeBuffers, SizeTyp
|
||||
manager.copy(*contextLengthsDevice, *positionIds);
|
||||
kernels::invokeAdd(*positionIds, step, stream);
|
||||
}
|
||||
else if (modelVariant == ModelConfig::ModelVariant::kGlm)
|
||||
else if (modelVariant == ModelConfig::ModelVariant::kChatGlm)
|
||||
{
|
||||
auto const positionIdsVec = getPositionIdsGenerationPhaseGlm(batchSize, beamWidth, step,
|
||||
contextLengthsHostPtr, modelConfig.useGptAttentionPlugin(), modelConfig.usePackedInput());
|
||||
|
||||
@ -35,8 +35,10 @@ find_library_create_target(nvonnxparser ${ONNX_PARSER_LIB_NAME} SHARED
|
||||
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include
|
||||
${PROJECT_SOURCE_DIR}/include ${3RDPARTY_DIR}/cutlass/include
|
||||
${3RDPARTY_DIR}/cutlass/tools/util/include)
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${3RDPARTY_DIR}/cutlass/include
|
||||
${3RDPARTY_DIR}/cutlass/tools/util/include
|
||||
${PROJECT_SOURCE_DIR}/tests/batch_manager)
|
||||
|
||||
set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
|
||||
|
||||
@ -77,6 +79,7 @@ add_gtest(transposeKVKernelTest runtime/transposeKVKernelTest.cpp)
|
||||
add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp)
|
||||
add_gtest(gptDecoderBatchTest runtime/gptDecoderBatchTest.cpp)
|
||||
add_gtest(gptSessionTest runtime/gptSessionTest.cpp)
|
||||
target_link_libraries(gptSessionTest PRIVATE modelSpecStatic)
|
||||
add_gtest(allocatorTest common/allocatorTest.cpp)
|
||||
add_gtest(memoryUtilsTest common/memoryUtilsTest.cu)
|
||||
if(ENABLE_MULTI_DEVICE EQUAL 1)
|
||||
@ -87,6 +90,8 @@ add_gtest(stringUtilsTest common/stringUtilsTest.cpp)
|
||||
add_gtest(tllmExceptionTest common/tllmExceptionTest.cpp)
|
||||
add_gtest(tensorTest common/tensorTest.cpp)
|
||||
add_gtest(stlUtilsTest common/stlUtilsTest.cpp)
|
||||
add_gtest(cudaProfilerUtilsTest common/cudaProfilerUtilsTest.cpp)
|
||||
add_gtest(timestampUtilsTest common/timestampUtilsTest.cpp)
|
||||
add_gtest(tllmRuntimeTest runtime/tllmRuntimeTest.cpp)
|
||||
add_gtest(tllmBuffersTest runtime/tllmBuffersTest.cpp)
|
||||
add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp)
|
||||
|
||||
@ -58,6 +58,7 @@ PYTHONPATH=examples/gptj:$PYTHONPATH python3 cpp/tests/resources/scripts/build_g
|
||||
PYTHONPATH=examples/llama:$PYTHONPATH python3 cpp/tests/resources/scripts/build_llama_engines.py
|
||||
PYTHONPATH=examples/chatglm:$PYTHONPATH python3 cpp/tests/resources/scripts/build_chatglm_engines.py
|
||||
PYTHONPATH=examples/medusa:$PYTHONPATH python3 cpp/tests/resources/scripts/build_medusa_engines.py
|
||||
PYTHONPATH=examples/redrafter:$PYTHONPATH python3 cpp/tests/resources/scripts/build_redrafter_engines.py --has_tllm_checkpoint
|
||||
```
|
||||
|
||||
It is possible to build engines with tensor and pipeline parallelism for LLaMA using 4 GPUs.
|
||||
@ -76,6 +77,7 @@ PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_exp
|
||||
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_llama_output.py
|
||||
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm_output.py
|
||||
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_medusa_output.py
|
||||
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_redrafter_output.py
|
||||
```
|
||||
|
||||
#### Generate data with tensor and pipeline parallelism
|
||||
|
||||
76
cpp/tests/common/cudaProfilerUtilsTest.cpp
Normal file
76
cpp/tests/common/cudaProfilerUtilsTest.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaProfilerUtils.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
int setenv(char const* name, char const* value, int overwrite)
|
||||
{
|
||||
int errcode = 0;
|
||||
if (!overwrite)
|
||||
{
|
||||
size_t envsize = 0;
|
||||
errcode = getenv_s(&envsize, NULL, 0, name);
|
||||
if (errcode || envsize)
|
||||
return errcode;
|
||||
}
|
||||
return _putenv_s(name, value);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::string kEnvVarName{"TLLM_PROFILE_START_STOP"};
|
||||
std::string kLegacyEnvVarName{"TLLM_GPTM_PROFILE_START_STOP"};
|
||||
|
||||
struct TestCase
|
||||
{
|
||||
std::optional<std::string> legacyEnvVarVal;
|
||||
std::string envVarVal;
|
||||
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> result;
|
||||
};
|
||||
|
||||
TEST(CudaProfilerUtils, populateIterationIndexes)
|
||||
{
|
||||
std::vector<TestCase> testCases;
|
||||
testCases.emplace_back(TestCase{std::nullopt, "", {{}, {}}});
|
||||
testCases.emplace_back(TestCase{std::nullopt, "1", {{1}, {1}}});
|
||||
testCases.emplace_back(TestCase{std::nullopt, "1,2,3", {{1, 2, 3}, {1, 2, 3}}});
|
||||
testCases.emplace_back(TestCase{std::nullopt, "1-4,7-8", {{1, 7}, {4, 8}}});
|
||||
testCases.emplace_back(TestCase{std::nullopt, "1,2,10-15", {{1, 2, 10}, {1, 2, 15}}});
|
||||
testCases.emplace_back(TestCase{std::nullopt, "1,,10-15", {{1, 10}, {1, 15}}});
|
||||
|
||||
// Only legacy env var set
|
||||
testCases.emplace_back(TestCase{"1-4,7-8", "", {{1, 7}, {4, 8}}});
|
||||
|
||||
// Both set, non-legacy has priority
|
||||
testCases.emplace_back(TestCase{"1-4,7-8", "2-10,88-99", {{2, 88}, {10, 99}}});
|
||||
|
||||
for (auto const& testCase : testCases)
|
||||
{
|
||||
auto ret = setenv(kEnvVarName.c_str(), testCase.envVarVal.c_str(), 1); // does overwrite
|
||||
EXPECT_EQ(ret, 0);
|
||||
ret = setenv(
|
||||
kLegacyEnvVarName.c_str(), testCase.legacyEnvVarVal.value_or(std::string()).c_str(), 1); // does overwrite
|
||||
auto const [profileIterIdxs, stopIterIdxs]
|
||||
= tensorrt_llm::common::populateIterationIndexes(kEnvVarName, kLegacyEnvVarName);
|
||||
EXPECT_EQ(profileIterIdxs, testCase.result.first)
|
||||
<< testCase.envVarVal << " " << testCase.legacyEnvVarVal.value_or("");
|
||||
EXPECT_EQ(stopIterIdxs, testCase.result.second)
|
||||
<< testCase.envVarVal << " " << testCase.legacyEnvVarVal.value_or("");
|
||||
}
|
||||
}
|
||||
@ -77,3 +77,25 @@ TEST(StringUtil, FormatFixedDecimals)
|
||||
EXPECT_EQ(prefix, formatFixed(num));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StringUtil, str2set)
|
||||
{
|
||||
{
|
||||
char delimiter{','};
|
||||
|
||||
std::vector<std::string> inputs{
|
||||
{"apple,car,dog"}, {",apple,car,dog"}, {"apple,car,dog"}, {"apple,car,dog,,"}, {"apple,,,car,dog,"}};
|
||||
|
||||
for (auto const& input : inputs)
|
||||
{
|
||||
|
||||
auto out = str2set(input, delimiter);
|
||||
|
||||
EXPECT_EQ(out.size(), 3);
|
||||
EXPECT_EQ(out.count("apple"), 1);
|
||||
EXPECT_EQ(out.count("car"), 1);
|
||||
EXPECT_EQ(out.count("dog"), 1);
|
||||
EXPECT_EQ(out.count("cat"), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
46
cpp/tests/common/timestampUtilsTest.cpp
Normal file
46
cpp/tests/common/timestampUtilsTest.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "tensorrt_llm/common/timestampUtils.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
TEST(TimestampUtils, getCurrentTimestamp)
|
||||
{
|
||||
int32_t sleepMs = 100;
|
||||
int32_t sleepUs = sleepMs * 1000;
|
||||
;
|
||||
int32_t tolUs = 5000;
|
||||
auto timestamp = getCurrentTimestamp();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs));
|
||||
auto timestamp2 = getCurrentTimestamp();
|
||||
auto microseconds = std::stoi(timestamp.erase(0, timestamp.find('.') + 1));
|
||||
std::cout << microseconds << std::endl;
|
||||
auto microseconds2 = std::stoi(timestamp2.erase(0, timestamp2.find('.') + 1));
|
||||
|
||||
int32_t delta = (microseconds2 - microseconds);
|
||||
if (delta < 0)
|
||||
{
|
||||
delta += 1000000;
|
||||
}
|
||||
EXPECT_NEAR(delta, sleepUs, tolUs) << "delta: " << delta << " expected " << sleepUs << std::endl;
|
||||
}
|
||||
@ -710,6 +710,9 @@ void ExplicitDraftTokensLayerTest<T>::allocateBuffers()
|
||||
mOutputGenerationLengths
|
||||
= BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
|
||||
|
||||
mOutputGenerationLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
|
||||
|
||||
mMaxGenLengthHost = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
// inputs
|
||||
@ -1012,6 +1015,8 @@ std::shared_ptr<ExplicitDraftTokensOutputs> ExplicitDraftTokensLayerTest<T>::cre
|
||||
|
||||
outputParams->generationLengths = tcc::toTllmTensor(*mOutputGenerationLengths);
|
||||
|
||||
outputParams->generationLengthsHost = tcc::toTllmTensor(*mOutputGenerationLengthsHost);
|
||||
|
||||
outputParams->maxGenLengthHost = tcc::toTllmTensor(*mMaxGenLengthHost);
|
||||
|
||||
return outputParams;
|
||||
|
||||
@ -273,6 +273,7 @@ private:
|
||||
TensorPtr mOutputDraftProbs;
|
||||
TensorPtr mOutputTemperatures;
|
||||
TensorPtr mOutputGenerationLengths;
|
||||
TensorPtr mOutputGenerationLengthsHost;
|
||||
TensorPtr mMaxGenLengthHost;
|
||||
|
||||
// inputs
|
||||
|
||||
BIN
cpp/tests/resources/data/input_tokens_glm-10b.npy
Normal file
BIN
cpp/tests/resources/data/input_tokens_glm-10b.npy
Normal file
Binary file not shown.
BIN
cpp/tests/resources/data/input_vicuna.npy
Normal file
BIN
cpp/tests/resources/data/input_vicuna.npy
Normal file
Binary file not shown.
@ -21,7 +21,12 @@ import sys
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
resources_dir = Path(__file__).parent.resolve().parent
|
||||
model_dir = resources_dir / "models"
|
||||
@ -42,7 +47,7 @@ def convert_ckpt(model_dir: str, output_dir: str, world_size: int):
|
||||
def build_engine(ckpt_dir: str,
|
||||
engine_dir: str,
|
||||
is_ifb: bool = False,
|
||||
is_chatglm_6b: bool = False):
|
||||
is_chatglm_6b_or_glm_10b: bool = False):
|
||||
build_cmd = [
|
||||
"trtllm-build",
|
||||
f"--checkpoint_dir={ckpt_dir}",
|
||||
@ -70,8 +75,8 @@ def build_engine(ckpt_dir: str,
|
||||
"--paged_kv_cache=disable",
|
||||
])
|
||||
|
||||
if is_chatglm_6b:
|
||||
print("Disable Context FMHA for ChatGLM-6B")
|
||||
if is_chatglm_6b_or_glm_10b:
|
||||
print("Disable Context FMHA for ChatGLM-6B and GLM-10B")
|
||||
build_cmd.extend(
|
||||
["--context_fmha=disable", "--context_fmha_fp32_acc=disable"])
|
||||
|
||||
@ -81,7 +86,8 @@ def build_engine(ckpt_dir: str,
|
||||
def build_engines(model_cache: typing.Optional[str] = None,
|
||||
world_size: int = 1):
|
||||
|
||||
for model_name in ["chatglm-6b", "chatglm2-6b", "chatglm3-6b"]:
|
||||
for model_name in ["chatglm-6b", "chatglm2-6b", "chatglm3-6b", "glm-10b"]:
|
||||
is_chatglm_6b_or_glm_10b = model_name in ["chatglm-6b", "glm-10b"]
|
||||
if model_cache and (Path(model_cache) / model_name).is_dir():
|
||||
model_cache_dir = Path(model_cache) / model_name
|
||||
if bCopyModel or model_name == "chatglm-6b":
|
||||
@ -101,15 +107,16 @@ def build_engines(model_cache: typing.Optional[str] = None,
|
||||
hf_dir = Path(model_cache)
|
||||
|
||||
else:
|
||||
print("Clone model from HF")
|
||||
hf_dir = model_dir / model_name
|
||||
run_command(
|
||||
[
|
||||
"git", "clone",
|
||||
f"https://huggingface.co/THUDM/{model_name}", model_name
|
||||
],
|
||||
cwd=model_dir,
|
||||
)
|
||||
if not hf_dir.is_dir():
|
||||
print("Clone model from HF")
|
||||
run_command(
|
||||
[
|
||||
"git", "clone",
|
||||
f"https://huggingface.co/THUDM/{model_name}", model_name
|
||||
],
|
||||
cwd=model_dir,
|
||||
)
|
||||
|
||||
# Build engines
|
||||
print(f"Building {model_name}")
|
||||
@ -125,14 +132,25 @@ def build_engines(model_cache: typing.Optional[str] = None,
|
||||
|
||||
convert_ckpt(hf_dir, ckpt_dir, world_size)
|
||||
|
||||
for engine_kind in ["fp16-plugin", "fp16-plugin-packed-paged"]:
|
||||
engine_dir = Path(
|
||||
model_dir
|
||||
) / "rt_engine" / model_name / engine_kind / "tp1-pp1-gpu"
|
||||
engine_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_spec_obj = model_spec.ModelSpec('input_tokens.npy',
|
||||
_tb.DataType.HALF)
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
engine_dir = Path(
|
||||
model_dir
|
||||
) / "rt_engine" / model_name / model_spec_obj.get_model_path(
|
||||
) / "tp1-pp1-gpu"
|
||||
engine_dir.mkdir(parents=True, exist_ok=True)
|
||||
build_engine(ckpt_dir, engine_dir, False, is_chatglm_6b_or_glm_10b)
|
||||
|
||||
build_engine(ckpt_dir, engine_dir, "paged" in engine_kind,
|
||||
model_name == "chatglm-6b")
|
||||
model_spec_obj.use_packed_input()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
engine_dir = Path(
|
||||
model_dir
|
||||
) / "rt_engine" / model_name / model_spec_obj.get_model_path(
|
||||
) / "tp1-pp1-gpu"
|
||||
engine_dir.mkdir(parents=True, exist_ok=True)
|
||||
build_engine(ckpt_dir, engine_dir, True, is_chatglm_6b_or_glm_10b)
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
@ -62,3 +62,37 @@ def wincopy(source: str, dest: str, isdir: bool, cwd=None) -> None:
|
||||
raise _sp.CalledProcessError(returncode=result.returncode,
|
||||
cmd=copy_cmd,
|
||||
output=result.stderr)
|
||||
|
||||
|
||||
# Helper function to locate model_spec module.
|
||||
def init_model_spec_module():
|
||||
import os
|
||||
|
||||
# Rely on unique built model_spec to locate the module.
|
||||
cpp_root_dir = _pl.Path(__file__).parent.resolve().parent.parent.parent
|
||||
|
||||
found_locations = []
|
||||
|
||||
def find_model_spec_module(directory, found_locations):
|
||||
for root, d, files in os.walk(directory):
|
||||
for item in files:
|
||||
if item == 'model_spec.so':
|
||||
found_locations.append(root)
|
||||
|
||||
for d in os.listdir(cpp_root_dir):
|
||||
if d.startswith("build"):
|
||||
find_model_spec_module(os.path.join(cpp_root_dir, d),
|
||||
found_locations)
|
||||
|
||||
if len(found_locations) == 0:
|
||||
# In CI package, model_spec module is copied to its source directory.
|
||||
find_model_spec_module(
|
||||
os.path.join(cpp_root_dir, 'tests', 'batch_manager'),
|
||||
found_locations)
|
||||
|
||||
assert len(
|
||||
found_locations
|
||||
) == 1, f'Can\'t uniquely locate model_spec module, found {found_locations}'
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, found_locations[0])
|
||||
|
||||
@ -21,7 +21,12 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def convert_ckpt(model_dir: str,
|
||||
@ -44,6 +49,9 @@ def build_engine(
|
||||
max_input_len: int = 256,
|
||||
max_seq_len: int = 384,
|
||||
):
|
||||
|
||||
if os.path.exists(engine_dir):
|
||||
assert False
|
||||
build_cmd = [
|
||||
"trtllm-build",
|
||||
'--log_level=error',
|
||||
@ -139,11 +147,14 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
|
||||
world_size=tp_size,
|
||||
dtype='float32')
|
||||
|
||||
input_file = 'input_tokens.npy'
|
||||
print("\nBuilding fp32 engines")
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FLOAT)
|
||||
build_engine(str(fp32_ckpt_dir),
|
||||
str(engine_dir / 'fp32-default' / tp_pp_dir))
|
||||
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir))
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
build_engine(str(fp32_ckpt_dir),
|
||||
str(engine_dir / 'fp32-plugin' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
|
||||
'--gpt_attention_plugin=float32', '--context_fmha=enable',
|
||||
'--context_fmha_fp32_acc=enable')
|
||||
|
||||
@ -155,13 +166,16 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
|
||||
dtype='float16')
|
||||
|
||||
print("\nBuilding fp16 engines")
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
build_engine(str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-default' / tp_pp_dir))
|
||||
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir))
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
build_engine(str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
|
||||
'--gpt_attention_plugin=float16')
|
||||
model_spec_obj.use_packed_input()
|
||||
build_engine(str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
|
||||
'--gpt_attention_plugin=float16',
|
||||
'--remove_input_padding=enable')
|
||||
|
||||
@ -175,30 +189,57 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
|
||||
'--max_num_tokens=10000',
|
||||
'--use_paged_context_fmha=enable',
|
||||
]
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
|
||||
build_engine(str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_obj.get_model_path() / tp_pp_dir),
|
||||
*ifb_args)
|
||||
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
max_draft_tokens = 5
|
||||
model_spec_current.use_draft_tokens_external_decoding()
|
||||
model_spec_current.set_draft_tokens(max_draft_tokens)
|
||||
|
||||
build_engine(
|
||||
str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged-draft-tokens' / tp_pp_dir),
|
||||
'--max_draft_len=5',
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
f'--max_draft_len={max_draft_tokens}',
|
||||
'--speculative_decoding_mode=draft_tokens_external', *ifb_args)
|
||||
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
model_spec_current.use_multiple_profiles()
|
||||
|
||||
build_engine(
|
||||
str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged-nprofiles' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
'--multiple_profiles=enable', *ifb_args)
|
||||
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
max_input_len = 128
|
||||
model_spec_current.set_max_input_length(max_input_len)
|
||||
|
||||
build_engine(str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged-in128' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_current.get_model_path() /
|
||||
tp_pp_dir),
|
||||
*ifb_args,
|
||||
max_input_len=128)
|
||||
max_input_len=max_input_len)
|
||||
|
||||
# Build the target model with return accepted token logits
|
||||
# Build with '--max_draft_len', '--speculative_decoding_mode' and '--gather_generation_logits'
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
max_draft_len = 5
|
||||
model_spec_current.use_draft_tokens_external_decoding()
|
||||
model_spec_current.set_draft_tokens(max_draft_len)
|
||||
model_spec_current.gather_logits()
|
||||
model_spec_current.return_accepted_tokens_logits()
|
||||
|
||||
build_engine(
|
||||
str(fp16_ckpt_dir),
|
||||
str(engine_dir /
|
||||
'fp16-plugin-packed-paged-return-accepted-tokens-logits' /
|
||||
tp_pp_dir), '--max_draft_len=5',
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
f'--max_draft_len={max_draft_len}',
|
||||
'--speculative_decoding_mode=draft_tokens_external',
|
||||
'--gather_generation_logits', *ifb_args)
|
||||
|
||||
@ -206,22 +247,31 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
|
||||
# to extract logits from python runtime and uses context FMHA for generation to match draft model executions,
|
||||
# which uses context FMHA for draft tokens prediction.
|
||||
# Currently the gather_all_token_logits is not supported with target model of speculative decoding
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
model_spec_current.gather_logits()
|
||||
|
||||
build_engine(
|
||||
str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged-gather' / tp_pp_dir),
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
'--gather_all_token_logits', *ifb_args)
|
||||
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
model_spec_current.use_look_ahead_decoding()
|
||||
max_draft_len = 64
|
||||
model_spec_current.set_draft_tokens(max_draft_len)
|
||||
build_engine(
|
||||
str(fp16_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged-la-decoding' / tp_pp_dir),
|
||||
'--max_draft_len=64', '--speculative_decoding_mode=lookahead_decoding',
|
||||
*ifb_args)
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
f'--max_draft_len={max_draft_len}',
|
||||
'--speculative_decoding_mode=lookahead_decoding', *ifb_args)
|
||||
|
||||
# build engine with lora enabled
|
||||
build_engine(str(fp16_ckpt_dir),
|
||||
str(engine_dir / "fp16-plugin-packed-paged-lora" / tp_pp_dir),
|
||||
"--lora_target_modules=attn_qkv", '--lora_plugin=float16',
|
||||
*ifb_args)
|
||||
model_spec_current = model_spec_obj.__copy__()
|
||||
model_spec_current.use_lora_plugin()
|
||||
build_engine(
|
||||
str(fp16_ckpt_dir),
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
"--lora_target_modules=attn_qkv", '--lora_plugin=float16', *ifb_args)
|
||||
|
||||
if model_cache:
|
||||
llm_datasets_root = Path(model_cache) / "datasets"
|
||||
@ -238,9 +288,16 @@ def build_engines(model_cache: Optional[str] = None, world_size: int = 1):
|
||||
dtype='float16')
|
||||
|
||||
print("\nBuilding fp16 SQ engines")
|
||||
build_engine(str(fp16_sq_ckpt_dir),
|
||||
str(engine_dir / 'fp16-plugin-packed-paged-sq' / tp_pp_dir),
|
||||
*ifb_args)
|
||||
model_spec_current = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
model_spec_current.use_gpt_plugin()
|
||||
model_spec_current.use_packed_input()
|
||||
model_spec_current.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_current.set_quant_method(model_spec.QuantMethod.SMOOTH_QUANT)
|
||||
|
||||
build_engine(
|
||||
str(fp16_sq_ckpt_dir),
|
||||
str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir),
|
||||
*ifb_args)
|
||||
|
||||
if has_safetensor:
|
||||
Path(str(safetensor_file) + ".bak").rename(safetensor_file)
|
||||
|
||||
@ -21,7 +21,12 @@ import platform as _pf
|
||||
import sys as _sys
|
||||
import typing as _tp
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def get_ckpt_without_quatization(model_dir, output_dir):
|
||||
@ -114,6 +119,7 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False):
|
||||
tp_size = 1
|
||||
pp_size = 1
|
||||
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu"
|
||||
input_file = 'input_tokens.npy'
|
||||
|
||||
if only_fp8:
|
||||
# with ifb, new plugin
|
||||
@ -125,7 +131,12 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False):
|
||||
# str(_pl.Path(model_cache) / 'fp8-quantized-modelopt' / 'gptj_tp1_rank0.npz')
|
||||
fp8_ckpt_path = engine_dir / 'fp8' / tp_pp_dir
|
||||
get_ckpt_with_modelopt_quant(hf_dir, fp8_ckpt_path, model_cache)
|
||||
build_engine(fp8_ckpt_path, engine_dir / 'fp8-plugin' / tp_pp_dir,
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FP8)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
build_engine(fp8_ckpt_path,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--gpt_attention_plugin=float16',
|
||||
'--paged_kv_cache=enable', '--remove_input_padding=enable',
|
||||
"--context_fmha=disable")
|
||||
@ -133,21 +144,28 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False):
|
||||
fp16_ckpt_path = engine_dir / 'fp16' / tp_pp_dir
|
||||
get_ckpt_without_quatization(hf_dir, fp16_ckpt_path)
|
||||
print("\nBuilding fp16-plugin engine")
|
||||
build_engine(fp16_ckpt_path, engine_dir / 'fp16-plugin' / tp_pp_dir,
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
|
||||
build_engine(fp16_ckpt_path,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--gpt_attention_plugin=float16',
|
||||
'--paged_kv_cache=disable',
|
||||
'--remove_input_padding=disable', "--context_fmha=disable")
|
||||
|
||||
print("\nBuilding fp16-plugin-packed engine")
|
||||
model_spec_obj.use_packed_input()
|
||||
build_engine(fp16_ckpt_path,
|
||||
engine_dir / 'fp16-plugin-packed' / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--gpt_attention_plugin=float16',
|
||||
'--paged_kv_cache=disable',
|
||||
'--remove_input_padding=enable', "--context_fmha=disable")
|
||||
|
||||
print("\nBuilding fp16-plugin-packed-paged engine")
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
build_engine(fp16_ckpt_path,
|
||||
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--gpt_attention_plugin=float16',
|
||||
'--paged_kv_cache=enable', '--remove_input_padding=enable',
|
||||
"--context_fmha=disable")
|
||||
|
||||
@ -19,7 +19,12 @@ import pathlib as _pl
|
||||
import platform as _pf
|
||||
import sys as _sys
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, *args):
|
||||
@ -61,7 +66,7 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
|
||||
if model_cache:
|
||||
print("Copy model from model_cache")
|
||||
model_cache_dir = _pl.Path(model_cache) / 'llama-models' / model_name
|
||||
assert (model_cache_dir.is_dir())
|
||||
assert (model_cache_dir.is_dir()), model_cache_dir
|
||||
|
||||
if _pf.system() == "Windows":
|
||||
wincopy(source=str(model_cache_dir),
|
||||
@ -77,14 +82,22 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
|
||||
|
||||
engine_dir = models_dir / 'rt_engine' / model_name
|
||||
|
||||
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
|
||||
tp_pp_sizes = [(1, 1)]
|
||||
if only_multi_gpu:
|
||||
tp_pp_sizes = [(1, 4), (4, 1), (2, 2)]
|
||||
for tp_size, pp_size in tp_pp_sizes:
|
||||
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu"
|
||||
print(f"\nBuilding fp16 tp{tp_size} pp{pp_size} engine")
|
||||
model_spec_obj.use_tensor_parallelism(tp_size)
|
||||
model_spec_obj.use_pipeline_parallelism(pp_size)
|
||||
|
||||
build_engine(hf_dir,
|
||||
engine_dir / f'fp16-plugin-packed-paged/{tp_pp_dir}',
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
f'--tp_size={tp_size}', f'--pp_size={pp_size}')
|
||||
|
||||
print("Done.")
|
||||
|
||||
@ -21,7 +21,12 @@ import platform as _pf
|
||||
import sys as _sys
|
||||
import typing as _tp
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, ckpt_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
@ -106,23 +111,30 @@ def build_engines(model_cache: _tp.Optional[str] = None):
|
||||
|
||||
ckpt_dir = models_dir / 'rt_ckpt' / model_name
|
||||
engine_dir = models_dir / 'rt_engine' / model_name
|
||||
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
model_spec_obj.use_tensor_parallelism(tp_size)
|
||||
model_spec_obj.use_pipeline_parallelism(pp_size)
|
||||
|
||||
print("\nBuilding fp16 engine")
|
||||
build_engine(hf_dir, ckpt_dir / 'fp16-default' / tp_pp_dir,
|
||||
engine_dir / 'fp16-default' / tp_pp_dir,
|
||||
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--remove_input_padding=disable', '--paged_state=disable',
|
||||
'--mamba_conv1d_plugin=disable')
|
||||
print("\nBuilding fp16-plugin engine")
|
||||
build_engine(hf_dir, ckpt_dir / 'fp16-plugin' / tp_pp_dir,
|
||||
engine_dir / 'fp16-plugin' / tp_pp_dir,
|
||||
model_spec_obj.use_mamba_plugin()
|
||||
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--remove_input_padding=disable', '--paged_state=disable')
|
||||
print("\nBuilding fp16-plugin-packed engine")
|
||||
build_engine(hf_dir, ckpt_dir / 'fp16-plugin-packed' / tp_pp_dir,
|
||||
engine_dir / 'fp16-plugin-packed' / tp_pp_dir,
|
||||
model_spec_obj.use_packed_input()
|
||||
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--remove_input_padding=enable', '--paged_state=disable')
|
||||
print("\nBuilding fp16-plugin-packed-paged engine")
|
||||
build_engine(hf_dir, ckpt_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
|
||||
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--remove_input_padding=enable', '--paged_state=enable')
|
||||
print("Done.")
|
||||
|
||||
|
||||
@ -19,7 +19,12 @@ import pathlib as _pl
|
||||
import platform as _pf
|
||||
import sys as _sys
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, medusa_dir: _pl.Path,
|
||||
@ -60,8 +65,8 @@ def build_engines(model_cache: str):
|
||||
print("Copy model from model_cache")
|
||||
model_cache_dir = _pl.Path(model_cache) / model_name
|
||||
medusa_cache_dir = _pl.Path(model_cache) / medusa_name
|
||||
assert model_cache_dir.is_dir()
|
||||
assert medusa_cache_dir.is_dir()
|
||||
assert model_cache_dir.is_dir(), model_cache_dir
|
||||
assert medusa_cache_dir.is_dir(), model_cache_dir
|
||||
|
||||
if _pf.system() == "Windows":
|
||||
wincopy(source=str(model_cache_dir),
|
||||
@ -85,9 +90,15 @@ def build_engines(model_cache: str):
|
||||
|
||||
engine_dir = models_dir / 'rt_engine' / model_name
|
||||
|
||||
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
model_spec_obj.use_medusa()
|
||||
|
||||
print(f"\nBuilding fp16 engine")
|
||||
build_engine(model_dir, medusa_dir,
|
||||
engine_dir / 'fp16-plugin-packed-paged/tp1-pp1-gpu')
|
||||
engine_dir / model_spec_obj.get_model_path() / 'tp1-pp1-gpu')
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
@ -21,7 +21,12 @@ import platform as _pf
|
||||
import sys as _sys
|
||||
import typing as _tp
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
from build_engines_utils import init_model_spec_module, run_command, wincopy
|
||||
|
||||
init_model_spec_module()
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, ckpt_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
@ -105,10 +110,15 @@ def build_engines(model_cache: _tp.Optional[str] = None):
|
||||
run_command([python_exe, "-m", "pip", "install", "transformers>=4.40.0"],
|
||||
env=_os.environ,
|
||||
timeout=300)
|
||||
input_file = 'input_tokens.npy'
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.use_packed_input()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
|
||||
print("\nBuilding fp16-plugin-packed-paged engine")
|
||||
build_engine(hf_dir, ckpt_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
|
||||
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir,
|
||||
build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
engine_dir / model_spec_obj.get_model_path() / tp_pp_dir,
|
||||
'--remove_input_padding=enable', '--paged_state=enable')
|
||||
|
||||
# Restore transformers version
|
||||
|
||||
109
cpp/tests/resources/scripts/build_redrafter_engines.py
Executable file
109
cpp/tests/resources/scripts/build_redrafter_engines.py
Executable file
@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse as _arg
|
||||
import pathlib as _pl
|
||||
import platform as _pf
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
has_tllm_checkpoint: bool):
|
||||
|
||||
if not has_tllm_checkpoint:
|
||||
raise RuntimeError(
|
||||
'Convert checkpoint is not supported for ReDrafter. '
|
||||
'Provide a path that contains a checkpoint in the tllm_ckpt folder and set --has_tllm_checkpoint flag'
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = weight_dir / 'tllm_ckpt'
|
||||
|
||||
build_args = ["trtllm-build"] + (
|
||||
['--checkpoint_dir', str(checkpoint_dir)] if engine_dir else []) + [
|
||||
'--output_dir',
|
||||
str(engine_dir),
|
||||
'--gemm_plugin=float16',
|
||||
'--max_batch_size=8',
|
||||
'--max_input_len=64',
|
||||
'--max_seq_len=1024',
|
||||
'--log_level=error',
|
||||
'--paged_kv_cache=enable',
|
||||
'--remove_input_padding=enable',
|
||||
'--speculative_decoding_mode=explicit_draft_tokens',
|
||||
]
|
||||
|
||||
run_command(build_args)
|
||||
|
||||
|
||||
def build_engines(model_cache: str, has_tllm_checkpoint: bool):
|
||||
resources_dir = _pl.Path(__file__).parent.resolve().parent
|
||||
models_dir = resources_dir / 'models'
|
||||
model_name = 'vicuna_redrafter'
|
||||
if has_tllm_checkpoint:
|
||||
base_model_name = 'vicuna-7b-v1.3'
|
||||
# FIXME(nkorobov): rename folder in the cache
|
||||
# model_name = 'redrafter-vicuna-7b-v1.3'
|
||||
|
||||
if model_cache:
|
||||
print("Copy model from model_cache")
|
||||
model_cache_dir = _pl.Path(model_cache) / model_name
|
||||
assert model_cache_dir.is_dir()
|
||||
if has_tllm_checkpoint:
|
||||
base_model_cache_dir = _pl.Path(model_cache) / base_model_name
|
||||
assert base_model_cache_dir.is_dir()
|
||||
|
||||
if _pf.system() == "Windows":
|
||||
wincopy(source=str(model_cache_dir),
|
||||
dest=model_name,
|
||||
isdir=True,
|
||||
cwd=models_dir)
|
||||
if has_tllm_checkpoint:
|
||||
wincopy(source=str(base_model_cache_dir),
|
||||
dest=base_model_name,
|
||||
isdir=True,
|
||||
cwd=models_dir)
|
||||
else:
|
||||
run_command(
|
||||
["rsync", "-av", str(model_cache_dir), "."], cwd=models_dir)
|
||||
if has_tllm_checkpoint:
|
||||
run_command(["rsync", "-av",
|
||||
str(base_model_cache_dir), "."],
|
||||
cwd=models_dir)
|
||||
|
||||
model_dir = models_dir / model_name
|
||||
assert model_dir.is_dir()
|
||||
|
||||
engine_dir = models_dir / 'rt_engine' / model_name
|
||||
|
||||
print(f"\nBuilding fp16 engine")
|
||||
build_engine(model_dir, engine_dir / 'fp16-plugin-packed-paged/tp1-pp1-gpu',
|
||||
has_tllm_checkpoint)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = _arg.ArgumentParser()
|
||||
parser.add_argument("--model_cache",
|
||||
type=str,
|
||||
help="Directory where models are stored")
|
||||
parser.add_argument(
|
||||
"--has_tllm_checkpoint",
|
||||
action='store_true',
|
||||
help="True if the provided path contains the trt-llm checkpoint.")
|
||||
|
||||
build_engines(**vars(parser.parse_args()))
|
||||
@ -18,6 +18,14 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
resources_dir = Path(__file__).parent.resolve().parent
|
||||
model_path = resources_dir / "models"
|
||||
@ -35,18 +43,30 @@ def generate_output(
|
||||
tp_size = 1
|
||||
pp_size = 1
|
||||
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu/"
|
||||
input_file = f"input_tokens_{model_name}.npy"
|
||||
|
||||
data_input_file_name = resources_dir / "data" / f"input_tokens_{model_name}.npy"
|
||||
data_input_file_name = resources_dir / "data" / input_file
|
||||
if num_beams == 1:
|
||||
output_dir = resources_dir / "data" / model_name / "sampling"
|
||||
else:
|
||||
output_dir = resources_dir / "data" / model_name / f"beam_search_{num_beams}"
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for engine_kind in ["fp16-plugin", "fp16-plugin-packed-paged"]:
|
||||
engine_dir = model_path / 'rt_engine' / model_name / engine_kind / tp_pp_dir
|
||||
output_npy_file_name = output_dir / f"output_tokens_{engine_kind.replace('-', '_')}_tp{tp_size}_pp{pp_size}.npy"
|
||||
output_csv_file_name = output_dir / f"output_tokens_{engine_kind.replace('-', '_')}_tp{tp_size}_pp{pp_size}.csv"
|
||||
model_spec_obj_list = [
|
||||
model_spec.ModelSpec(
|
||||
input_file, _tb.DataType.HALF).use_gpt_plugin().set_kv_cache_type(
|
||||
model_spec.KVCacheType.CONTINUOUS),
|
||||
model_spec.ModelSpec(input_file, _tb.DataType.HALF).use_gpt_plugin().
|
||||
use_packed_input().set_kv_cache_type(model_spec.KVCacheType.PAGED),
|
||||
]
|
||||
|
||||
for model_spec_obj in model_spec_obj_list:
|
||||
engine_dir = model_path / 'rt_engine' / model_name / model_spec_obj.get_model_path(
|
||||
) / tp_pp_dir
|
||||
base_output_name = os.path.splitext(
|
||||
model_spec_obj.get_results_file())[0]
|
||||
output_npy_file_name = output_dir / f'{base_output_name}.npy'
|
||||
output_csv_file_name = output_dir / f'{base_output_name}.csv'
|
||||
|
||||
args_list = [
|
||||
'--engine_dir',
|
||||
@ -85,8 +105,13 @@ def generate_output(
|
||||
data = np.load(str(output_npy_file_name))
|
||||
if model_name == 'chatglm-6b':
|
||||
data[data == 3] = 130005
|
||||
else:
|
||||
elif model_name == 'chatglm2-6b' or model_name == 'chatglm3-6b':
|
||||
data[data == 0] = 2
|
||||
elif model_name == 'glm-10b':
|
||||
data[data == 50256] = 50258
|
||||
else:
|
||||
raise NameError('bad model name')
|
||||
|
||||
np.save(str(output_npy_file_name), data)
|
||||
|
||||
|
||||
@ -97,4 +122,5 @@ if __name__ == '__main__':
|
||||
generate_output(model_name='chatglm2-6b', num_beams=2)
|
||||
generate_output(model_name='chatglm3-6b', num_beams=1)
|
||||
generate_output(model_name='chatglm3-6b', num_beams=2)
|
||||
generate_output(model_name='glm-10b', num_beams=1)
|
||||
print("Done")
|
||||
|
||||
@ -17,12 +17,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def generate_output(engine: str,
|
||||
num_beams: int,
|
||||
input_name: str,
|
||||
output_name: str,
|
||||
model_spec_obj: model_spec.ModelSpec,
|
||||
max_output_len: int = 8,
|
||||
output_logits: bool = False,
|
||||
output_cum_log_probs: bool = False,
|
||||
@ -36,106 +45,140 @@ def generate_output(engine: str,
|
||||
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
|
||||
|
||||
data_dir = resources_dir / 'data'
|
||||
input_file = data_dir / (input_name + '.npy')
|
||||
input_file = data_dir / input_name
|
||||
model_data_dir = data_dir / model
|
||||
if num_beams <= 1:
|
||||
output_dir = model_data_dir / 'sampling'
|
||||
else:
|
||||
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
|
||||
|
||||
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
|
||||
model_spec_obj.use_tensor_parallelism(tp_size).use_pipeline_parallelism(
|
||||
pp_size)
|
||||
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
output_logits_npy = None
|
||||
if output_logits:
|
||||
output_logits_npy = str(output_dir / (output_name + '_logits' + '.npy'))
|
||||
logits_file = base_output_name + '_logits.npy'
|
||||
output_logits_npy = str(output_dir / logits_file)
|
||||
|
||||
results_file = str(output_dir / (base_output_name + '.npy'))
|
||||
results_csv = str(output_dir / (base_output_name + '.csv'))
|
||||
|
||||
args_list = [
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(models_dir / model), '--output_npy',
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(models_dir / model), '--output_npy', results_file, '--output_csv',
|
||||
results_csv, '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--output_logits_npy',
|
||||
str(output_logits_npy), '--use_py_session'
|
||||
]
|
||||
|
||||
output_cum_log_probs_npy = None
|
||||
if output_cum_log_probs:
|
||||
output_cum_log_probs_npy = str(
|
||||
output_dir / (output_name + '_cum_log_probs' + '.npy'))
|
||||
args_list.extend(
|
||||
['--output_cum_log_probs_npy',
|
||||
str(output_cum_log_probs_npy)])
|
||||
assert not os.path.exists(results_file) and not os.path.exists(results_csv)
|
||||
|
||||
if output_cum_log_probs:
|
||||
args_list.extend([
|
||||
'--output_cum_log_probs_npy',
|
||||
f'{output_dir / model_spec_obj.get_cum_log_probs_file()}'
|
||||
])
|
||||
|
||||
output_log_probs_npy = None
|
||||
if output_log_probs:
|
||||
output_log_probs_npy = str(output_dir /
|
||||
(output_name + '_log_probs' + '.npy'))
|
||||
args_list.extend(['--output_log_probs_npy', str(output_log_probs_npy)])
|
||||
args_list.extend([
|
||||
'--output_log_probs_npy',
|
||||
f'{output_dir / model_spec_obj.get_log_probs_file()}'
|
||||
])
|
||||
|
||||
args = run.parse_arguments(args_list)
|
||||
|
||||
print(args_list)
|
||||
run.main(args)
|
||||
|
||||
|
||||
def generate_outputs(num_beams):
|
||||
print('Generating GPT2 FP32 outputs')
|
||||
input_name = 'input_tokens.npy'
|
||||
input_name_long = 'input_tokens_long.npy'
|
||||
|
||||
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.FLOAT)
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
if num_beams == 1:
|
||||
generate_output(engine='fp32-default',
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp32')
|
||||
generate_output(engine='fp32-plugin',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp32_plugin')
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
print('Generating GPT2 FP16 outputs')
|
||||
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
if num_beams == 1:
|
||||
generate_output(engine='fp16-default',
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16')
|
||||
generate_output(engine='fp16-plugin',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin')
|
||||
generate_output(engine='fp16-plugin-packed',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
model_spec_obj.use_packed_input()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed')
|
||||
generate_output(engine='fp16-plugin-packed-paged-gather',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.gather_logits()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed_paged_gather',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj,
|
||||
output_logits=True,
|
||||
output_log_probs=True,
|
||||
output_cum_log_probs=True)
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
|
||||
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed_paged',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj,
|
||||
output_logits=False,
|
||||
output_log_probs=True,
|
||||
output_cum_log_probs=True)
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
model_spec_obj.set_max_output_length(128)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_long_fp16_plugin_packed_paged',
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj,
|
||||
output_logits=False,
|
||||
max_output_len=128)
|
||||
generate_output(
|
||||
engine='fp16-plugin-packed-paged',
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens_long',
|
||||
output_name='output_tokens_long_input_fp16_plugin_packed_paged',
|
||||
output_logits=False)
|
||||
generate_output(engine='fp16-plugin-packed-paged-sq',
|
||||
|
||||
model_spec_obj = model_spec.ModelSpec(input_name_long, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.use_packed_input()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed_paged_sq',
|
||||
input_name=input_name_long,
|
||||
model_spec_obj=model_spec_obj,
|
||||
output_logits=False)
|
||||
|
||||
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.use_packed_input()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.set_quant_method(model_spec.QuantMethod.SMOOTH_QUANT)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj,
|
||||
output_logits=False)
|
||||
|
||||
|
||||
|
||||
@ -18,11 +18,19 @@ import argparse as _arg
|
||||
from pathlib import Path
|
||||
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def generate_output(engine: str,
|
||||
num_beams: int,
|
||||
output_name: str,
|
||||
model_spec_obj: model_spec.ModelSpec,
|
||||
max_output_len: int = 4):
|
||||
|
||||
tp_size = 1
|
||||
@ -42,15 +50,15 @@ def generate_output(engine: str,
|
||||
else:
|
||||
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
|
||||
|
||||
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
args = run.parse_arguments([
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(hf_dir), '--output_npy',
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(output_dir / (base_output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--use_py_session'
|
||||
])
|
||||
@ -58,22 +66,35 @@ def generate_output(engine: str,
|
||||
|
||||
|
||||
def generate_outputs(only_fp8, num_beams):
|
||||
input_file = 'input_tokens.npy'
|
||||
if only_fp8 and num_beams == 1:
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FP8)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
|
||||
print('Generating GPT-J FP8-kv-cache outputs')
|
||||
generate_output(engine='fp8-plugin',
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
output_name='output_tokens_fp8_plugin')
|
||||
model_spec_obj=model_spec_obj)
|
||||
elif not only_fp8:
|
||||
print('Generating GPT-J FP16 outputs')
|
||||
generate_output(engine='fp16-plugin',
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
output_name='output_tokens_fp16_plugin')
|
||||
generate_output(engine='fp16-plugin-packed',
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
model_spec_obj.use_packed_input()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
output_name='output_tokens_fp16_plugin_packed')
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
output_name='output_tokens_fp16_plugin_packed_paged')
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -18,11 +18,19 @@ import argparse as _arg
|
||||
from pathlib import Path
|
||||
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def generate_output(engine: str,
|
||||
num_beams: int,
|
||||
output_name: str,
|
||||
model_spec_obj: model_spec.ModelSpec,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
max_output_len: int = 8):
|
||||
@ -42,15 +50,15 @@ def generate_output(engine: str,
|
||||
else:
|
||||
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
|
||||
|
||||
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
args = run.parse_arguments([
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(hf_dir), '--output_npy',
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(output_dir / (base_output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--use_py_session'
|
||||
])
|
||||
@ -59,15 +67,22 @@ def generate_output(engine: str,
|
||||
|
||||
def generate_outputs(num_beams, only_multi_gpu=False):
|
||||
tp_pp_sizes = [(1, 1)] if not only_multi_gpu else [(4, 1), (2, 2), (1, 4)]
|
||||
model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
|
||||
for tp_size, pp_size in tp_pp_sizes:
|
||||
print(
|
||||
f'Generating outputs for Llama FP16 with TP={tp_size} and PP={pp_size}'
|
||||
)
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
model_spec_obj.use_tensor_parallelism(tp_size)
|
||||
model_spec_obj.use_pipeline_parallelism(pp_size)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
output_name='output_tokens_fp16_plugin_packed_paged')
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -17,12 +17,20 @@
|
||||
from pathlib import Path
|
||||
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def generate_output(engine: str,
|
||||
num_beams: int,
|
||||
input_name: str,
|
||||
output_name: str,
|
||||
model_spec_obj: model_spec.ModelSpec,
|
||||
max_output_len: int = 8,
|
||||
output_logits: bool = False):
|
||||
tp_size = 1
|
||||
@ -34,26 +42,27 @@ def generate_output(engine: str,
|
||||
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
|
||||
|
||||
data_dir = resources_dir / 'data'
|
||||
input_file = data_dir / (input_name + '.npy')
|
||||
input_file = data_dir / input_name
|
||||
model_data_dir = data_dir / model
|
||||
if num_beams <= 1:
|
||||
output_dir = model_data_dir / 'sampling'
|
||||
else:
|
||||
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
|
||||
|
||||
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
output_logits_npy = None
|
||||
if output_logits:
|
||||
output_logits_npy = str(output_dir / (output_name + '_logits' + '.npy'))
|
||||
output_logits_npy = str(output_dir /
|
||||
(base_output_name + '_logits' + '.npy'))
|
||||
|
||||
args = run.parse_arguments([
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(models_dir / 'gpt-neox-20b'), '--output_npy',
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(output_dir / (base_output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--output_logits_npy',
|
||||
str(output_logits_npy), '--use_py_session'
|
||||
@ -63,25 +72,35 @@ def generate_output(engine: str,
|
||||
|
||||
def generate_outputs(num_beams):
|
||||
print('Generating Mamba FP16 outputs')
|
||||
generate_output(engine='fp16-default',
|
||||
input_name = 'input_tokens.npy'
|
||||
model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF)
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS)
|
||||
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16')
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
print('Generating Mamba FP16-plugin outputs')
|
||||
generate_output(engine='fp16-plugin',
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin')
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
print('Generating Mamba FP16-plugin-packed outputs')
|
||||
generate_output(engine='fp16-plugin-packed',
|
||||
model_spec_obj.use_packed_input()
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed')
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
print('Generating Mamba FP16-plugin-packed-paged outputs')
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed_paged')
|
||||
input_name=input_name,
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -18,9 +18,19 @@ import argparse as _arg
|
||||
from pathlib import Path
|
||||
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def generate_output(engine: str, output_name: str, max_output_len: int = 8):
|
||||
def generate_output(engine: str,
|
||||
model_spec_obj: model_spec.ModelSpec,
|
||||
max_output_len: int = 8):
|
||||
|
||||
model = 'vicuna-7b-v1.3'
|
||||
resources_dir = Path(__file__).parent.resolve().parent
|
||||
@ -30,19 +40,19 @@ def generate_output(engine: str, output_name: str, max_output_len: int = 8):
|
||||
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
|
||||
|
||||
data_dir = resources_dir / 'data'
|
||||
input_file = data_dir / 'input_tokens.npy'
|
||||
input_file = data_dir / 'input_vicuna.npy'
|
||||
model_data_dir = data_dir / model
|
||||
output_dir = model_data_dir / 'sampling'
|
||||
|
||||
output_name += '_tp1_pp1'
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
args = run.parse_arguments([
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(hf_dir), '--output_npy',
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(output_dir / (base_output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--use_py_session',
|
||||
'--medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]',
|
||||
'--temperature', '1.0'
|
||||
@ -52,9 +62,18 @@ def generate_output(engine: str, output_name: str, max_output_len: int = 8):
|
||||
|
||||
def generate_outputs():
|
||||
print(f'Generating outputs for Medusa FP16')
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
output_name='output_tokens_long_fp16_plugin_packed_paged',
|
||||
max_output_len=128)
|
||||
max_output_len = 128
|
||||
model_spec_obj = model_spec.ModelSpec('input_tokens_long.npy',
|
||||
_tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_max_output_length(max_output_len)
|
||||
model_spec_obj.use_packed_input()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_medusa()
|
||||
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
model_spec_obj=model_spec_obj,
|
||||
max_output_len=max_output_len)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -17,12 +17,20 @@
|
||||
from pathlib import Path
|
||||
|
||||
import run
|
||||
from build_engines_utils import init_model_spec_module
|
||||
|
||||
init_model_spec_module()
|
||||
import os
|
||||
|
||||
import model_spec
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
|
||||
def generate_output(engine: str,
|
||||
num_beams: int,
|
||||
input_name: str,
|
||||
output_name: str,
|
||||
model_spec_obj: model_spec.ModelSpec,
|
||||
max_output_len: int = 8,
|
||||
output_logits: bool = False):
|
||||
tp_size = 1
|
||||
@ -34,26 +42,27 @@ def generate_output(engine: str,
|
||||
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_dir
|
||||
|
||||
data_dir = resources_dir / 'data'
|
||||
input_file = data_dir / (input_name + '.npy')
|
||||
input_file = data_dir / input_name
|
||||
model_data_dir = data_dir / model
|
||||
if num_beams <= 1:
|
||||
output_dir = model_data_dir / 'sampling'
|
||||
else:
|
||||
output_dir = model_data_dir / ('beam_search_' + str(num_beams))
|
||||
|
||||
output_name += '_tp' + str(tp_size) + '_pp' + str(pp_size)
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
output_logits_npy = None
|
||||
if output_logits:
|
||||
output_logits_npy = str(output_dir / (output_name + '_logits' + '.npy'))
|
||||
output_logits_npy = str(output_dir /
|
||||
(base_output_name + '_logits' + '.npy'))
|
||||
|
||||
args = run.parse_arguments([
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(models_dir / model), '--output_npy',
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(output_dir / (base_output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--output_logits_npy',
|
||||
str(output_logits_npy), '--use_py_session'
|
||||
@ -62,11 +71,17 @@ def generate_output(engine: str,
|
||||
|
||||
|
||||
def generate_outputs(num_beams):
|
||||
input_file = 'input_tokens.npy'
|
||||
model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF)
|
||||
model_spec_obj.use_gpt_plugin()
|
||||
model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED)
|
||||
model_spec_obj.use_packed_input()
|
||||
|
||||
print('Generating RecurrentGemma FP16-plugin-packed-paged outputs')
|
||||
generate_output(engine='fp16-plugin-packed-paged',
|
||||
generate_output(engine=model_spec_obj.get_model_path(),
|
||||
num_beams=num_beams,
|
||||
input_name='input_tokens',
|
||||
output_name='output_tokens_fp16_plugin_packed_paged')
|
||||
input_name=input_file,
|
||||
model_spec_obj=model_spec_obj)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user