mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#349)
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
cd6bbab0b3
commit
b2fd493c16
@ -8,6 +8,7 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./setup.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture.md) | [Results](./docs/source/performance.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
|
||||
@ -18,7 +18,8 @@ import pynvml
|
||||
|
||||
|
||||
def get_memory_info(handle):
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle,
|
||||
version=pynvml.nvmlMemory_v2)
|
||||
total = round(mem_info.total / 1024 / 1024 / 1024, 2)
|
||||
used = round(mem_info.used / 1024 / 1024 / 1024, 2)
|
||||
free = round(mem_info.used / 1024 / 1024 / 1024, 2)
|
||||
|
||||
@ -22,6 +22,7 @@ include(CheckLanguage)
|
||||
include(cmake/modules/set_ifndef.cmake)
|
||||
include(cmake/modules/find_library_create_target.cmake)
|
||||
include(cmake/modules/resolve_dirs.cmake)
|
||||
include(cmake/modules/parse_make_options.cmake)
|
||||
|
||||
project(tensorrt_llm LANGUAGES CXX)
|
||||
|
||||
@ -246,6 +247,22 @@ endif()
|
||||
set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDAToolkit_INCLUDE_DIR})
|
||||
message(STATUS "COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}")
|
||||
|
||||
if(NOT WIN32 AND NOT DEFINED USE_CXX11_ABI)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import torch; print(torch.compiled_with_cxx11_abi(),end='');"
|
||||
RESULT_VARIABLE _PYTHON_SUCCESS
|
||||
OUTPUT_VARIABLE USE_CXX11_ABI)
|
||||
# Convert the bool variable to integer.
|
||||
if(USE_CXX11_ABI)
|
||||
set(USE_CXX11_ABI 1)
|
||||
else()
|
||||
set(USE_CXX11_ABI 0)
|
||||
endif()
|
||||
message(STATUS "USE_CXX11_ABI is set by python Torch to ${USE_CXX11_ABI}")
|
||||
endif()
|
||||
|
||||
if(BUILD_PYT)
|
||||
# Build TORCH_CUDA_ARCH_LIST
|
||||
set(TORCH_CUDA_ARCH_LIST "")
|
||||
@ -304,27 +321,39 @@ print(os.path.dirname(torch.__file__),end='');"
|
||||
message(STATUS "TORCH_CXX_FLAGS: ${TORCH_CXX_FLAGS}")
|
||||
add_compile_options(${TORCH_CXX_FLAGS})
|
||||
add_compile_definitions(TORCH_CUDA=1)
|
||||
|
||||
if(DEFINED USE_CXX11_ABI)
|
||||
parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")
|
||||
if(DEFINED TORCH_CXX_FLAGS__GLIBCXX_USE_CXX11_ABI
|
||||
AND NOT ${TORCH_CXX_FLAGS__GLIBCXX_USE_CXX11_ABI} EQUAL ${USE_CXX11_ABI})
|
||||
message(
|
||||
WARNING
|
||||
"The libtorch compilation options _GLIBCXX_USE_CXX11_ABI=${TORCH_CXX_FLAGS__GLIBCXX_USE_CXX11_ABI} "
|
||||
"found by CMake conflict with the project setting USE_CXX11_ABI=${USE_CXX11_ABI}, and the project "
|
||||
"setting will be discarded.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
elseif(NOT WIN32)
|
||||
if(NOT USE_CXX11_ABI)
|
||||
add_compile_options("-D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
message(STATUS "Build without PyTorch, USE_CXX11_ABI=${USE_CXX11_ABI}")
|
||||
endif()
|
||||
|
||||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" VERSION_STRINGS
|
||||
REGEX "#define NV_TENSORRT_.*")
|
||||
foreach(TYPE MAJOR MINOR PATCH BUILD)
|
||||
string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]" TRT_TYPE_STRING
|
||||
string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]+" TRT_TYPE_STRING
|
||||
${VERSION_STRINGS})
|
||||
string(REGEX MATCH "[0-9]" TRT_${TYPE} ${TRT_TYPE_STRING})
|
||||
endforeach(TYPE)
|
||||
|
||||
foreach(TYPE MAJOR MINOR PATCH)
|
||||
string(REGEX MATCH "NV_TENSORRT_SONAME_${TYPE} [0-9]" TRT_TYPE_STRING
|
||||
${VERSION_STRINGS})
|
||||
string(REGEX MATCH "[0-9]" TRT_SO_${TYPE} ${TRT_TYPE_STRING})
|
||||
string(REGEX MATCH "[0-9]+" TRT_${TYPE} ${TRT_TYPE_STRING})
|
||||
endforeach(TYPE)
|
||||
|
||||
set(TRT_VERSION
|
||||
"${TRT_MAJOR}.${TRT_MINOR}.${TRT_PATCH}"
|
||||
CACHE STRING "TensorRT project version")
|
||||
set(TRT_SOVERSION
|
||||
"${TRT_SO_MAJOR}"
|
||||
"${TRT_MAJOR}"
|
||||
CACHE STRING "TensorRT library so version")
|
||||
message(
|
||||
STATUS
|
||||
|
||||
28
cpp/cmake/modules/parse_make_options.cmake
Normal file
28
cpp/cmake/modules/parse_make_options.cmake
Normal file
@ -0,0 +1,28 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
|
||||
function(parse_make_options options result)
|
||||
foreach(option ${options})
|
||||
string(REGEX REPLACE "(-D|-)" "" option ${option})
|
||||
string(REPLACE "=" ";" option ${option})
|
||||
list(GET option 0 option_name)
|
||||
list(GET option 1 option_value)
|
||||
set(${result}_${option_name}
|
||||
${option_value}
|
||||
PARENT_SCOPE)
|
||||
endforeach()
|
||||
endfunction()
|
||||
@ -86,6 +86,7 @@ private:
|
||||
std::shared_ptr<TrtGptModel> mTrtGptModel;
|
||||
SizeType mMaxInputLen;
|
||||
SizeType mMaxOutputLen;
|
||||
SizeType mMaxKvCacheLen;
|
||||
SizeType mMaxNumSequences;
|
||||
std::optional<uint64_t> mTerminateReqId;
|
||||
|
||||
|
||||
@ -29,14 +29,17 @@ class KvCacheConfig
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit KvCacheConfig(
|
||||
std::optional<SizeType> maxTokens = std::nullopt, std::optional<float> freeGpuMemoryFraction = std::nullopt)
|
||||
explicit KvCacheConfig(std::optional<SizeType> maxTokens = std::nullopt,
|
||||
std::optional<SizeType> maxKvCacheLength = std::nullopt,
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt)
|
||||
: maxTokens{maxTokens}
|
||||
, maxKvCacheLength{maxKvCacheLength}
|
||||
, freeGpuMemoryFraction{freeGpuMemoryFraction}
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<SizeType> maxTokens;
|
||||
std::optional<SizeType> maxKvCacheLength;
|
||||
std::optional<float> freeGpuMemoryFraction;
|
||||
|
||||
static constexpr auto kDefaultGpuMemFraction = 0.85f;
|
||||
|
||||
@ -217,7 +217,7 @@ public:
|
||||
|
||||
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
|
||||
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxBlocksPerSeq, nvinfer1::DataType dtype, CudaStreamPtr stream);
|
||||
SizeType maxBlocksPerSeq, SizeType maxKvCacheLength, nvinfer1::DataType dtype, CudaStreamPtr stream);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
@ -330,6 +330,9 @@ private:
|
||||
SizeType mMaxBeamWidth;
|
||||
// Maximum number of blocks per sequence
|
||||
SizeType mMaxBlocksPerSeq;
|
||||
// Maximum kv cache length per sequence
|
||||
// Enable cyclic kv cache when it exceeds
|
||||
SizeType mMaxKvCacheLength;
|
||||
// Pools
|
||||
std::vector<runtime::ITensor::SharedPtr> mPools;
|
||||
// Block manager
|
||||
|
||||
@ -29,9 +29,10 @@ class DecodingInput
|
||||
public:
|
||||
using TensorPtr = std::shared_ptr<ITensor const>;
|
||||
|
||||
DecodingInput(SizeType maxLength, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
|
||||
DecodingInput(SizeType maxLength, SizeType maxKvCacheLength, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
|
||||
: step{maxLength}
|
||||
, maxLength{maxLength}
|
||||
, maxKvCacheLength{maxKvCacheLength}
|
||||
, batchSize{batchSize}
|
||||
, logits{std::move(logits)}
|
||||
, endIds{std::move(endIds)}
|
||||
@ -43,6 +44,7 @@ public:
|
||||
// mandatory parameters
|
||||
SizeType step;
|
||||
SizeType maxLength;
|
||||
SizeType maxKvCacheLength;
|
||||
SizeType batchSize;
|
||||
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
|
||||
TensorPtr endIds; // [batchSize * beamWidth], on gpu
|
||||
|
||||
@ -54,7 +54,7 @@ public:
|
||||
bool packed; // indicates if ids are packed or padded to maxInputLength
|
||||
|
||||
// optional parameters
|
||||
TensorPtr embeddingBiasOpt; // [vocabSizePadded], on gpu
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [batchSize, 2, stopWordsLength], on gpu
|
||||
std::optional<SizeType> maxNewTokens; // max number of tokens to generate
|
||||
|
||||
@ -44,8 +44,8 @@ public:
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(
|
||||
SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) override;
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(
|
||||
@ -166,6 +166,7 @@ private:
|
||||
std::vector<SizeType> mMaxNewTokens;
|
||||
std::vector<SizeType> mBeamWidths;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxKvCacheLength{};
|
||||
SizeType mActualBatchSize{};
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -140,10 +140,10 @@ private:
|
||||
|
||||
void createContexts(SizeType numBatchesCtx, SizeType numBatchesGen, bool useCudaGraphs);
|
||||
void createBuffers(SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
void createKvCacheManager(
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
|
||||
|
||||
void executeContextStep(std::vector<GenerationInput> const& microBatches,
|
||||
@ -258,6 +258,7 @@ private:
|
||||
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
|
||||
|
||||
SizeType mDecoderMaxSequenceLength{};
|
||||
SizeType mDecoderMaxKvCacheLength{};
|
||||
|
||||
LoggerPtr mLogger;
|
||||
std::shared_ptr<TllmRuntime> mRuntime;
|
||||
|
||||
@ -73,8 +73,8 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(
|
||||
SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
@ -48,6 +49,9 @@ public:
|
||||
using UniqueConstPtr = std::unique_ptr<ITensor const>;
|
||||
using SharedConstPtr = std::shared_ptr<ITensor const>;
|
||||
using Shape = nvinfer1::Dims;
|
||||
using DimType = std::remove_reference_t<decltype(Shape::d[0])>;
|
||||
|
||||
~ITensor() override = default;
|
||||
|
||||
//!
|
||||
//! \brief Returns the tensor dimensions.
|
||||
@ -59,7 +63,13 @@ public:
|
||||
//!
|
||||
virtual void reshape(Shape const& dims) = 0;
|
||||
|
||||
~ITensor() override = default;
|
||||
void resize(std::size_t newSize) override
|
||||
{
|
||||
if (newSize == getSize())
|
||||
return;
|
||||
|
||||
reshape(makeShape({castSize(newSize)}));
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Not allowed to copy.
|
||||
@ -101,18 +111,7 @@ public:
|
||||
//! \param dim The dimension that should be removed ("squeezed").
|
||||
//! \return A new shape without the unit dimension.
|
||||
//!
|
||||
static Shape squeeze(Shape const& shape, SizeType dim)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(shape.nbDims > 0, "Cannot squeeze 1-dimensional tensor");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
dim < shape.nbDims, common::fmtstr("Invalid index %d, tensor has %d dimensions", dim, shape.nbDims));
|
||||
TLLM_CHECK_WITH_INFO(shape.d[dim] == 1, "Can only squeeze dimension of size 1");
|
||||
|
||||
Shape newDims{shape.nbDims - 1};
|
||||
std::copy(shape.d, shape.d + dim, newDims.d);
|
||||
std::copy(shape.d + dim + 1, shape.d + shape.nbDims, newDims.d + dim);
|
||||
return newDims;
|
||||
}
|
||||
static Shape squeeze(Shape const& shape, SizeType dim);
|
||||
|
||||
//!
|
||||
//! \brief Add a *unit* dimension to `shape` at the specified position.
|
||||
@ -121,17 +120,7 @@ public:
|
||||
//! \param dim The dimension where unit dimension should be added.
|
||||
//! \return A new shape with the added unit dimension.
|
||||
//!
|
||||
static Shape unsqueeze(Shape const& shape, SizeType dim)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(dim <= shape.nbDims && dim >= 0,
|
||||
common::fmtstr("Invalid dim %d, tensor has %d dimensions", dim, shape.nbDims));
|
||||
|
||||
Shape newDims{shape.nbDims + 1};
|
||||
std::copy(shape.d, shape.d + dim, newDims.d);
|
||||
newDims.d[dim] = 1;
|
||||
std::copy(shape.d + dim, shape.d + shape.nbDims, newDims.d + dim + 1);
|
||||
return newDims;
|
||||
}
|
||||
static Shape unsqueeze(Shape const& shape, SizeType dim);
|
||||
|
||||
//!
|
||||
//! \brief Removes the given *unit* dimensions from this tensor.
|
||||
@ -251,6 +240,13 @@ public:
|
||||
|
||||
protected:
|
||||
ITensor() = default;
|
||||
|
||||
static DimType castSize(size_t newSize)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
newSize <= std::numeric_limits<DimType>::max(), "New size is too large. Use reshape() instead.");
|
||||
return static_cast<DimType>(newSize);
|
||||
}
|
||||
};
|
||||
|
||||
//! \brief Utility function to print a shape.
|
||||
|
||||
@ -71,7 +71,7 @@ public:
|
||||
// Function assumes that the first numContextRequests requests in the batch are context requests
|
||||
void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests,
|
||||
const std::vector<SizeType>& reqBeamWidths, const std::vector<SizeType>& reqPromptLengths,
|
||||
BufferManager& manager, bool packedInput);
|
||||
BufferManager const& manager, bool packedInput);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -84,14 +84,6 @@ if(BUILD_BATCH_MANAGER)
|
||||
else()
|
||||
add_library(${BATCH_MANAGER_TARGET} STATIC IMPORTED)
|
||||
if(NOT WIN32) # Linux
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import torch; print(torch.compiled_with_cxx11_abi(),end='');"
|
||||
RESULT_VARIABLE _PYTHON_SUCCESS
|
||||
OUTPUT_VARIABLE USE_CXX11_ABI)
|
||||
|
||||
message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}")
|
||||
|
||||
if(USE_CXX11_ABI)
|
||||
set(BATCH_MANAGER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/libtensorrt_llm_batch_manager_static.a"
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f591dd181613b14f7ded3ba3e167d14073564254bc46db8c4bd9636d6d896b16
|
||||
size 1611436
|
||||
oid sha256:7a3ec9a8760d7b8ace53e420572aeb1b3607effc92fd56e13351fa4cbddbbb37
|
||||
size 1646420
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:21d17a9fa736d033ad77270a0fbcdd09c27dfab3f871d92a5ffa0cb744fa48fd
|
||||
size 1623126
|
||||
oid sha256:114348de9f6d1b3fa147f4fbccede10b7dbe13da6c5c86e968bb56bf05f9ec5a
|
||||
size 1657852
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
e1dc326c0c45864b9e7963b4d92d322f libtensorrt_llm_batch_manager_static.a
|
||||
d2e9d76efe6b4173270aa6b494dfe59c libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
07363ea7a6fdd6eeedc1670dedeeaedff7f9a848 commit
|
||||
0776a4d41c06192c4ca0409ad8b837de libtensorrt_llm_batch_manager_static.a
|
||||
c901725d5d278fd8d41f524f81fe5170 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
b3330c65d9b23d4f20c2b8d5a7c24cd45c910cd4 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3fe444bf079ce35262b932302806b372ccb677182969e3bba45698343e5e350f
|
||||
size 1523444
|
||||
oid sha256:abdce9bc64cecddb39ed14809eefc8bcf7164524a6dd20ec7c8167229f3c22a3
|
||||
size 1557782
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:99641389fdf26f6324b7465df0b61b74946787a6a147d145de23b444261e6e5f
|
||||
size 1524188
|
||||
oid sha256:a9109b506e993a041ea238f992bec2a5064dffd9c0a7af10cca0d4d96c5047a9
|
||||
size 1557482
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
b10b0e00d0132b04969d779af45d73d0 libtensorrt_llm_batch_manager_static.a
|
||||
3ad06255afdaa8450c133d1d1bc486c4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
25d1ebdd5977208c25023329c621e970 libtensorrt_llm_batch_manager_static.a
|
||||
5cb1a7a13db34fcaee6b89fcdc1212ce libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -99,8 +99,8 @@ public:
|
||||
|
||||
~mhaImpl() {}
|
||||
|
||||
void setup(const int b, const int s, const int total_seqlen, const bool has_alibi, const bool scale_alibi,
|
||||
const int tp_size, const int tp_rank)
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||
// Note that we apply scales and bias in the order of
|
||||
@ -170,11 +170,19 @@ public:
|
||||
launch_params.use_tma = true;
|
||||
}
|
||||
|
||||
// alibi.
|
||||
if (has_alibi)
|
||||
{
|
||||
params.has_alibi = true;
|
||||
params.alibi_params = AlibiParams(mNumHeads, s, tp_size, tp_rank, scale_after_alibi);
|
||||
}
|
||||
|
||||
// Sliding_window_causal mask.
|
||||
if (s > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
||||
{
|
||||
params.sliding_window_size = sliding_window_size;
|
||||
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: assume that heads_interleaved = false (b, s, 3, h, d), and sequences are padded/non-padded
|
||||
@ -271,8 +279,6 @@ public:
|
||||
{
|
||||
// BF16 FMHA only accumulates on FP32
|
||||
launch_params.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||
// sliding_window_causal is disabled temporally.
|
||||
// TODO (perkzz): It will be enabled when the sliding window attention is fully supported.
|
||||
launch_params.attention_mask_type
|
||||
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
|
||||
params.h_kv = num_kv_heads;
|
||||
@ -360,10 +366,10 @@ FusedMHARunnerV2::FusedMHARunnerV2(
|
||||
|
||||
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
|
||||
|
||||
void FusedMHARunnerV2::setup(const int b, const int s, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
pimpl->setup(b, s, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
}
|
||||
|
||||
bool FusedMHARunnerV2::fmha_supported()
|
||||
|
||||
@ -47,8 +47,8 @@ public:
|
||||
|
||||
virtual ~MHARunner() = default;
|
||||
|
||||
virtual void setup(const int b, const int s, const int total_seqlen, const bool has_alibi = false,
|
||||
const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
||||
virtual void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
||||
= 0;
|
||||
|
||||
static bool fmha_supported(const int headSize, const int sm);
|
||||
@ -80,8 +80,9 @@ public:
|
||||
|
||||
~FusedMHARunnerV2(); // for pimpl
|
||||
|
||||
void setup(const int b, const int s, const int total_seqlen, const bool has_alibi = false,
|
||||
const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) override;
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
||||
const int tp_rank = 0) override;
|
||||
|
||||
bool fmha_supported() override;
|
||||
|
||||
|
||||
@ -49,8 +49,7 @@ enum class ContextAttentionMaskType
|
||||
{
|
||||
PADDING,
|
||||
CAUSAL,
|
||||
// The past attention length is limited.
|
||||
LIMITED_LENGTH_CAUSAL
|
||||
SLIDING_WINDOW_CAUSAL
|
||||
};
|
||||
|
||||
constexpr int32_t kSM_70 = 70;
|
||||
|
||||
@ -282,6 +282,7 @@ public:
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
|
||||
if (!forceUnroll)
|
||||
|
||||
@ -102,9 +102,12 @@ struct Multihead_attention_params_base
|
||||
int batch_size = 0;
|
||||
// The beam width
|
||||
int beam_width = 0;
|
||||
// The sequence length.
|
||||
// TODO: change name max_seq_len
|
||||
int memory_max_len = 0;
|
||||
// By default, max_kv_cache_length == cyclic_kv_cache_length
|
||||
// unless each layer has different cyclic kv cache length.
|
||||
// Max cache capacity (used to allocate KV cache)
|
||||
int max_kv_cache_length = 0;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int cyclic_kv_cache_length = 0;
|
||||
// The number of heads (H).
|
||||
int num_heads = 0;
|
||||
// Controls MHA/MQA/GQA
|
||||
@ -148,7 +151,7 @@ struct Multihead_attention_params_base
|
||||
bool fp8_kv_cache = false;
|
||||
|
||||
// Multi-block setups
|
||||
bool multi_block_mode = false;
|
||||
mutable bool multi_block_mode = false;
|
||||
|
||||
// Number of streaming processors on the device.
|
||||
// Tune block size to maximum occupancy.
|
||||
|
||||
@ -44,8 +44,8 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_AT
|
||||
using Tk = typename kernel_type_t<T>::Type;
|
||||
// The amount of shared memory needed to store the Q*K^T values in float.
|
||||
const int max_timesteps = DO_CROSS_ATTENTION
|
||||
? params.memory_max_len
|
||||
: min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.memory_max_len);
|
||||
? params.cyclic_kv_cache_length
|
||||
: min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_kv_cache_length);
|
||||
const auto qk_elts = static_cast<std::size_t>(divUp(max_timesteps + 1, 4)); // explicit cast because of the sign
|
||||
const auto qk_sz = qk_elts * 16;
|
||||
|
||||
@ -90,49 +90,34 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_AT
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int Dh, bool DO_CROSS_ATTENTION>
|
||||
inline size_t multi_block_grid_setup(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
|
||||
int threads_per_block, int tlength, bool do_multi_block)
|
||||
inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
|
||||
int blocks_per_sm, int block_size, int tlength, bool do_multi_block)
|
||||
{
|
||||
if (!do_multi_block)
|
||||
{
|
||||
return 1;
|
||||
params.multi_block_mode = false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto constexpr threads_per_value = mmha::threads_per_value<T>(mmha::dh_max(Dh));
|
||||
params.seq_len_tile
|
||||
= mmha::divUp(params.multi_processor_count * blocks_per_sm, params.batch_size * params.num_heads);
|
||||
|
||||
// Make sure: seq_len_tile * threads_per_value <= threads_per_block (for multi_block_mode)
|
||||
params.seq_len_tile = std::floor(threads_per_block / threads_per_value);
|
||||
const int threads_per_value = mmha::threads_per_value<T>(mmha::dh_max(Dh));
|
||||
// Make sure that each block at least processes one loop of kv (unroll size is default at 8).
|
||||
const int seq_len_per_kv_loop = mmha::divUp(block_size, threads_per_value) * 8;
|
||||
const int max_seq_len_tile = std::min(mmha::divUp(tlength + 1, seq_len_per_kv_loop), params.max_seq_len_tile);
|
||||
|
||||
assert(params.seq_len_tile <= params.max_seq_len_tile);
|
||||
params.seq_len_tile = std::min(params.seq_len_tile, max_seq_len_tile);
|
||||
|
||||
params.timesteps_per_block = mmha::divUp(tlength, params.seq_len_tile);
|
||||
// We should consider the new timestep.
|
||||
params.timesteps_per_block = mmha::divUp(tlength + 1, params.seq_len_tile);
|
||||
|
||||
#ifndef ENABLE_MULTI_BLOCK_OPTION
|
||||
do_multi_block = false;
|
||||
#endif
|
||||
params.multi_block_mode = (params.seq_len_tile > 1);
|
||||
|
||||
// Return the sequence length tile if using multi block modes.
|
||||
return params.seq_len_tile;
|
||||
grid.z = params.seq_len_tile;
|
||||
}
|
||||
|
||||
#define MMHA_LAUNCH_CHECK(DYNAMIC_THDS_PER_BLOCK) \
|
||||
std::size_t const dynamic_smem_sz{ \
|
||||
mmha::smem_size_in_bytes<T, Dh, DO_MULTI_BLOCK>(params, DYNAMIC_THDS_PER_BLOCK)}; \
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */ \
|
||||
if (dynamic_smem_sz >= 46 * 1024) \
|
||||
{ \
|
||||
cudaError_t res = cudaFuncSetAttribute(mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, \
|
||||
DYNAMIC_THDS_PER_BLOCK, HAS_BEAMS, DO_MULTI_BLOCK>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
|
||||
TLLM_CHECK_WITH_INFO( \
|
||||
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
|
||||
} \
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&available_blocks, \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, HAS_BEAMS, \
|
||||
DO_MULTI_BLOCK>, \
|
||||
DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz));
|
||||
|
||||
#define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK) \
|
||||
std::size_t const dynamic_smem_sz{ \
|
||||
mmha::smem_size_in_bytes<T, Dh, DO_MULTI_BLOCK>(params, DYNAMIC_THDS_PER_BLOCK)}; \
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */ \
|
||||
@ -145,34 +130,51 @@ inline size_t multi_block_grid_setup(const Multihead_attention_params<T, DO_CROS
|
||||
TLLM_CHECK_WITH_INFO( \
|
||||
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
|
||||
} \
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&available_blocks, \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>, \
|
||||
DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz));
|
||||
|
||||
#define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK, ENABLE_MULTI_BLOCK) \
|
||||
std::size_t const dynamic_smem_sz{ \
|
||||
mmha::smem_size_in_bytes<T, Dh, ENABLE_MULTI_BLOCK>(params, DYNAMIC_THDS_PER_BLOCK)}; \
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */ \
|
||||
if (dynamic_smem_sz >= 46 * 1024) \
|
||||
{ \
|
||||
cudaError_t res = cudaFuncSetAttribute( \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
|
||||
TLLM_CHECK_WITH_INFO( \
|
||||
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
|
||||
} \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK> \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK> \
|
||||
<<<grid, DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz, stream>>>(params, kv_cache_buffer);
|
||||
|
||||
// if resources are not enough to launch 512 threads per block, we will fallback to 256.
|
||||
#define MMHA_LAUNCH_512_BLOCKSIZE() \
|
||||
int available_blocks = -1; \
|
||||
#define MMHA_512_BLOCKSIZE_CHECK() \
|
||||
MMHA_LAUNCH_CHECK(512); \
|
||||
if (available_blocks <= 0) \
|
||||
{ \
|
||||
MMHA_KERNEL(256); \
|
||||
MMHA_LAUNCH_CHECK(256); \
|
||||
dynamic_block_size = 256; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MMHA_KERNEL(512); \
|
||||
dynamic_block_size = 512; \
|
||||
}
|
||||
|
||||
// if resources are not enough to launch 1024 threads per block, we will fallback to 512.
|
||||
#define MMHA_LAUNCH_1024_BLOCKSIZE() \
|
||||
int available_blocks = -1; \
|
||||
#define MMHA_1024_BLOCKSIZE_CHECK() \
|
||||
MMHA_LAUNCH_CHECK(1024); \
|
||||
if (available_blocks <= 0) \
|
||||
if (available_blocks > 0) \
|
||||
{ \
|
||||
MMHA_LAUNCH_512_BLOCKSIZE(); \
|
||||
dynamic_block_size = 1024; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MMHA_KERNEL(1024); \
|
||||
MMHA_512_BLOCKSIZE_CHECK(); \
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -182,55 +184,83 @@ template <typename T, typename T_cache, typename KVCacheBuffer, typename KernelP
|
||||
void mmha_launch_kernel_ex(
|
||||
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream, int tlength)
|
||||
{
|
||||
std::size_t const seq_len_tile{mmha::multi_block_grid_setup<T, Dh, KernelParamsType::DO_CROSS_ATTENTION>(
|
||||
params, THDS_PER_BLOCK, tlength, DO_MULTI_BLOCK)};
|
||||
dim3 grid{static_cast<unsigned>(params.num_heads), static_cast<unsigned>(params.batch_size),
|
||||
static_cast<unsigned>(seq_len_tile)};
|
||||
dim3 grid{static_cast<unsigned>(params.num_heads), static_cast<unsigned>(params.batch_size), 1};
|
||||
|
||||
if (DO_MULTI_BLOCK)
|
||||
const int kernel_total_blocks = params.batch_size * params.num_heads;
|
||||
// Don't tune the block size if batchxhead is large enough.
|
||||
// The max number of warps we can launch per SM is 32 limited by registers.
|
||||
if (kernel_total_blocks >= params.multi_processor_count * 4)
|
||||
{
|
||||
MMHA_KERNEL(THDS_PER_BLOCK);
|
||||
MMHA_KERNEL(THDS_PER_BLOCK, false);
|
||||
return;
|
||||
}
|
||||
else
|
||||
|
||||
// Tune block size based on batchxhead to increase occupancy.
|
||||
int num_blocks_per_sm = -1;
|
||||
// Set 0 dynamic shared memory size as we need the number of available blocks limited by registers.
|
||||
// Dynamic shared memory is fixed for different block size.
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm,
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, THDS_PER_BLOCK,
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>,
|
||||
THDS_PER_BLOCK, 0));
|
||||
|
||||
int block_size_factor
|
||||
= min(mmha::divUp(params.multi_processor_count * num_blocks_per_sm, kernel_total_blocks), num_blocks_per_sm);
|
||||
// Max block size is 1024.
|
||||
int dynamic_block_size = min(THDS_PER_BLOCK * block_size_factor, 1024);
|
||||
|
||||
// Check if resources are enough for launch.
|
||||
int available_blocks = -1;
|
||||
if (dynamic_block_size < 512)
|
||||
{
|
||||
const int kernel_total_blocks = params.batch_size * params.num_heads;
|
||||
// Don't tune the block size if batchxhead is large enough.
|
||||
// The max number of warps we can launch per SM is 32 limited by registers.
|
||||
if (kernel_total_blocks >= params.multi_processor_count * 4)
|
||||
{
|
||||
MMHA_KERNEL(THDS_PER_BLOCK);
|
||||
return;
|
||||
}
|
||||
MMHA_LAUNCH_CHECK(256);
|
||||
dynamic_block_size = 256;
|
||||
}
|
||||
else if (dynamic_block_size < 1024)
|
||||
{
|
||||
MMHA_512_BLOCKSIZE_CHECK();
|
||||
}
|
||||
else if (dynamic_block_size == 1024)
|
||||
{
|
||||
MMHA_1024_BLOCKSIZE_CHECK();
|
||||
}
|
||||
|
||||
// Tune block size based on batchxhead to increase occupancy.
|
||||
int num_blocks_per_sm = -1;
|
||||
// Set 0 dynamic shared memory size as we need the number of available blocks limited by registers.
|
||||
// Dynamic shared memory is fixed for different block size.
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm,
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, THDS_PER_BLOCK, HAS_BEAMS,
|
||||
DO_MULTI_BLOCK>,
|
||||
THDS_PER_BLOCK, 0));
|
||||
// If blocks with larger block size already fill all SMs, then disable the multi blocks mode.
|
||||
mmha::multi_block_grid_setup<T, Dh>(grid, params, dynamic_block_size, available_blocks, tlength, DO_MULTI_BLOCK);
|
||||
|
||||
int block_size_factor = min(
|
||||
mmha::divUp(params.multi_processor_count * num_blocks_per_sm, kernel_total_blocks), num_blocks_per_sm);
|
||||
// Max block size is 1024.
|
||||
const int dynamic_block_size = min(THDS_PER_BLOCK * block_size_factor, 1024);
|
||||
|
||||
// Make sure number of threads per block is power of 2.
|
||||
if (dynamic_block_size <= 256)
|
||||
// Launch kernels based on the valid block size.
|
||||
switch (dynamic_block_size)
|
||||
{
|
||||
case 256:
|
||||
if (params.multi_block_mode)
|
||||
{
|
||||
MMHA_KERNEL(256);
|
||||
MMHA_KERNEL(256, true);
|
||||
}
|
||||
else if (dynamic_block_size <= 512)
|
||||
else
|
||||
{
|
||||
// Check if the kernel with new block size can be launched in terms of resources.
|
||||
MMHA_LAUNCH_512_BLOCKSIZE();
|
||||
MMHA_KERNEL(256, false);
|
||||
}
|
||||
else if (dynamic_block_size <= 1024)
|
||||
break;
|
||||
case 512:
|
||||
if (params.multi_block_mode)
|
||||
{
|
||||
// Check if the kernel with new block size can be launched in terms of resources.
|
||||
MMHA_LAUNCH_1024_BLOCKSIZE();
|
||||
MMHA_KERNEL(512, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
MMHA_KERNEL(512, false);
|
||||
}
|
||||
break;
|
||||
case 1024:
|
||||
if (params.multi_block_mode)
|
||||
{
|
||||
MMHA_KERNEL(1024, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
MMHA_KERNEL(1024, false);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -263,23 +293,15 @@ void mmha_launch_kernel_dispatch(
|
||||
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
|
||||
{
|
||||
int const tlength = params.timestep;
|
||||
if (tlength < 1024)
|
||||
if (params.multi_block_mode)
|
||||
{
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, false>(
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, true>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (params.multi_block_mode)
|
||||
{
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, true>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
}
|
||||
else
|
||||
{
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, false>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
}
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, false>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1244,10 +1244,10 @@ template <
|
||||
// The number of threads per value.
|
||||
unsigned THREADS_PER_VALUE = threads_per_value<T>(dh_max(Dh)),
|
||||
// The unroll factor for loading from K cache.
|
||||
unsigned K_LOOP_UNROLL = 8,
|
||||
// The unroll factor for loading from V cache.
|
||||
// Set it default to 4 for higher occupancy (by reducing registers usage).
|
||||
unsigned V_LOOP_UNROLL = 4>
|
||||
unsigned K_LOOP_UNROLL = 4,
|
||||
// The unroll factor for loading from V cache.
|
||||
unsigned V_LOOP_UNROLL = 8>
|
||||
__global__ void masked_multihead_attention_kernel(
|
||||
Multihead_attention_params<T, DO_CROSS_ATTENTION> params, KVCacheBuffer kvCacheBuffer)
|
||||
{
|
||||
@ -1271,10 +1271,11 @@ __global__ void masked_multihead_attention_kernel(
|
||||
static_assert(Dh_MAX >= WARP_SIZE);
|
||||
static_assert(Dh_MAX >= Dh);
|
||||
|
||||
// The maximum sequence length in the kv_cache, i.e., an upper bound on L.
|
||||
// The maximum sequence length in the cyclic kv_cache, i.e., an upper bound on L.
|
||||
// Note that the maximum sequence length supported by the model might be greater than this.
|
||||
const auto max_seq_len = static_cast<unsigned>(params.memory_max_len);
|
||||
assert(max_seq_len > 0);
|
||||
// Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
|
||||
// By default, you can assume that they are the same.
|
||||
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_kv_cache_length);
|
||||
// The current timestep (including paddings).
|
||||
// It is only used to calculate the smem stride.
|
||||
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
|
||||
@ -1298,8 +1299,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
#ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
||||
if (sizeof(Tk) != 4)
|
||||
{
|
||||
// TODO - change to tlength
|
||||
const auto max_timesteps = DO_CROSS_ATTENTION ? max_seq_len : min(timestep, max_seq_len);
|
||||
const auto max_timesteps = DO_CROSS_ATTENTION ? cyclic_kv_cache_len : min(timestep, cyclic_kv_cache_len);
|
||||
logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
|
||||
}
|
||||
Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
|
||||
@ -1345,6 +1345,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Use alignment for safely casting the shared buffers as Qk_vec_k and K_vec_k.
|
||||
// Shared memory to store Q inputs.
|
||||
__shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk q_smem[Dh_MAX];
|
||||
__shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk k_smem[Dh_MAX];
|
||||
|
||||
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
|
||||
static_assert(Dh_MAX % THREADS_PER_VALUE == 0); // trivially satisfied since THREADS_PER_VALUE == Dh_MAX / p
|
||||
@ -1420,8 +1421,17 @@ __global__ void masked_multihead_attention_kernel(
|
||||
const int tlength = DO_CROSS_ATTENTION
|
||||
? params.memory_length_per_sample[batch_beam_idx] - 1
|
||||
: (params.length_per_sample ? (params.length_per_sample[batch_beam_idx] - 1) : static_cast<int>(timestep));
|
||||
// We will use cyclic kv cache when it exceeds the limit.
|
||||
// The length position for storing new key and value.
|
||||
const int cyclic_tlength = tlength % cyclic_kv_cache_len;
|
||||
// The actual kv cache length.
|
||||
// tlength is the past length actually.
|
||||
const int kv_loop_length = min(tlength, cyclic_kv_cache_len);
|
||||
// The context length for beam searching optimization (all points to beam 0).
|
||||
const int input_length = params.input_lengths[batch_beam_idx];
|
||||
// TODO: with cyclic kv cache, we set it 0 for now (will optimize in the future)
|
||||
// as context kv cache might be overwritten by the new kv cache
|
||||
const int beam0_context_length
|
||||
= HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];
|
||||
|
||||
// The offset in the Q and K buffer also accounts for the batch.
|
||||
const auto qk_vec_idx = tidx * QK_VEC_SIZE;
|
||||
@ -1474,8 +1484,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
if constexpr (DO_CROSS_ATTENTION)
|
||||
{
|
||||
const auto k_idx = QK_VEC_SIZE * tidx;
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tlength, hi, Dh, k_idx);
|
||||
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, tlength));
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi, Dh, k_idx);
|
||||
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));
|
||||
|
||||
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&k_cache[inBlockIdx]));
|
||||
}
|
||||
@ -1572,7 +1582,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
const bool do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
||||
|
||||
T* q_smem_ = reinterpret_cast<T*>(smem_);
|
||||
T* k_smem = q_smem_ + params.rotary_embedding_dim;
|
||||
T* k_smem_ = q_smem_ + params.rotary_embedding_dim;
|
||||
|
||||
const int half_rotary_dim = params.rotary_embedding_dim / 2;
|
||||
const int half_idx = qk_vec_idx / half_rotary_dim;
|
||||
@ -1586,7 +1596,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
*reinterpret_cast<Qk_vec_k*>(q_smem_ + half_idx * smem_pitch + intra_half_idx) = q;
|
||||
if (HANDLE_KV)
|
||||
{
|
||||
*reinterpret_cast<Qk_vec_k*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
|
||||
*reinterpret_cast<Qk_vec_k*>(k_smem_ + half_idx * smem_pitch + intra_half_idx) = k;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1599,12 +1609,12 @@ __global__ void masked_multihead_attention_kernel(
|
||||
mmha::vec_from_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
||||
if (HANDLE_KV)
|
||||
{
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, tlength);
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1621,7 +1631,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
q = *reinterpret_cast<Qk_vec_k*>(q_smem_ + half_idx * smem_pitch + intra_half_idx);
|
||||
if (HANDLE_KV)
|
||||
{
|
||||
k = *reinterpret_cast<Qk_vec_k*>(k_smem + half_idx * smem_pitch + intra_half_idx);
|
||||
k = *reinterpret_cast<Qk_vec_k*>(k_smem_ + half_idx * smem_pitch + intra_half_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1631,7 +1641,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
|
||||
// For the same reason as HANDLE_KV, no compute needed in Cross-Attention's 1st step
|
||||
|
||||
// Store Q K vectors to shared memory, and calculate QK.
|
||||
if (qk_vec_idx < Dh_MAX)
|
||||
{
|
||||
|
||||
@ -1658,31 +1668,10 @@ __global__ void masked_multihead_attention_kernel(
|
||||
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = is_valid_qk_vec ? q : zero_q;
|
||||
}
|
||||
|
||||
// Write the K values to the global memory cache.
|
||||
//
|
||||
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
||||
// system. We designed it this way as it allows much better memory loads (and there are many
|
||||
// more loads) + the stores are really "write and forget" since we won't need the ack before
|
||||
// the end of the kernel. There's plenty of time for the transactions to complete.
|
||||
|
||||
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
||||
if (HANDLE_KV && hi == (hi_kv * qhead_per_kv) && (IS_Dh_MAX || is_valid_qk_vec))
|
||||
{
|
||||
// Trigger the stores to global memory.
|
||||
const auto k_idx = QK_VEC_SIZE * tidx;
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tlength, hi_kv, Dh, k_idx);
|
||||
// The base pointer for the value in the cache buffer.
|
||||
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, tlength));
|
||||
|
||||
if constexpr (ENABLE_8BITS_CACHE)
|
||||
{
|
||||
store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k, inBlockIdx, kv_scale_orig_quant);
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<Qk_vec_m*>(&k_cache[inBlockIdx]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
|
||||
}
|
||||
}
|
||||
// Store the K values to shared memory.
|
||||
// We store K values from shared memory to global memory
|
||||
// when the target position of K cache in global memory has been accessed (in the case of cyclic kv cache)
|
||||
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;
|
||||
|
||||
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
|
||||
qk = dot<Qk_vec_accum, Qk_vec_k>(q, k);
|
||||
@ -1736,7 +1725,9 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
else
|
||||
{
|
||||
qk_smem[tlength] = qk;
|
||||
// We need to store the qk result to the end of the qk_smem for cyclic kv cache (+ 1 for smem memory
|
||||
// allocation) because the previous cache will still write to the new_cache_pos of qk_smem.
|
||||
qk_smem[kv_loop_length] = qk;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1778,17 +1769,22 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
||||
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
||||
const int context_length = HAS_BEAMS ? input_length : tlength;
|
||||
const int context_length = HAS_BEAMS ? beam0_context_length : kv_loop_length;
|
||||
const auto context_ti_end = MULTI_BLOCK_FLAG
|
||||
? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP
|
||||
: divUp(static_cast<unsigned>(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP;
|
||||
|
||||
// The generation ti_end.
|
||||
const auto generation_ti_end = MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP
|
||||
: divUp(static_cast<unsigned>(tlength), K_PER_WARP) * K_PER_WARP;
|
||||
const auto generation_ti_end = MULTI_BLOCK_FLAG
|
||||
? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP
|
||||
: divUp(static_cast<unsigned>(kv_loop_length), K_PER_WARP) * K_PER_WARP;
|
||||
|
||||
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
||||
const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * max_seq_len;
|
||||
// Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
|
||||
// By default, you can assume that they are the same.
|
||||
const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_kv_cache_length;
|
||||
// Beam indices are based on the max_kv_cache_length while each layer may have different cyclic_kv_cache_length
|
||||
// So we need to rebuild the beam_indices if max_kv_cache_length is not equal to cyclic_kv_cache_length.
|
||||
const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr;
|
||||
|
||||
const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG
|
||||
@ -1940,11 +1936,11 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
|
||||
// Handle generation key cache with beam searching.
|
||||
// Note that it may be overlapped with the context key loop, but it won't impact the correctness.
|
||||
if (HAS_BEAMS && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > input_length))
|
||||
// Note that it may be overlapped with the context key loop, but it won't impact the corretness.
|
||||
if (HAS_BEAMS && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length))
|
||||
{
|
||||
// The input length;
|
||||
const int input_length_ = MULTI_BLOCK_FLAG ? input_length % timesteps_per_block : input_length;
|
||||
const int input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length;
|
||||
// The beginning of the generation.
|
||||
const int generation_start_ti = k_idx.x + input_length_ / K_PER_WARP * K_PER_WARP;
|
||||
|
||||
@ -1960,7 +1956,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
||||
{
|
||||
const int jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
||||
const int valid_time_now = min(time_now, tlength - 1);
|
||||
const int valid_time_now = min(time_now, kv_loop_length - 1);
|
||||
int beam_offset = beam_indices[valid_time_now];
|
||||
const int seqIdx = batch_idx * beam_width + beam_offset;
|
||||
// Base pointer to k cache block for beam's batch, before offsetting with indirection buffer
|
||||
@ -1971,7 +1967,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
|
||||
// Is it active?
|
||||
const bool is_active = time_now >= input_length && time_now < tlength;
|
||||
const bool is_active = time_now >= context_length && time_now < kv_loop_length;
|
||||
|
||||
if (implicit_rel_attn_bias)
|
||||
{
|
||||
@ -2092,6 +2088,34 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Make sure the products are in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// After the syncthreads, the target k position (cyclic kv cache) should also have been used by the k loop.
|
||||
// Write the K values to the global memory cache.
|
||||
//
|
||||
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
||||
// system. We designed it this way as it allows much better memory loads (and there are many
|
||||
// more loads) + the stores are really "write and forget" since we won't need the ack before
|
||||
// the end of the kernel. There's plenty of time for the transactions to complete.
|
||||
|
||||
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
||||
if (HANDLE_KV && hi == (hi_kv * qhead_per_kv) && qk_vec_idx < Dh)
|
||||
{
|
||||
// Trigger the stores to global memory.
|
||||
Qk_vec_k k_vec = *reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx]);
|
||||
const auto k_idx = QK_VEC_SIZE * tidx;
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, k_idx);
|
||||
// The base pointer for the value in the cache buffer.
|
||||
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));
|
||||
|
||||
if constexpr (ENABLE_8BITS_CACHE)
|
||||
{
|
||||
store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k_vec, inBlockIdx, kv_scale_orig_quant);
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<Qk_vec_m*>(&k_cache[inBlockIdx]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k_vec);
|
||||
}
|
||||
}
|
||||
|
||||
// The warps finalize the reduction.
|
||||
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
@ -2107,7 +2131,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
float sum = 0.f;
|
||||
|
||||
// Each thread will handle one float (either qk_smem/logit).
|
||||
const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
||||
const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
||||
for (int ti = tidx; ti <= logit_loop_end; ti += THREADS_PER_BLOCK)
|
||||
{
|
||||
|
||||
@ -2123,13 +2147,13 @@ __global__ void masked_multihead_attention_kernel(
|
||||
else
|
||||
{
|
||||
// Not supported yet: multi-block mode with FP8_MHA
|
||||
if (time_now < tlength && ti != timesteps_per_block)
|
||||
if (time_now < kv_loop_length && ti != timesteps_per_block)
|
||||
{
|
||||
float logit = __expf(qk_smem[ti] - qk_max);
|
||||
sum += logit;
|
||||
qk_smem[ti] = logit;
|
||||
}
|
||||
else if (time_now == tlength)
|
||||
else if (time_now == kv_loop_length)
|
||||
{
|
||||
float logit = __expf(qk_current_smem[0] - qk_max);
|
||||
sum += logit;
|
||||
@ -2149,7 +2173,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
#endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
||||
float inv_sum = __fdividef(logit_scale, sum + 1.e-6f);
|
||||
|
||||
const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
||||
const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
||||
for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK)
|
||||
{
|
||||
const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
||||
@ -2161,11 +2185,11 @@ __global__ void masked_multihead_attention_kernel(
|
||||
else
|
||||
{
|
||||
// no scaling factor inv_sum applied here, will apply the scaling factor after all blocks finished
|
||||
if (time_now < tlength && ti != timesteps_per_block)
|
||||
if (time_now < kv_loop_length && ti != timesteps_per_block)
|
||||
{
|
||||
convert_from_float(&logits_smem[ti], qk_smem[ti]);
|
||||
}
|
||||
else if (time_now == tlength)
|
||||
else if (time_now == kv_loop_length)
|
||||
{
|
||||
convert_from_float(&logits_current_smem[0], qk_current_smem[0]);
|
||||
}
|
||||
@ -2198,7 +2222,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
V_vec_k v_bias;
|
||||
zero(v_bias);
|
||||
// if( vo == params.timestep % V_PER_ITER ) {
|
||||
if (is_valid_vi && HANDLE_KV && vo == tlength % V_PER_ITER)
|
||||
if (is_valid_vi && HANDLE_KV && vo == kv_loop_length % V_PER_ITER)
|
||||
{
|
||||
// Trigger the loads from the V bias buffer.
|
||||
if (params.v_bias != nullptr)
|
||||
@ -2236,9 +2260,9 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Handle both context and generation value cache without beam searching.
|
||||
// Explicit batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables.
|
||||
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
||||
const int context_length = HAS_BEAMS ? input_length : tlength;
|
||||
const int context_length = HAS_BEAMS ? beam0_context_length : kv_loop_length;
|
||||
int context_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length;
|
||||
int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
||||
int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
||||
for (int ti = vo; ti < context_v_loop_end; ti += UNROLLED_V_PER_ITER)
|
||||
{
|
||||
V_vec_m v_vec_cache[V_LOOP_UNROLL];
|
||||
@ -2247,7 +2271,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
{
|
||||
// Fetch offset based on cache_indir when beam sampling
|
||||
int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
||||
time_idx = min(time_idx, tlength - 1);
|
||||
time_idx = min(time_idx, kv_loop_length - 1);
|
||||
int rowIdx = batch_idx * beam_width;
|
||||
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
||||
@ -2278,16 +2302,17 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Handle generation value cache with beam searching.
|
||||
if (HAS_BEAMS)
|
||||
{
|
||||
const auto generation_start_ti = MULTI_BLOCK_FLAG ? vo : (vo + (input_length / V_PER_ITER) * V_PER_ITER);
|
||||
const auto generation_start_ti
|
||||
= MULTI_BLOCK_FLAG ? vo : (vo + (beam0_context_length / V_PER_ITER) * V_PER_ITER);
|
||||
// Only the last few blocks need to handle the generation value cache.
|
||||
if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > input_length)
|
||||
if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length)
|
||||
{
|
||||
for (int ti = generation_start_ti; ti < generation_v_loop_end; ti += V_PER_ITER)
|
||||
{
|
||||
// Fetch offset based on cache_indir when beam sampling
|
||||
int time_idx = ti + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
||||
int local_time_idx = ti;
|
||||
if (time_idx < input_length || (MULTI_BLOCK_FLAG && time_idx >= tlength))
|
||||
if (time_idx < beam0_context_length || (MULTI_BLOCK_FLAG && time_idx >= kv_loop_length))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
@ -2307,13 +2332,16 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we can overwrite the v cache if using cyclic kv cache.
|
||||
__syncthreads();
|
||||
|
||||
// Get the c_tile_id that handles the current timestep.
|
||||
const int ctile_idx = tlength / timesteps_per_block;
|
||||
|
||||
// One group of threads computes the product(s) for the current timestep.
|
||||
if (vo == tlength % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx)))
|
||||
if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx)))
|
||||
{
|
||||
const int tokenIdx = tlength;
|
||||
const int tokenIdx = cyclic_tlength;
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi);
|
||||
// The base pointer for the value in the cache buffer.
|
||||
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getBlockPtr(v_cache_base_row_ptr, tokenIdx));
|
||||
@ -2380,7 +2408,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
|
||||
if (!MULTI_BLOCK_FLAG)
|
||||
{
|
||||
out = fma(logits_smem[tlength], cast_to_float(v), out);
|
||||
out = fma(logits_smem[kv_loop_length], cast_to_float(v), out);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -2390,7 +2418,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// out = fma(logits_smem[params.timestep], v, out);
|
||||
if (!MULTI_BLOCK_FLAG)
|
||||
{
|
||||
out = fma(logits_smem[tlength], v, out);
|
||||
out = fma(logits_smem[kv_loop_length], v, out);
|
||||
}
|
||||
else
|
||||
{ // MULTI_BLOCK_FLAG // Not supported yet: multi-block mode with FP8_MHA
|
||||
|
||||
@ -133,8 +133,8 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets
|
||||
// This kernel computes the attention mask. We must compute this on-the-fly in the future.
|
||||
|
||||
template <typename AttentionMaskDataType>
|
||||
__global__ void computeAttentionMask(
|
||||
AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength, AttentionMaskType attentionMaskType)
|
||||
__global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength,
|
||||
int maxKvCacheLength, AttentionMaskType attentionMaskType)
|
||||
{
|
||||
// The index of the sequence in the batch.
|
||||
int batchIdx = blockIdx.y;
|
||||
@ -173,12 +173,22 @@ __global__ void computeAttentionMask(
|
||||
break;
|
||||
case AttentionMaskType::CAUSAL:
|
||||
isValid = rowIdx < seqLength && colIdx < seqLength && colIdx <= rowIdx;
|
||||
// Sliding_window_causal when there are not enough kv cache.
|
||||
isValid = isValid && colIdx >= max(0, rowIdx - maxKvCacheLength);
|
||||
// seq_length==4, max_seq_len==5
|
||||
// 1 0 0 0 0
|
||||
// 1 1 0 0 0
|
||||
// 1 1 1 0 0
|
||||
// 1 1 1 1 0
|
||||
// 0 0 0 0 0
|
||||
|
||||
// seq_length==6, max_seq_len==6, max_kv_cache_length = 2
|
||||
// 1 0 0 0 0 0
|
||||
// 1 1 0 0 0 0
|
||||
// 1 1 1 0 0 0
|
||||
// 0 1 1 1 0 0
|
||||
// 0 0 1 1 1 0
|
||||
// 0 0 0 1 1 1
|
||||
break;
|
||||
case AttentionMaskType::BIDIRECTIONAL:
|
||||
// clang-format off
|
||||
@ -222,8 +232,8 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams<T>& params, cudaStream_
|
||||
blocksPerSeq *= 2;
|
||||
}
|
||||
dim3 grid(blocksPerSeq, params.batchSize);
|
||||
computeAttentionMask<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
params.attentionMask, params.seqOffsets, params.maxSeqLength, params.attentionMaskType);
|
||||
computeAttentionMask<<<grid, THREADS_PER_BLOCK, 0, stream>>>(params.attentionMask, params.seqOffsets,
|
||||
params.maxSeqLength, params.maxKvCacheLength, params.attentionMaskType);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -73,6 +73,9 @@ struct BuildDecoderInfoParams
|
||||
int batchSize;
|
||||
// The maximum length of a sequence; it includes input and output.
|
||||
int maxSeqLength;
|
||||
// The kv cache capacity.
|
||||
// We will apply the limited_length_causal mask when there are not enough kv cache.
|
||||
int maxKvCacheLength;
|
||||
// The number of tokens in total. It's \sum_{ii=0}^{batchSize} seqLengths[ii].
|
||||
int numTokens;
|
||||
// The type of attention.
|
||||
|
||||
@ -4,6 +4,24 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
struct Vec2Type;
|
||||
|
||||
template <>
|
||||
struct Vec2Type<half>
|
||||
{
|
||||
using type = half2;
|
||||
};
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
template <>
|
||||
struct Vec2Type<__nv_bfloat16>
|
||||
{
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
#endif
|
||||
}; // namespace
|
||||
|
||||
template <typename T, int kProcessRows, typename AccessType>
|
||||
__global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols)
|
||||
@ -21,13 +39,18 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T*
|
||||
for (int i = 0; i < kProcessRows; ++i)
|
||||
{
|
||||
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<const AccessType*>(act + i * cols)[col_offset];
|
||||
if constexpr (std::is_same_v<T, half> && kElems % 2 == 0)
|
||||
if constexpr ((std::is_same_v<T, half>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
|| std::is_same_v<T, __nv_bfloat16>
|
||||
#endif
|
||||
) &&(kElems % 2 == 0))
|
||||
{
|
||||
using Vec2 = typename Vec2Type<T>::type;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; j += 2)
|
||||
{
|
||||
*reinterpret_cast<half2*>(act_vec + j)
|
||||
= __hmul2(*reinterpret_cast<half2*>(act_vec + j), *reinterpret_cast<half2*>(scale + j));
|
||||
*reinterpret_cast<Vec2*>(act_vec + j)
|
||||
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), *reinterpret_cast<Vec2*>(scale + j));
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -35,7 +58,7 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T*
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; ++j)
|
||||
{
|
||||
act_vec[j] *= scale[j];
|
||||
act_vec[j] = static_cast<T>(static_cast<float>(act_vec[j]) * static_cast<float>(scale[j]));
|
||||
}
|
||||
}
|
||||
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset] = *reinterpret_cast<AccessType*>(act_vec);
|
||||
@ -85,6 +108,9 @@ void apply_per_channel_scale_kernel_launcher(
|
||||
T * smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_PREQUANT_SCALE(half);
|
||||
#if defined(ENABLE_BF16)
|
||||
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16);
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -20,6 +20,10 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#if defined(ENABLE_BF16)
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@ -1629,7 +1629,7 @@ INSTANTIATE_TRANSPOSE_4D(half);
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer>
|
||||
__global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCacheBuffer kvCacheBuffer,
|
||||
const int headNum, const int sizePerHead, const int seqLen, const float* kvScaleOrigQuant,
|
||||
const int headNum, const int sizePerHead, const int seqLen, const int maxKvCacheLen, const float* kvScaleOrigQuant,
|
||||
const int* sequence_lengths)
|
||||
{
|
||||
// We allow only fp32/fp16/bf16 as input types
|
||||
@ -1655,14 +1655,20 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
|
||||
}
|
||||
|
||||
// Get linear token index
|
||||
const int tokenIdx = idx / sizePerHeadDivX;
|
||||
int tokenIdx = idx / sizePerHeadDivX;
|
||||
// Apply cyclic kv cache if tokenIdx >= max_kv_cache_length.
|
||||
// which means we will drop the tokens in the beginning if seqLen > max_kv_cache_length.
|
||||
const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - maxKvCacheLen, 0);
|
||||
// Get channel index
|
||||
const int channelIdx = idx % sizePerHeadDivX;
|
||||
if (tokenIdx >= sequence_lengths[batchIdx])
|
||||
if (tokenIdx >= sequence_lengths[batchIdx] || tokenIdx < tokenIdxLowerBound)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply cyclic kv cache if tokenIdx >= max_kv_cache_length.
|
||||
tokenIdx = tokenIdx % maxKvCacheLen;
|
||||
|
||||
// Get pointer to the dst block given sequence, head and token ids
|
||||
auto valDst = handle_k ? reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batchIdx, tokenIdx))
|
||||
: reinterpret_cast<T_dst*>(kvCacheBuffer.getVBlockPtr(batchIdx, tokenIdx));
|
||||
@ -1697,7 +1703,7 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, const int localBatchSize,
|
||||
const int seqLen, const int maxSeqLen, const int sizePerHead, const int localHeadNum,
|
||||
const int seqLen, const int maxKvCacheLen, const int sizePerHead, const int localHeadNum,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream)
|
||||
{
|
||||
// Block handles both K and V tile.
|
||||
@ -1710,25 +1716,25 @@ void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kv
|
||||
if (cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
transpose4dBatchMajorKVCache<T, int8_t, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, kvScaleOrigQuant, sequence_lengths);
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
transpose4dBatchMajorKVCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, kvScaleOrigQuant, sequence_lengths);
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
transpose4dBatchMajorKVCache<T, T, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, kvScaleOrigQuant, sequence_lengths);
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TRANSPOSE_4D_BATCH_MAJOR_KV_CACHE_TYPE(T, KVCacheBuffer) \
|
||||
template void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, \
|
||||
const int localBatchSize, const int seqLen, const int maxSeqLen, const int sizePerHead, \
|
||||
const int localBatchSize, const int seqLen, const int maxKvCacheLen, const int sizePerHead, \
|
||||
const int localHeadNum, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \
|
||||
const int* sequence_lengths, cudaStream_t stream)
|
||||
|
||||
|
||||
@ -105,16 +105,17 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer& kvTable, const int local_batch_size,
|
||||
const int seq_len, const int max_seq_len, const int size_per_head, const int local_head_num,
|
||||
const int seq_len, const int max_kv_cache_len, const int size_per_head, const int local_head_num,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream);
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len,
|
||||
const int token_num, const int head_num, const int kv_head_num, const int size_per_head,
|
||||
const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type,
|
||||
const float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename BT>
|
||||
void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, const int batch_size,
|
||||
|
||||
@ -65,9 +65,9 @@ struct Vec_t<__nv_bfloat16>
|
||||
template <typename T, typename T_cache, bool ADD_BIAS, typename KVCacheBuffer>
|
||||
__global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, const T* __restrict qkv_bias,
|
||||
const int* seq_lens, const int* padding_offset, const float* kvScaleOrigQuant, const int batch_size,
|
||||
const int seq_len, const int head_num, const int kv_head_num, const int size_per_head,
|
||||
const int rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type,
|
||||
float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
const int seq_len, const int cyclic_kv_cache_len, const int head_num, const int kv_head_num,
|
||||
const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base,
|
||||
RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
PositionEmbeddingType const position_embedding_type)
|
||||
{
|
||||
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
|
||||
@ -222,9 +222,11 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer,
|
||||
}
|
||||
|
||||
const int channelIdx{tidx};
|
||||
auto kDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batch_idx, token_idx_in_seq));
|
||||
auto vDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getVBlockPtr(batch_idx, token_idx_in_seq));
|
||||
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(token_idx_in_seq, kv_head_idx, sizePerHeadDivX, channelIdx);
|
||||
const bool valid_kv_cache_pos = token_idx_in_seq >= (actual_seq_len - cyclic_kv_cache_len);
|
||||
const int token_idx_in_kv_cache = token_idx_in_seq % cyclic_kv_cache_len;
|
||||
auto kDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
|
||||
auto vDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getVBlockPtr(batch_idx, token_idx_in_kv_cache));
|
||||
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(token_idx_in_kv_cache, kv_head_idx, sizePerHeadDivX, channelIdx);
|
||||
if (!is_masked)
|
||||
{
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
|
||||
@ -233,48 +235,24 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer,
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
|
||||
|
||||
if (ENABLE_8BITS_CACHE)
|
||||
if (valid_kv_cache_pos)
|
||||
{
|
||||
inBlockIdx = inBlockIdx * vec_size;
|
||||
// Cast float scale to dst data type.
|
||||
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
|
||||
T_scale scaleOrigQuant;
|
||||
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
|
||||
// Store 8bits kv cache.
|
||||
mmha::store_8bits_kv_cache_vec(kDst, k, inBlockIdx, scaleOrigQuant);
|
||||
mmha::store_8bits_kv_cache_vec(vDst, v, inBlockIdx, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = k;
|
||||
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (is_seq_masked && !is_head_size_masked)
|
||||
{
|
||||
// Set padding to zero in case of potential nan generated.
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = zero;
|
||||
if ((head_num == kv_head_num) || (head_idx == (kv_head_idx * qheads_per_kv_head)))
|
||||
{
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = zero;
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = zero;
|
||||
|
||||
if (ENABLE_8BITS_CACHE)
|
||||
{
|
||||
inBlockIdx = inBlockIdx * vec_size;
|
||||
// Cast float scale to dst data type.
|
||||
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
|
||||
T_scale scaleOrigQuant;
|
||||
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
|
||||
// Store 8bits kv cache.
|
||||
mmha::store_8bits_kv_cache_vec(kDst, zero, inBlockIdx, scaleOrigQuant);
|
||||
mmha::store_8bits_kv_cache_vec(vDst, zero, inBlockIdx, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = zero;
|
||||
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = zero;
|
||||
if (ENABLE_8BITS_CACHE)
|
||||
{
|
||||
inBlockIdx = inBlockIdx * vec_size;
|
||||
// Cast float scale to dst data type.
|
||||
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
|
||||
T_scale scaleOrigQuant;
|
||||
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
|
||||
// Store 8bits kv cache.
|
||||
mmha::store_8bits_kv_cache_vec(kDst, k, inBlockIdx, scaleOrigQuant);
|
||||
mmha::store_8bits_kv_cache_vec(vDst, v, inBlockIdx, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = k;
|
||||
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -282,11 +260,12 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer,
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer>
|
||||
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const float* kvScaleOrigQuant, const int int8_mode, cudaStream_t stream)
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len,
|
||||
const int token_num, const int head_num, const int kv_head_num, const int size_per_head,
|
||||
const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type,
|
||||
const float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const float* kvScaleOrigQuant,
|
||||
const int int8_mode, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO
|
||||
// To implement rotary embeddings, each thread processes two QKV elems:
|
||||
@ -298,26 +277,27 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, KVCacheBuffer& kvTable, co
|
||||
if (qkv_bias != nullptr)
|
||||
{
|
||||
applyBiasRopeUpdateKVCache<T, T_cache, true, KVCacheBuffer><<<grid, block, smem_size, stream>>>(QKV, kvTable,
|
||||
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num,
|
||||
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type);
|
||||
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, cyclic_kv_cache_len, head_num,
|
||||
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
applyBiasRopeUpdateKVCache<T, T_cache, false, KVCacheBuffer><<<grid, block, smem_size, stream>>>(QKV, kvTable,
|
||||
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num,
|
||||
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type);
|
||||
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, cyclic_kv_cache_len, head_num,
|
||||
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len,
|
||||
const int token_num, const int head_num, const int kv_head_num, const int size_per_head,
|
||||
const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type,
|
||||
const float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
|
||||
{
|
||||
// Block handles both K and V tile.
|
||||
constexpr int x = (sizeof(T) == 4) ? 4 : 8;
|
||||
@ -326,36 +306,37 @@ void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* q
|
||||
if (cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens,
|
||||
padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
padding_offset, batch_size, seq_len, cyclic_kv_cache_len, token_num, head_num, kv_head_num, size_per_head,
|
||||
rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens,
|
||||
padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
padding_offset, batch_size, seq_len, cyclic_kv_cache_len, token_num, head_num, kv_head_num, size_per_head,
|
||||
rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens, padding_offset,
|
||||
batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
batch_size, seq_len, cyclic_kv_cache_len, token_num, head_num, kv_head_num, size_per_head,
|
||||
rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(T, KVCacheBuffer) \
|
||||
template void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, \
|
||||
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, \
|
||||
const int head_num, const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, \
|
||||
const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, \
|
||||
const float rotary_embedding_scale, const int rotary_embedding_max_positions, \
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, \
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
|
||||
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, \
|
||||
const int cyclic_kv_cache_len, const int token_num, const int head_num, const int kv_head_num, \
|
||||
const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, \
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, \
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, \
|
||||
const float* scale, const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \
|
||||
cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVBlockArray);
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVLinearBuffer);
|
||||
|
||||
@ -19,6 +19,9 @@
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(ENABLE_BF16)
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <iostream>
|
||||
@ -27,34 +30,6 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
struct WeightOnlyParams
|
||||
{
|
||||
const uint8_t* qweight;
|
||||
const half* scales;
|
||||
const half* zeros;
|
||||
const half* in;
|
||||
const half* bias;
|
||||
half* out;
|
||||
const int m;
|
||||
const int n;
|
||||
const int k;
|
||||
const int group_size;
|
||||
|
||||
WeightOnlyParams(const uint8_t* _qweight, const half* _scales, const half* _zeros, const half* _in,
|
||||
const half* _bias, half* _out, const int _m, const int _n, const int _k, const int _group_size)
|
||||
: qweight(_qweight)
|
||||
, scales(_scales)
|
||||
, zeros(_zeros)
|
||||
, in(_in)
|
||||
, bias(_bias)
|
||||
, out(_out)
|
||||
, m(_m)
|
||||
, n(_n)
|
||||
, k(_k)
|
||||
, group_size(_group_size)
|
||||
{
|
||||
}
|
||||
};
|
||||
enum class WeightOnlyQuantType
|
||||
{
|
||||
Int4b,
|
||||
@ -70,12 +45,61 @@ struct WeightOnlyPerChannel;
|
||||
template <int GS>
|
||||
struct WeightOnlyGroupWise;
|
||||
|
||||
enum class WeightOnlyActivationType
|
||||
enum class WeightOnlyActivationFunctionType
|
||||
{
|
||||
Gelu,
|
||||
Relu,
|
||||
Identity,
|
||||
InvalidType
|
||||
};
|
||||
|
||||
enum class WeightOnlyActivationType
|
||||
{
|
||||
FP16,
|
||||
BF16
|
||||
};
|
||||
|
||||
struct WeightOnlyParams
|
||||
{
|
||||
// ActType is fp16 or bf16
|
||||
using ActType = void;
|
||||
using WeiType = uint8_t;
|
||||
|
||||
const uint8_t* qweight;
|
||||
const ActType* scales;
|
||||
const ActType* zeros;
|
||||
const ActType* in;
|
||||
const ActType* bias;
|
||||
ActType* out;
|
||||
const int m;
|
||||
const int n;
|
||||
const int k;
|
||||
const int group_size;
|
||||
WeightOnlyQuantType quant_type;
|
||||
WeightOnlyType weight_only_type;
|
||||
WeightOnlyActivationFunctionType act_func_type;
|
||||
WeightOnlyActivationType act_type;
|
||||
|
||||
WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
|
||||
const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, const int _group_size,
|
||||
const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
|
||||
const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
|
||||
: qweight(_qweight)
|
||||
, scales(_scales)
|
||||
, zeros(_zeros)
|
||||
, in(_in)
|
||||
, bias(_bias)
|
||||
, out(_out)
|
||||
, m(_m)
|
||||
, n(_n)
|
||||
, k(_k)
|
||||
, group_size(_group_size)
|
||||
, quant_type(_quant_type)
|
||||
, weight_only_type(_weight_only_type)
|
||||
, act_func_type(_act_func_type)
|
||||
, act_type(_act_type)
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -22,11 +22,51 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
template <WeightOnlyQuantType QType>
|
||||
struct WeightLayoutDetails;
|
||||
template <typename ActType>
|
||||
struct ActTypeDetails;
|
||||
|
||||
template <>
|
||||
struct WeightLayoutDetails<WeightOnlyQuantType::Int4b>
|
||||
struct ActTypeDetails<half>
|
||||
{
|
||||
using CutlassType = cutlass::half_t;
|
||||
using Vec2 = half2;
|
||||
|
||||
__device__ __forceinline__ static Vec2 to_vec2(half v)
|
||||
{
|
||||
return __half2half2(v);
|
||||
}
|
||||
};
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
template <>
|
||||
struct ActTypeDetails<__nv_bfloat16>
|
||||
{
|
||||
using CutlassType = cutlass::bfloat16_t;
|
||||
using Vec2 = __nv_bfloat162;
|
||||
|
||||
__device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v)
|
||||
{
|
||||
return __bfloat162bfloat162(v);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename ActType, WeightOnlyQuantType QType>
|
||||
struct ConverterSelector
|
||||
{
|
||||
static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b);
|
||||
|
||||
using WeiType = std::conditional_t<QType == WeightOnlyQuantType::Int4b, cutlass::uint4b_t, uint8_t>;
|
||||
static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4;
|
||||
using Converter
|
||||
= cutlass::FastInterleavedAndBiasedNumericArrayConverter<typename ActTypeDetails<ActType>::CutlassType, WeiType,
|
||||
kConvertCount>;
|
||||
};
|
||||
|
||||
template <typename ActType, WeightOnlyQuantType QType>
|
||||
struct WeightOnlyDetails;
|
||||
|
||||
template <typename ActType>
|
||||
struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int4b>
|
||||
{
|
||||
// Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread
|
||||
// processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads
|
||||
@ -49,16 +89,6 @@ struct WeightLayoutDetails<WeightOnlyQuantType::Int4b>
|
||||
static constexpr int kShuffleContinous = 4;
|
||||
static constexpr int kShuffleStrided = 4;
|
||||
|
||||
// The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4s_inplace
|
||||
// Input int8 data layout
|
||||
// [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits)
|
||||
//
|
||||
// Converted fp16 data layout
|
||||
// [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
|
||||
static constexpr int kConvertCount = 8;
|
||||
using Converter
|
||||
= cutlass::FastInterleavedAndBiasedNumericArrayConverter<cutlass::half_t, cutlass::uint4b_t, kConvertCount>;
|
||||
|
||||
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
|
||||
// corresponding address in shared memory
|
||||
template <int Num, int WarpSize>
|
||||
@ -85,8 +115,8 @@ struct WeightLayoutDetails<WeightOnlyQuantType::Int4b>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WeightLayoutDetails<WeightOnlyQuantType::Int8b>
|
||||
template <typename ActType>
|
||||
struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int8b>
|
||||
{
|
||||
// Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread
|
||||
// processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads
|
||||
@ -109,15 +139,6 @@ struct WeightLayoutDetails<WeightOnlyQuantType::Int8b>
|
||||
static constexpr int kShuffleContinous = 2;
|
||||
static constexpr int kShuffleStrided = 4;
|
||||
|
||||
// The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int8s_inplace
|
||||
// Input int8 data layout
|
||||
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
|
||||
//
|
||||
// Converted fp16 data layout
|
||||
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
|
||||
static constexpr int kConvertCount = 4;
|
||||
using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter<cutlass::half_t, uint8_t, kConvertCount>;
|
||||
|
||||
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
|
||||
// corresponding address in shared memory
|
||||
template <int Num, int WarpSize>
|
||||
@ -145,10 +166,10 @@ struct WeightLayoutDetails<WeightOnlyQuantType::Int8b>
|
||||
}
|
||||
};
|
||||
|
||||
template <WeightOnlyQuantType QType>
|
||||
template <typename ActType, WeightOnlyQuantType QType>
|
||||
struct WeightOnlyKernelDetails
|
||||
{
|
||||
using Layout = WeightLayoutDetails<QType>;
|
||||
using Layout = WeightOnlyDetails<ActType, QType>;
|
||||
|
||||
static constexpr int kElemBits = Layout::kElemBits;
|
||||
static constexpr int kInterleave = Layout::kInterleave;
|
||||
@ -159,8 +180,20 @@ struct WeightOnlyKernelDetails
|
||||
static constexpr int kShuffleContinous = Layout::kShuffleContinous;
|
||||
static constexpr int kShuffleStrided = Layout::kShuffleStrided;
|
||||
|
||||
using Converter = typename Layout::Converter;
|
||||
static constexpr int kConvertCount = Layout::kConvertCount;
|
||||
// The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace
|
||||
// Input int8 data layout
|
||||
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
|
||||
//
|
||||
// Converted fp16/bf16 data layout
|
||||
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
|
||||
|
||||
// Input int8 data layout
|
||||
// [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits)
|
||||
//
|
||||
// Converted fp16/bf16 data layout
|
||||
// [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
|
||||
static constexpr int kConvertCount = ConverterSelector<ActType, QType>::kConvertCount;
|
||||
using Converter = typename ConverterSelector<ActType, QType>::Converter;
|
||||
|
||||
// Use ldg128 load data from global memory
|
||||
static constexpr int kAccessSize = 128;
|
||||
@ -175,8 +208,8 @@ struct WeightOnlyKernelDetails
|
||||
static constexpr int kConvertIters = kElemsPerThread / kConvertCount;
|
||||
|
||||
// Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128
|
||||
// So more times of ldg128 are needed to load the same number of fp16 activation elements.
|
||||
static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(half) * 8);
|
||||
// So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements.
|
||||
static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8);
|
||||
static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess;
|
||||
};
|
||||
|
||||
@ -197,11 +230,11 @@ struct WeightOnlyProperties<WeightOnlyGroupWise<GS>>
|
||||
static constexpr int kGroupSize = GS;
|
||||
};
|
||||
|
||||
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, bool Zero, int BlockSize>
|
||||
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, bool Zero, int BlockSize>
|
||||
struct WeightOnlyScaleLoader
|
||||
{
|
||||
using ElemType = half;
|
||||
using Details = WeightOnlyKernelDetails<QType>;
|
||||
using ElemType = ActType;
|
||||
using Details = WeightOnlyKernelDetails<ActType, QType>;
|
||||
static constexpr bool kIsFineGrained = WeightOnlyProperties<WeightOnlyFlag>::kIsFineGrained;
|
||||
static constexpr int kGroupSize = WeightOnlyProperties<WeightOnlyFlag>::kGroupSize;
|
||||
|
||||
@ -258,19 +291,20 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias,
|
||||
int NPerBlock, int Batch, int BlockSize>
|
||||
__global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* scales, const half* zeros, const half* in,
|
||||
const half* bias, half* out, const int n, const int k)
|
||||
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
|
||||
bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize>
|
||||
__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
|
||||
const ActType* in, const ActType* bias, ActType* out, const int n, const int k)
|
||||
{
|
||||
static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0));
|
||||
using Details = WeightOnlyKernelDetails<QType>;
|
||||
using ActType2 = typename ActTypeDetails<ActType>::Vec2;
|
||||
using Details = WeightOnlyKernelDetails<ActType, QType>;
|
||||
|
||||
using Converter = typename Details::Converter;
|
||||
using AccType = typename Details::AccessType;
|
||||
using CvtSrcType = typename Converter::source_type;
|
||||
using CvtResType = typename Converter::result_type;
|
||||
using ScaleLoader = WeightOnlyScaleLoader<QType, WeightOnlyFlag, Zero, BlockSize>;
|
||||
using ScaleLoader = WeightOnlyScaleLoader<ActType, QType, WeightOnlyFlag, Zero, BlockSize>;
|
||||
extern __shared__ uint8_t shmem[];
|
||||
constexpr int Interleave = Details::kInterleave;
|
||||
constexpr int WarpSize = 32;
|
||||
@ -286,20 +320,20 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
|
||||
float(*sm)[Num * Interleave] = reinterpret_cast<float(*)[Num * Interleave]>(shmem);
|
||||
|
||||
// In order to take advantage of hfma2, we use fp16 for accumulation within threads and fp32 for accumulation
|
||||
// In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation
|
||||
// between threads.
|
||||
half accumulator[Num];
|
||||
ActType accumulator[Num];
|
||||
for (int i = 0; i < Num; ++i)
|
||||
{
|
||||
accumulator[i] = __float2half_rn(0.f);
|
||||
accumulator[i] = static_cast<ActType>(0.f);
|
||||
}
|
||||
|
||||
// Iteration in k dimensions
|
||||
for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave;
|
||||
local_k += BlockSize * Details::kElemsPerThread)
|
||||
{
|
||||
half weights_f16[Details::kElemsPerThread * NPerBlock];
|
||||
half scale[NPerBlock], zero[NPerBlock];
|
||||
ActType weights_f16[Details::kElemsPerThread * NPerBlock];
|
||||
ActType scale[NPerBlock], zero[NPerBlock];
|
||||
#pragma unroll
|
||||
for (int idx = 0; idx < NPerBlock; ++idx)
|
||||
{
|
||||
@ -308,7 +342,7 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
load<AccType>(weights_quantized,
|
||||
qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte);
|
||||
scale_loader.load(scale[idx], zero[idx], idx);
|
||||
half weights_vec[Details::kElemsPerThread];
|
||||
ActType weights_vec[Details::kElemsPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Details::kConvertIters; ++i)
|
||||
{
|
||||
@ -325,9 +359,10 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
{
|
||||
// Dequantize the weights and arrange the shuffled elements back to the correct order in the
|
||||
// register array
|
||||
half2 v = *reinterpret_cast<half2*>(weights_vec + i * Details::kShuffleBasicTile
|
||||
ActType2 v = *reinterpret_cast<ActType2*>(weights_vec + i * Details::kShuffleBasicTile
|
||||
+ j * Details::kShuffleContinous * Details::kShuffleBasicTile);
|
||||
v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));
|
||||
v = __hfma2(
|
||||
v, ActTypeDetails<ActType>::to_vec2(scale[idx]), ActTypeDetails<ActType>::to_vec2(zero[idx]));
|
||||
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
|
||||
+ j * Details::kShuffleBasicTile + 0)
|
||||
* NPerBlock
|
||||
@ -344,7 +379,7 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
#pragma unroll
|
||||
for (int b = 0; b < Batch; ++b)
|
||||
{
|
||||
half in_v[Details::kElemsPerThread];
|
||||
ActType in_v[Details::kElemsPerThread];
|
||||
#pragma unroll
|
||||
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
|
||||
{
|
||||
@ -355,11 +390,12 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
// Perform vector inner product and accumulate
|
||||
if constexpr (NPerBlock == 1)
|
||||
{
|
||||
half2 v = __float2half2_rn(0.f);
|
||||
ActType2 v = ActTypeDetails<ActType>::to_vec2(static_cast<ActType>(0.f));
|
||||
#pragma unroll
|
||||
for (int y = 0; y < Details::kElemsPerThread; y += 2)
|
||||
{
|
||||
v = __hfma2(*reinterpret_cast<half2*>(weights_f16 + y), *reinterpret_cast<half2*>(in_v + y), v);
|
||||
v = __hfma2(
|
||||
*reinterpret_cast<ActType2*>(weights_f16 + y), *reinterpret_cast<ActType2*>(in_v + y), v);
|
||||
}
|
||||
accumulator[b] += __hadd(v.x, v.y);
|
||||
}
|
||||
@ -371,9 +407,10 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
#pragma unroll
|
||||
for (int y = 0; y < Details::kElemsPerThread; ++y)
|
||||
{
|
||||
*reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2)
|
||||
= __hfma2(*reinterpret_cast<half2*>(weights_f16 + y * NPerBlock + x * 2),
|
||||
__half2half2(in_v[y]), *reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2));
|
||||
*reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2)
|
||||
= __hfma2(*reinterpret_cast<ActType2*>(weights_f16 + y * NPerBlock + x * 2),
|
||||
ActTypeDetails<ActType>::to_vec2(in_v[y]),
|
||||
*reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -384,7 +421,7 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Num; ++i)
|
||||
{
|
||||
reses[i] = __half2float(accumulator[i]);
|
||||
reses[i] = static_cast<float>(accumulator[i]);
|
||||
}
|
||||
|
||||
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
|
||||
@ -403,27 +440,64 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca
|
||||
float bias_v = 0.f;
|
||||
if constexpr (Bias)
|
||||
{
|
||||
bias_v = __half2float(bias[n_start_id + nid]);
|
||||
bias_v = static_cast<float>(bias[n_start_id + nid]);
|
||||
}
|
||||
int b = i / NPerBlock / Interleave;
|
||||
out[b * n + n_start_id + nid] = __float2half_rn(ActOp<float>::apply(v + bias_v));
|
||||
out[b * n + n_start_id + nid] = static_cast<ActType>(ActOp<float>::apply(v + bias_v));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
|
||||
bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize>
|
||||
__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
|
||||
const ActType* in, const ActType* bias, ActType* out, const int n, const int k)
|
||||
{
|
||||
if constexpr (std::is_same_v<ActType, half>)
|
||||
{
|
||||
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>(
|
||||
qweight, scales, zeros, in, bias, out, n, k);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
else if (std::is_same_v<ActType, nv_bfloat16>)
|
||||
{
|
||||
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>(
|
||||
qweight, scales, zeros, in, bias, out, n, k);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias,
|
||||
int NPerBlock, int Batch, int BlockSize>
|
||||
struct WeightOnlyBatchedGemvKernelLauncher
|
||||
{
|
||||
static constexpr int kInterleave = WeightLayoutDetails<QType>::kInterleave;
|
||||
|
||||
static void run(const WeightOnlyParams& params, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(params.n / NPerBlock / kInterleave);
|
||||
dim3 block(BlockSize);
|
||||
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
|
||||
weight_only_batched_gemv<QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>
|
||||
<<<grid, block, size, stream>>>(
|
||||
params.qweight, params.scales, params.zeros, params.in, params.bias, params.out, params.n, params.k);
|
||||
if (params.act_type == WeightOnlyActivationType::FP16)
|
||||
{
|
||||
constexpr int kInterleave = WeightOnlyDetails<half, QType>::kInterleave;
|
||||
dim3 grid(params.n / NPerBlock / kInterleave);
|
||||
dim3 block(BlockSize);
|
||||
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
|
||||
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch,
|
||||
BlockSize><<<grid, block, size, stream>>>(params.qweight, reinterpret_cast<const half*>(params.scales),
|
||||
reinterpret_cast<const half*>(params.zeros), reinterpret_cast<const half*>(params.in),
|
||||
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, params.k);
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (params.act_type == WeightOnlyActivationType::BF16)
|
||||
{
|
||||
constexpr int kInterleave = WeightOnlyDetails<nv_bfloat16, QType>::kInterleave;
|
||||
dim3 grid(params.n / NPerBlock / kInterleave);
|
||||
dim3 block(BlockSize);
|
||||
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
|
||||
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch,
|
||||
BlockSize><<<grid, block, size, stream>>>(params.qweight,
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.scales),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.zeros), reinterpret_cast<const __nv_bfloat16*>(params.in),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
|
||||
params.n, params.k);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace kernels
|
||||
|
||||
@ -55,21 +55,24 @@ void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream)
|
||||
}
|
||||
|
||||
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE>
|
||||
void select_activation(WeightOnlyActivationType atype, const WeightOnlyParams& params, cudaStream_t stream)
|
||||
void select_activation(const WeightOnlyParams& params, cudaStream_t stream)
|
||||
{
|
||||
switch (atype)
|
||||
switch (params.act_func_type)
|
||||
{
|
||||
case WeightOnlyActivationType::Gelu:
|
||||
// Currently, activation function is not called in the plugin
|
||||
#if 0
|
||||
case WeightOnlyActivationFunctionType::Gelu:
|
||||
{
|
||||
select_zero_bias<QType, WeightOnlyFlag, GeluActivation, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
break;
|
||||
}
|
||||
case WeightOnlyActivationType::Relu:
|
||||
case WeightOnlyActivationFunctionType::Relu:
|
||||
{
|
||||
select_zero_bias<QType, WeightOnlyFlag, ReluActivation, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
break;
|
||||
}
|
||||
case WeightOnlyActivationType::Identity:
|
||||
#endif
|
||||
case WeightOnlyActivationFunctionType::Identity:
|
||||
{
|
||||
select_zero_bias<QType, WeightOnlyFlag, IdentityActivation, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
break;
|
||||
@ -83,18 +86,15 @@ void select_activation(WeightOnlyActivationType atype, const WeightOnlyParams& p
|
||||
}
|
||||
|
||||
template <typename WeightOnlyFlag, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE>
|
||||
void select_quant_type(
|
||||
WeightOnlyQuantType qtype, WeightOnlyActivationType atype, const WeightOnlyParams& params, cudaStream_t stream)
|
||||
void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream)
|
||||
{
|
||||
if (qtype == WeightOnlyQuantType::Int4b)
|
||||
if (params.quant_type == WeightOnlyQuantType::Int4b)
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyFlag, N_PER_BLOCK, BATCH, BLOCK_SIZE>(
|
||||
atype, params, stream);
|
||||
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyFlag, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
}
|
||||
else if (qtype == WeightOnlyQuantType::Int8b)
|
||||
else if (params.quant_type == WeightOnlyQuantType::Int8b)
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyFlag, N_PER_BLOCK, BATCH, BLOCK_SIZE>(
|
||||
atype, params, stream);
|
||||
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyFlag, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -103,16 +103,15 @@ void select_quant_type(
|
||||
}
|
||||
|
||||
template <int N_PER_BLOCK, int BATCH, int BLOCK_SIZE>
|
||||
void select_groupwise_weight_only(WeightOnlyQuantType qtype, WeightOnlyType wtype, WeightOnlyActivationType atype,
|
||||
const WeightOnlyParams& params, cudaStream_t stream)
|
||||
void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream)
|
||||
{
|
||||
if (wtype == WeightOnlyType::GroupWise && params.group_size == 64)
|
||||
if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64)
|
||||
{
|
||||
select_quant_type<WeightOnlyGroupWise<64>, N_PER_BLOCK, BATCH, BLOCK_SIZE>(qtype, atype, params, stream);
|
||||
select_quant_type<WeightOnlyGroupWise<64>, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
}
|
||||
else if (wtype == WeightOnlyType::GroupWise && params.group_size == 128)
|
||||
else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128)
|
||||
{
|
||||
select_quant_type<WeightOnlyGroupWise<128>, N_PER_BLOCK, BATCH, BLOCK_SIZE>(qtype, atype, params, stream);
|
||||
select_quant_type<WeightOnlyGroupWise<128>, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -120,33 +119,40 @@ void select_groupwise_weight_only(WeightOnlyQuantType qtype, WeightOnlyType wtyp
|
||||
}
|
||||
}
|
||||
|
||||
void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType wtype, WeightOnlyActivationType atype,
|
||||
const WeightOnlyParams& params, cudaStream_t stream)
|
||||
void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream)
|
||||
{
|
||||
if (wtype == WeightOnlyType::PerChannel)
|
||||
assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity);
|
||||
assert(params.weight_only_type == WeightOnlyType::GroupWise
|
||||
|| (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr
|
||||
&& params.zeros == nullptr));
|
||||
if (params.weight_only_type == WeightOnlyType::PerChannel)
|
||||
{
|
||||
if (qtype == WeightOnlyQuantType::Int4b)
|
||||
if (params.quant_type == WeightOnlyQuantType::Int4b)
|
||||
{
|
||||
switch (params.m)
|
||||
{
|
||||
case 1:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, 1, 1, 192>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 1, 1, 192>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, 2, 2, 128>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 2, 128>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, 2, 3, 256>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 3, 256>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, 4, 4, 256>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 4, 4, 256>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -156,28 +162,32 @@ void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (qtype == WeightOnlyQuantType::Int8b)
|
||||
else if (params.quant_type == WeightOnlyQuantType::Int8b)
|
||||
{
|
||||
switch (params.m)
|
||||
{
|
||||
case 1:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, 2, 1, 256>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 1, 256>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, 2, 2, 256>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 2, 256>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, 2, 3, 256>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 3, 256>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, 2, 4, 256>(atype, params, stream);
|
||||
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 4, 256>::run(params, stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -188,28 +198,28 @@ void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (wtype == WeightOnlyType::GroupWise)
|
||||
else if (params.weight_only_type == WeightOnlyType::GroupWise)
|
||||
{
|
||||
switch (params.m)
|
||||
{
|
||||
case 1:
|
||||
{
|
||||
select_groupwise_weight_only<2, 1, 256>(qtype, wtype, atype, params, stream);
|
||||
select_groupwise_weight_only<2, 1, 256>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
select_groupwise_weight_only<2, 2, 256>(qtype, wtype, atype, params, stream);
|
||||
select_groupwise_weight_only<2, 2, 256>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
select_groupwise_weight_only<2, 3, 128>(qtype, wtype, atype, params, stream);
|
||||
select_groupwise_weight_only<2, 3, 128>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
select_groupwise_weight_only<2, 4, 128>(qtype, wtype, atype, params, stream);
|
||||
select_groupwise_weight_only<2, 4, 128>(params, stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
||||
@ -20,7 +20,6 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType wtype, WeightOnlyActivationType atype,
|
||||
const WeightOnlyParams& params, cudaStream_t stream);
|
||||
void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream);
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -21,46 +21,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 1, 1, 192>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 1, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b,
|
||||
IdentityActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 1, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -21,46 +21,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 1, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b,
|
||||
IdentityActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 1, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 1, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -20,46 +20,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 2, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 2, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
@ -68,22 +31,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b,
|
||||
IdentityActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 2, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -20,46 +20,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 2, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
@ -68,22 +31,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b,
|
||||
IdentityActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 2, 256>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 2, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -21,46 +21,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 3, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b,
|
||||
IdentityActivation, false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 3, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -21,46 +21,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 3, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 3, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b,
|
||||
IdentityActivation, false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 3, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 3, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -20,46 +20,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 4, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 4, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
@ -68,22 +31,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b,
|
||||
IdentityActivation, false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 4, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -21,46 +21,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, true, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
true, false, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, true, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, GeluActivation,
|
||||
false, false, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, true, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
true, false, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, true, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, ReluActivation,
|
||||
false, false, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, true, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, true, false, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, true, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel,
|
||||
IdentityActivation, false, false, 2, 4, 256>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, GeluActivation,
|
||||
false, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>, ReluActivation,
|
||||
false, false, 2, 4, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b,
|
||||
IdentityActivation, false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<64>,
|
||||
IdentityActivation, false, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
GeluActivation, false, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, true, false, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
ReluActivation, false, false, 2, 4, 128>;
|
||||
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
IdentityActivation, true, true, 2, 4, 128>;
|
||||
template struct WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyGroupWise<128>,
|
||||
|
||||
@ -29,39 +29,40 @@ namespace layers
|
||||
|
||||
__global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids,
|
||||
const bool* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, int local_batch_size,
|
||||
int beam_width, int max_seq_len)
|
||||
int beam_width, int max_kv_cache_length, int max_seq_len)
|
||||
{
|
||||
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int bb_id = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1
|
||||
const int input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]};
|
||||
const int batch_id = bb_id / beam_width;
|
||||
const int beam_id = bb_id % beam_width;
|
||||
if (bb_id >= beam_width * local_batch_size || time_step < input_length || finished[bb_id])
|
||||
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_kv_cache_length) || finished[bb_id])
|
||||
{
|
||||
return;
|
||||
}
|
||||
int time_step_circ = time_step % max_seq_len;
|
||||
// FIXME: we will remove all paddings later (@boyang)
|
||||
// Skip input paddings when updating the indir cache table.
|
||||
int time_step_circ = time_step % max_kv_cache_length;
|
||||
|
||||
// for the parent_ids, we will still keep it for all past tokens (i.e. max_seq_len)
|
||||
const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step];
|
||||
|
||||
const uint32_t tgt_offset = batch_id * beam_width * max_seq_len + beam_id * max_seq_len + time_step_circ;
|
||||
const uint32_t src_offset = batch_id * beam_width * max_seq_len + src_beam * max_seq_len + time_step_circ;
|
||||
// for the indir tables, we have the cyclic kv cache.
|
||||
const uint tgt_offset
|
||||
= batch_id * beam_width * max_kv_cache_length + beam_id * max_kv_cache_length + time_step_circ;
|
||||
const uint src_offset
|
||||
= batch_id * beam_width * max_kv_cache_length + src_beam * max_kv_cache_length + time_step_circ;
|
||||
|
||||
tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset];
|
||||
}
|
||||
|
||||
void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids,
|
||||
const bool* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, int local_batch_size,
|
||||
int beam_width, int max_seq_len, cudaStream_t stream)
|
||||
int beam_width, int max_seq_len, int max_kv_cache_length, cudaStream_t stream)
|
||||
{
|
||||
const dim3 block(32);
|
||||
// Update indirections steps [input_length[bb_id], sequence_lengths[bb_id]], included
|
||||
const dim3 grid((max_seq_len + block.x - 1) / block.x, local_batch_size * beam_width);
|
||||
update_indir_cache_kernel<<<grid, block, 0, stream>>>(tgt_indir_cache, src_indir_cache, parent_ids, finished,
|
||||
sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_seq_len);
|
||||
sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_kv_cache_length, max_seq_len);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -201,7 +202,8 @@ void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardPar
|
||||
update_indir_cache_kernelLauncher(outputs.tgt_cache_indirection.template getPtr<int>(),
|
||||
params.src_cache_indirection.template getPtr<const int>(),
|
||||
outputs.parent_ids_ptr.template getPtr<const int*>(), outputs.finished->template getPtr<const bool>(),
|
||||
sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len, stream_);
|
||||
sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len,
|
||||
params.max_kv_cache_length, stream_);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -54,15 +54,17 @@ public:
|
||||
class ForwardParams : public SoftmaxParams
|
||||
{
|
||||
public:
|
||||
ForwardParams(
|
||||
int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection, int max_seq_len)
|
||||
ForwardParams(int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection,
|
||||
int max_kv_cache_length, int max_seq_len)
|
||||
: SoftmaxParams(step, ite, std::move(logits), std::move(endIds))
|
||||
, src_cache_indirection{std::move(src_cache_indirection)}
|
||||
, max_kv_cache_length{max_kv_cache_length}
|
||||
, max_seq_len{max_seq_len}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
int max_kv_cache_length;
|
||||
int max_seq_len;
|
||||
tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len]
|
||||
|
||||
|
||||
@ -295,7 +295,8 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
auto const end_id_offset
|
||||
= end_ids.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
|
||||
typename BaseBeamSearchLayer<T>::ForwardParams dynamic_decode_input_tensors{step, ite, logits_offset,
|
||||
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(max_seq_len)};
|
||||
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_kv_cache_length),
|
||||
static_cast<std::int32_t>(max_seq_len)};
|
||||
|
||||
dynamic_decode_input_tensors.embedding_bias = params.embedding_bias;
|
||||
|
||||
|
||||
@ -79,10 +79,12 @@ public:
|
||||
class ForwardParams
|
||||
{
|
||||
public:
|
||||
ForwardParams(int step, int ite, int maxInputLength, int localBatchSize, tc::Tensor logits, tc::Tensor endIds)
|
||||
ForwardParams(int step, int ite, int maxInputLength, int maxKvCacheLength, int localBatchSize,
|
||||
tc::Tensor logits, tc::Tensor endIds)
|
||||
: step{step}
|
||||
, ite{ite}
|
||||
, max_input_length{maxInputLength}
|
||||
, max_kv_cache_length{maxKvCacheLength}
|
||||
, local_batch_size{localBatchSize}
|
||||
, logits{std::move(logits)}
|
||||
, end_ids{std::move(endIds)}
|
||||
@ -93,6 +95,7 @@ public:
|
||||
int step;
|
||||
int ite;
|
||||
int max_input_length;
|
||||
int max_kv_cache_length;
|
||||
int local_batch_size;
|
||||
tc::Tensor logits; // [batch_size, beam_width, vocab_size_padded], on gpu
|
||||
tc::Tensor end_ids; // [batch_size], on gpu
|
||||
|
||||
@ -252,7 +252,7 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
if (mEnableContextFMHA && !mRelativeAttention)
|
||||
{
|
||||
// b, max_seqlen, actual_total_seqlen
|
||||
mFMHARunner->setup(request_batch_size, request_seq_len, request_batch_size * request_seq_len);
|
||||
mFMHARunner->setup(request_batch_size, request_seq_len, request_seq_len, request_batch_size * request_seq_len);
|
||||
mFMHARunner->run(const_cast<T*>(attention_input), cu_seqlens, context_buf_, stream);
|
||||
}
|
||||
else
|
||||
|
||||
@ -83,7 +83,8 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
float rotary_embedding_scale;
|
||||
int rotary_embedding_max_positions;
|
||||
PositionEmbeddingType position_embedding_type;
|
||||
int max_seq_len;
|
||||
int max_kv_cache_length;
|
||||
int cyclic_kv_cache_length;
|
||||
const int* input_lengths;
|
||||
int step;
|
||||
float q_scaling;
|
||||
@ -157,7 +158,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.cache_indir = input_params.cache_indir;
|
||||
params.batch_size = input_params.inference_batch_size;
|
||||
params.beam_width = input_params.beam_width;
|
||||
params.memory_max_len = input_params.max_seq_len;
|
||||
params.max_kv_cache_length = input_params.max_kv_cache_length;
|
||||
params.cyclic_kv_cache_length = input_params.cyclic_kv_cache_length;
|
||||
params.length_per_sample = input_params.sequence_lengths; // max_input_length + current output length
|
||||
// timestep for shared memory size calculation and rotary embedding computation
|
||||
params.timestep = input_params.step - 1;
|
||||
@ -267,10 +269,13 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
|
||||
, mCrossAttention(cross_attention)
|
||||
, mMaxDistance(max_distance)
|
||||
{
|
||||
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF || mType == DataType::kBF16);
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF || mType == DataType::kBF16)
|
||||
&& MHARunner::fmha_supported(getHeadSize(), mSM);
|
||||
|
||||
TLLM_CHECK(isRoPE() == (rotary_embedding_dim != 0));
|
||||
TLLM_CHECK_WITH_INFO((tc::getSMVersion() >= 80) || (mType != DataType::kBF16),
|
||||
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(mSM >= 80) || (mType != DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
}
|
||||
|
||||
const int GPTAttentionPluginCommon::getHeadSize(bool checkInit) const
|
||||
@ -318,8 +323,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng
|
||||
mKVCacheQuantMode = tc::QuantMode(kvCacheQuantMode);
|
||||
|
||||
TLLM_CHECK(d == a + length);
|
||||
TLLM_CHECK_WITH_INFO((tc::getSMVersion() >= 80) || (mType != DataType::kBF16),
|
||||
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(mSM >= 80) || (mType != DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
}
|
||||
|
||||
size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(
|
||||
@ -380,12 +385,12 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration(DataType type, in
|
||||
size_t generation_workspace_size = 0;
|
||||
|
||||
const int batch_beam = total_num_seq;
|
||||
int32_t const maxSeqLenTile = getMaxSeqLenTile(size);
|
||||
int32_t const maxSeqLenTile = getMaxNumSeqLenTile();
|
||||
|
||||
const size_t partial_out_size = mMultiBlockMode ? size * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile : 0;
|
||||
const size_t partial_sum_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
||||
const size_t partial_max_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
||||
const size_t block_counter_size = mMultiBlockMode ? sizeof(int) * batch_beam * mNumHeads : 0;
|
||||
const size_t partial_out_size = size * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile;
|
||||
const size_t partial_sum_size = sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile;
|
||||
const size_t partial_max_size = sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile;
|
||||
const size_t block_counter_size = sizeof(int) * batch_beam * mNumHeads;
|
||||
|
||||
const int NUM_BUFFERS = 4;
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
@ -397,18 +402,13 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration(DataType type, in
|
||||
return generation_workspace_size;
|
||||
}
|
||||
|
||||
int GPTAttentionPluginCommon::getMaxSeqLenTile(int elemSize) const
|
||||
int GPTAttentionPluginCommon::getMaxNumSeqLenTile(int batch_beam_size) const
|
||||
{
|
||||
if (mMultiBlockMode)
|
||||
{
|
||||
const int threads_per_value = pow2roundup(getHeadSize()) * elemSize / 16;
|
||||
|
||||
// max_seq_len_tile to make sure: seq_len_tile * threads_per_value <= threads_per_block (for
|
||||
// multi_block_mode)
|
||||
const int max_seq_len_tile
|
||||
= 256 / threads_per_value; // for allocate partial output results memory. Regardless to THDS_PER_BLOCK
|
||||
// (which may be smaller than 256 like being 128)
|
||||
return max_seq_len_tile;
|
||||
// And we allocate the buffer based on the maximum number of blocks per sequence (batch_beam_size = 1).
|
||||
// Assume we can only have 1 block (large block size like 1024) in SM, and we only want one wave of blocks.
|
||||
return tc::divUp(mMultiProcessorCount, batch_beam_size * mNumHeads);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@ -439,7 +439,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
{
|
||||
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
||||
kv_cache_buffer = KVCacheBuffer(params.batch_size, 1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.max_seq_length, num_kv_heads * head_size * elem_size);
|
||||
isCrossAttention() ? params.cross_qkv_length : params.max_kv_cache_length,
|
||||
num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
|
||||
}
|
||||
|
||||
@ -524,6 +525,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
decoder_params.seqLengths = params.context_lengths;
|
||||
decoder_params.batchSize = params.batch_size;
|
||||
decoder_params.maxSeqLength = params.input_seq_length;
|
||||
decoder_params.maxKvCacheLength = params.cyclic_kv_cache_length;
|
||||
decoder_params.numTokens = params.num_tokens;
|
||||
decoder_params.attentionMaskType = mMaskType;
|
||||
invokeBuildDecoderInfo(decoder_params, stream);
|
||||
@ -563,12 +565,14 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), kv_cache_buffer,
|
||||
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
||||
params.batch_size, params.input_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
||||
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
||||
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type,
|
||||
params.kv_scale_orig_quant, stream);
|
||||
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.num_tokens, isALiBi(), isAliBiWithScale(),
|
||||
mTpSize, mTpRank);
|
||||
params.batch_size, params.input_seq_length, params.cyclic_kv_cache_length, params.num_tokens, mNumHeads,
|
||||
mNumKVHeads, getHeadSize(), mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType,
|
||||
mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0,
|
||||
cache_type, params.kv_scale_orig_quant, stream);
|
||||
// we will apply limited_length_causal when the max_past_length(cyclic_kv_cache_length) is set.
|
||||
// the token will pay attention to previous tokens while starting from max(0, rowIdx - cyclic_kv_cache_length);
|
||||
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.cyclic_kv_cache_length, params.num_tokens,
|
||||
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
|
||||
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_seqlens, params.context_buf, stream);
|
||||
}
|
||||
else
|
||||
@ -611,7 +615,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// write KV to cache
|
||||
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, params.batch_size,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.input_seq_length,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.max_seq_length, getHeadSize(), mNumKVHeads,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, getHeadSize(), mNumKVHeads,
|
||||
cache_type, params.kv_scale_orig_quant,
|
||||
isCrossAttention() ? params.encoder_input_lengths : params.context_lengths, stream);
|
||||
sync_check_cuda_error();
|
||||
@ -704,8 +708,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// In implicit mode, relative_attention_bias is relative_attention_table [num_heads, num_buckets], with
|
||||
// necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attention_bias, params.batch_size,
|
||||
mNumHeads, attention_seq_len_1, isCrossAttention() ? params.cross_qkv_length : params.max_seq_length,
|
||||
stream, max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
mNumHeads, attention_seq_len_1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, stream, max_distance > 0,
|
||||
relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
if (is_qk_buf_float_ == true)
|
||||
@ -808,6 +813,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -865,19 +871,23 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
|
||||
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
|
||||
size_t offset = 0;
|
||||
int32_t const maxSeqLenTile = getMaxSeqLenTile(sizeof(T));
|
||||
// Runtime check to see the actual number of blocks per sequence we need.
|
||||
int32_t const max_num_seq_len_tiles = getMaxNumSeqLenTile(batch_beam);
|
||||
const bool enable_multi_block = mMultiBlockMode && max_num_seq_len_tiles > 1;
|
||||
const size_t partial_out_size
|
||||
= mMultiBlockMode ? sizeof(T) * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile : 0;
|
||||
const size_t partial_sum_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
||||
const size_t partial_max_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
||||
const size_t block_counter_size = mMultiBlockMode ? sizeof(int) * batch_beam * mNumHeads : 0;
|
||||
= enable_multi_block ? sizeof(T) * batch_beam * mNumHeads * mHeadSize * max_num_seq_len_tiles : 0;
|
||||
const size_t partial_sum_size
|
||||
= enable_multi_block ? sizeof(float) * batch_beam * mNumHeads * max_num_seq_len_tiles : 0;
|
||||
const size_t partial_max_size
|
||||
= enable_multi_block ? sizeof(float) * batch_beam * mNumHeads * max_num_seq_len_tiles : 0;
|
||||
const size_t block_counter_size = enable_multi_block ? sizeof(int) * batch_beam * mNumHeads : 0;
|
||||
|
||||
// Workspace pointer shift
|
||||
T* partial_out = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, partial_out_size));
|
||||
float* partial_sum = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, partial_sum_size));
|
||||
float* partial_max = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, partial_max_size));
|
||||
int* block_counter = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, block_counter_size));
|
||||
if (mMultiBlockMode)
|
||||
if (enable_multi_block)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(block_counter, 0, block_counter_size, stream));
|
||||
}
|
||||
@ -894,7 +904,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
else
|
||||
{
|
||||
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
||||
kv_cache_buffer = KVCacheBuffer(batch_beam, 1, params.max_seq_length, num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer
|
||||
= KVCacheBuffer(batch_beam, 1, params.max_kv_cache_length, num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
@ -919,7 +930,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.size_per_head = getHeadSize();
|
||||
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
|
||||
dispatch_params.position_embedding_type = mPositionEmbeddingType;
|
||||
dispatch_params.max_seq_len = params.max_seq_length; // difference between max_seq_lengths and max_seq_length?
|
||||
dispatch_params.max_kv_cache_length = params.max_kv_cache_length;
|
||||
dispatch_params.cyclic_kv_cache_length = params.cyclic_kv_cache_length;
|
||||
dispatch_params.input_lengths = params.context_lengths;
|
||||
dispatch_params.step = step;
|
||||
dispatch_params.q_scaling = q_scaling;
|
||||
@ -931,8 +943,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.qkv_scale_out = qkv_scale_out;
|
||||
dispatch_params.attention_out_scale = attention_out_scale;
|
||||
dispatch_params.quant_option = quant_option;
|
||||
dispatch_params.multi_block_mode = mMultiBlockMode;
|
||||
dispatch_params.max_seq_len_tile = getMaxSeqLenTile(sizeof(T));
|
||||
dispatch_params.multi_block_mode = enable_multi_block;
|
||||
dispatch_params.max_seq_len_tile = max_num_seq_len_tiles;
|
||||
dispatch_params.partial_out = partial_out;
|
||||
dispatch_params.partial_sum = partial_sum;
|
||||
dispatch_params.partial_max = partial_max;
|
||||
@ -962,6 +974,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
Cross_multihead_attention_params<DataType> mmhca_params;
|
||||
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ public:
|
||||
const int getHeadSize(bool checkInit = true) const;
|
||||
|
||||
protected:
|
||||
int getMaxSeqLenTile(int elemSize) const;
|
||||
int getMaxNumSeqLenTile(int batch_beam_size = 1) const;
|
||||
size_t getWorkspaceSizeForContext(
|
||||
nvinfer1::DataType type, int32_t nbReq, int32_t max_input_length, int32_t cross_qkv_length = 0) const noexcept;
|
||||
// total_num_seq is the sum of beam_width for multiple requests
|
||||
@ -85,7 +85,12 @@ protected:
|
||||
T const* attention_input;
|
||||
T const* qkv_bias;
|
||||
int32_t input_seq_length; // padded input length
|
||||
int32_t max_seq_length; // cache capacity
|
||||
// By default, max_kv_cache_length == cyclic_kv_cache_length
|
||||
// unless each layer has different cyclic kv cache length.
|
||||
// Max cache capacity (used to allocate KV cache)
|
||||
int32_t max_kv_cache_length;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int32_t cyclic_kv_cache_length;
|
||||
int32_t const* context_lengths;
|
||||
float const* kv_scale_orig_quant;
|
||||
float const* kv_scale_quant_orig;
|
||||
@ -125,7 +130,12 @@ protected:
|
||||
T* context_buf;
|
||||
void* key_value_cache;
|
||||
void* block_pointers;
|
||||
int32_t max_seq_length; // cache capacity
|
||||
// By default, max_kv_cache_length == cyclic_kv_cache_length
|
||||
// unless each layer has different cyclic kv cache length.
|
||||
// Max cache capacity (used to allocate KV cache)
|
||||
int32_t max_kv_cache_length;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int32_t cyclic_kv_cache_length;
|
||||
int32_t num_requests;
|
||||
int32_t max_blocks_per_sequence;
|
||||
int32_t const* cache_indir;
|
||||
|
||||
@ -84,8 +84,8 @@ nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions(
|
||||
bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == getSequenceLengthIdx() || pos == getHostPastKeyValueLengthsIdx() || pos == getContextLengthsIdx()
|
||||
|| pos == getCacheIndirIdx() || pos == getRequestTypesIdx())
|
||||
if (pos == getSequenceLengthIdx() || pos == getHostPastKeyValueLengthsIdx() || pos == getHostMaxKvCacheLengthIdx()
|
||||
|| pos == getContextLengthsIdx() || pos == getCacheIndirIdx() || pos == getRequestTypesIdx())
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
@ -131,9 +131,6 @@ void GPTAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
|
||||
{
|
||||
TLLM_CHECK(mHeadSize > 0);
|
||||
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
mEnableContextFMHA = mEnableContextFMHA && MHARunner::fmha_supported(getHeadSize(), mSM);
|
||||
}
|
||||
|
||||
size_t GPTAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
@ -258,7 +255,15 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
// -- max_encoder_context_len: len of encoder input (in cross attn). Also called encoder_input_seq_length
|
||||
|
||||
const int beamWidth = inputDesc[getCacheIndirIdx()].dims.d[1];
|
||||
const int maxSeqLen = isCrossAttention() ? max_encoder_context_len : inputDesc[getCacheIndirIdx()].dims.d[2];
|
||||
|
||||
// Commonly, cyclic kv cache length, and max kv cache length will be the same
|
||||
// unless each layer has different max kv cache length.
|
||||
// the kv_cache capacity.
|
||||
const int max_kv_cache_length
|
||||
= isCrossAttention() ? max_encoder_context_len : inputDesc[getCacheIndirIdx()].dims.d[2];
|
||||
// The cyclic_kv_cache_length will determine the cyclic kv cache position of new tokens.
|
||||
// Note that this cyclic_kv_cache_length might be smaller than the actual kv cache capactity (max_kv_cache_length).
|
||||
const int cyclic_kv_cache_length = reinterpret_cast<const int*>(inputs[getHostMaxKvCacheLengthIdx()])[0];
|
||||
|
||||
const float* kv_scale_orig_quant = nullptr;
|
||||
const float* kv_scale_quant_orig = nullptr;
|
||||
@ -308,9 +313,10 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
num_encoder_tokens = inputDesc[getCrossQKVIdx()].dims.d[1];
|
||||
}
|
||||
|
||||
EnqueueContextParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, max_context_len, maxSeqLen,
|
||||
context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache,
|
||||
block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
|
||||
EnqueueContextParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, max_context_len,
|
||||
max_kv_cache_length, cyclic_kv_cache_length, context_lengths, kv_scale_orig_quant, kv_scale_quant_orig,
|
||||
alibi_slopes, context_buf_, key_value_cache, block_pointers, batch_size, localNbTokens,
|
||||
max_blocks_per_sequence, workspace};
|
||||
if (isRelativePosition())
|
||||
{
|
||||
enqueue_params.relative_attention_bias = static_cast<const T*>(inputs[getRelativeAttentionBiasIdx()]);
|
||||
@ -340,8 +346,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
int32_t const past_kv_len = *std::max_element(past_kv_len_list, past_kv_len_list + localNbSeq);
|
||||
EnqueueGenerationParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, sequence_length,
|
||||
past_kv_len, beamWidth, context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes,
|
||||
context_buf_, key_value_cache, block_pointers, maxSeqLen, num_requests, max_blocks_per_sequence,
|
||||
cache_indir, workspace};
|
||||
context_buf_, key_value_cache, block_pointers, max_kv_cache_length, cyclic_kv_cache_length, num_requests,
|
||||
max_blocks_per_sequence, cache_indir, workspace};
|
||||
if (isRelativePosition())
|
||||
{
|
||||
enqueue_params.relative_attention_bias = static_cast<const T*>(inputs[getRelativeAttentionBiasIdx()]);
|
||||
|
||||
@ -43,9 +43,10 @@ namespace tensorrt_llm::plugins
|
||||
// enable_remove_input_padding
|
||||
// 1. sequence_length [batch_size]
|
||||
// 2. host_past_key_value_lengths [batch_size] (int32)
|
||||
// 3. context_lengths [batch_size]
|
||||
// 4. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch)
|
||||
// 5. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching
|
||||
// 3. host_max_kv_cache_lengths [1] (int32)
|
||||
// 4. context_lengths [batch_size]
|
||||
// 5. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch)
|
||||
// 6. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching
|
||||
// mode,
|
||||
// all elements must be identical.
|
||||
// 6. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
|
||||
@ -145,46 +146,51 @@ private:
|
||||
return 2;
|
||||
}
|
||||
|
||||
IndexType getContextLengthsIdx() const
|
||||
IndexType getHostMaxKvCacheLengthIdx() const
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
|
||||
IndexType getCacheIndirIdx() const
|
||||
IndexType getContextLengthsIdx() const
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
|
||||
IndexType getRequestTypesIdx() const
|
||||
IndexType getCacheIndirIdx() const
|
||||
{
|
||||
return 5;
|
||||
}
|
||||
|
||||
IndexType getRequestTypesIdx() const
|
||||
{
|
||||
return 6;
|
||||
}
|
||||
|
||||
IndexType getKVCacheBlockPointersIdx() const
|
||||
{
|
||||
// NOTE We either provide this tensor when mPagedKVCache is true or PastKeyValue otherwise
|
||||
return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
IndexType getPastKeyValueIdx() const
|
||||
{
|
||||
// NOTE We either provide this tensor when mPagedKVCache is false or KVCacheBlockPointers otherwise
|
||||
return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
IndexType getKVCacheQuantizationScaleIdx() const
|
||||
{
|
||||
return 7;
|
||||
return 8;
|
||||
}
|
||||
|
||||
IndexType getKVCacheDequantizationScaleIdx() const
|
||||
{
|
||||
return 8;
|
||||
return 9;
|
||||
}
|
||||
|
||||
IndexType getAlibiSlopesIdx() const
|
||||
{
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 10 : 8);
|
||||
}
|
||||
|
||||
IndexType getRelativeAttentionBiasIdx() const
|
||||
@ -216,7 +222,7 @@ private:
|
||||
IndexType getQKVBiasTensorIdx() const
|
||||
{
|
||||
TLLM_CHECK(mQKVBiasEnabled);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (isALiBi() ? 1 : 0) + (mRemovePadding ? 1 : 0);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 10 : 8) + (isALiBi() ? 1 : 0) + (mRemovePadding ? 1 : 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -150,14 +150,32 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
|
||||
}
|
||||
mCudaKernelEnabled
|
||||
= tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b);
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
if (quant_algo & ZERO)
|
||||
{
|
||||
// has zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// no zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported data type");
|
||||
}
|
||||
|
||||
mCudaKernelEnabled
|
||||
= tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b);
|
||||
mPluginProfiler->setQuantAlgo(mQuantAlgo);
|
||||
mPluginProfiler->setGroupSize(mGroupSize);
|
||||
|
||||
@ -295,27 +313,52 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
|
||||
if (mQuantAlgo & PRE_QUANT_SCALE)
|
||||
{
|
||||
// Apply pre-quant per channel scale on activations
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<half>(reinterpret_cast<half*>(workspace),
|
||||
reinterpret_cast<const half*>(inputs[0]), reinterpret_cast<const half*>(inputs[mPreQuantScaleInputIdx]), m,
|
||||
k, stream);
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<half>(reinterpret_cast<half*>(workspace),
|
||||
reinterpret_cast<const half*>(inputs[0]), reinterpret_cast<const half*>(inputs[mPreQuantScaleInputIdx]),
|
||||
m, k, stream);
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<__nv_bfloat16>(
|
||||
reinterpret_cast<__nv_bfloat16*>(workspace), reinterpret_cast<const __nv_bfloat16*>(inputs[0]),
|
||||
reinterpret_cast<const __nv_bfloat16*>(inputs[mPreQuantScaleInputIdx]), m, k, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast<const half*>(inputs[mZerosInputIdx]) : nullptr;
|
||||
const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast<const half*>(inputs[mBiasesInputIdx]) : nullptr;
|
||||
const half* act_ptr = reinterpret_cast<const half*>((mQuantAlgo & PRE_QUANT_SCALE) ? workspace : inputs[0]);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration");
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF
|
||||
#if defined(ENABLE_BF16)
|
||||
|| mType == nvinfer1::DataType::kBF16
|
||||
#endif
|
||||
,
|
||||
"No valid weightOnlyGropwiseQuantMatmul configuration");
|
||||
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type;
|
||||
int real_n = n * INT8_INT4_RATIO;
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
|
||||
}
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16;
|
||||
}
|
||||
if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled)
|
||||
{
|
||||
// Use CUDA kernels for small batch size
|
||||
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel
|
||||
// when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
|
||||
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[mWeightInputIdx]),
|
||||
reinterpret_cast<const half*>(inputs[mScalesInputIdx]), zeros_ptr, act_ptr, biases_ptr,
|
||||
reinterpret_cast<half*>(outputs[0]), m, n * INT8_INT4_RATIO, k, mGroupSize};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b,
|
||||
tensorrt_llm::kernels::WeightOnlyType::GroupWise, tensorrt_llm::kernels::WeightOnlyActivationType::Identity,
|
||||
params, stream);
|
||||
inputs[mScalesInputIdx], zeros_ptr, act_ptr, biases_ptr, outputs[0], m, real_n, k, mGroupSize,
|
||||
tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, tensorrt_llm::kernels::WeightOnlyType::GroupWise,
|
||||
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -326,9 +369,8 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
|
||||
|
||||
const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId);
|
||||
TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic");
|
||||
m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, reinterpret_cast<cutlass::uint4b_t*>(weight_ptr),
|
||||
reinterpret_cast<const half*>(inputs[mScalesInputIdx]), zeros_ptr, biases_ptr,
|
||||
reinterpret_cast<half*>(outputs[0]), m, n * INT8_INT4_RATIO, k, mGroupSize, *bestTactic,
|
||||
m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, weight_ptr, inputs[mScalesInputIdx], zeros_ptr, biases_ptr,
|
||||
outputs[0], m, real_n, k, mGroupSize, *bestTactic,
|
||||
reinterpret_cast<char*>(workspace) + m * k * sizeof(half), ws_bytes, stream);
|
||||
}
|
||||
|
||||
|
||||
@ -108,17 +108,46 @@ void WeightOnlyQuantMatmulPlugin::init(nvinfer1::DataType type, WeightTypeId wei
|
||||
{
|
||||
mType = type;
|
||||
mWeightTypeId = weightTypeId;
|
||||
if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT8)
|
||||
if (mWeightTypeId == WeightTypeId::INT8)
|
||||
{
|
||||
m_weightOnlyGemmRunner = std::make_shared<
|
||||
CutlassFpAIntBGemmRunner<half, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>();
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
m_weightOnlyGemmRunner = std::make_shared<
|
||||
CutlassFpAIntBGemmRunner<half, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>();
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
m_weightOnlyGemmRunner = std::make_shared<
|
||||
CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>();
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_CHECK(false);
|
||||
}
|
||||
|
||||
mCudaKernelEnabled
|
||||
= tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int8b);
|
||||
}
|
||||
else if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT4)
|
||||
else if (mWeightTypeId == WeightTypeId::INT4)
|
||||
{
|
||||
m_weightOnlyGemmRunner = std::make_shared<
|
||||
CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>();
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
m_weightOnlyGemmRunner = std::make_shared<
|
||||
CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>();
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
m_weightOnlyGemmRunner = std::make_shared<CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
|
||||
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>();
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_CHECK(false);
|
||||
}
|
||||
mCudaKernelEnabled
|
||||
= tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b);
|
||||
}
|
||||
@ -259,50 +288,49 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input
|
||||
const int ws_size = m_weightOnlyGemmRunner->getWorkspaceSize(m, n, k);
|
||||
const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId);
|
||||
TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic");
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyQuantMatmul configuration");
|
||||
if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT8)
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF ||
|
||||
#if defined(ENABLE_BF16)
|
||||
mType == nvinfer1::DataType::kBF16
|
||||
#endif
|
||||
,
|
||||
"No valid weightOnlyQuantMatmul configuration");
|
||||
|
||||
tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type;
|
||||
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type;
|
||||
int real_n;
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled)
|
||||
{
|
||||
// Use CUDA kernels for small batch size
|
||||
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel
|
||||
// when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
|
||||
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[1]),
|
||||
reinterpret_cast<const half*>(inputs[2]), nullptr, reinterpret_cast<const half*>(inputs[0]), nullptr,
|
||||
reinterpret_cast<half*>(outputs[0]), m, n, k, 0};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(tensorrt_llm::kernels::WeightOnlyQuantType::Int8b,
|
||||
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
|
||||
tensorrt_llm::kernels::WeightOnlyActivationType::Identity, params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_weightOnlyGemmRunner->gemm(reinterpret_cast<const half*>(inputs[0]),
|
||||
reinterpret_cast<const int8_t*>(inputs[1]), reinterpret_cast<const half*>(inputs[2]),
|
||||
reinterpret_cast<half*>(outputs[0]), m, n, k, *bestTactic, reinterpret_cast<char*>(workspace), ws_size,
|
||||
stream);
|
||||
}
|
||||
weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
|
||||
}
|
||||
else if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT4)
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled)
|
||||
{
|
||||
// Use CUDA kernels for small batch size
|
||||
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel
|
||||
// when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
|
||||
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[1]),
|
||||
reinterpret_cast<const half*>(inputs[2]), nullptr, reinterpret_cast<const half*>(inputs[0]), nullptr,
|
||||
reinterpret_cast<half*>(outputs[0]), m, n * INT8_INT4_RATIO, k, 0};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b,
|
||||
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
|
||||
tensorrt_llm::kernels::WeightOnlyActivationType::Identity, params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_weightOnlyGemmRunner->gemm(reinterpret_cast<const half*>(inputs[0]),
|
||||
reinterpret_cast<const cutlass::uint4b_t*>(inputs[1]), reinterpret_cast<const half*>(inputs[2]),
|
||||
reinterpret_cast<half*>(outputs[0]), m, n * INT8_INT4_RATIO, k, *bestTactic,
|
||||
reinterpret_cast<char*>(workspace), ws_size, stream);
|
||||
}
|
||||
weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16;
|
||||
}
|
||||
if (mWeightTypeId == WeightTypeId::INT8)
|
||||
{
|
||||
weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
|
||||
real_n = n;
|
||||
}
|
||||
else if (mWeightTypeId == WeightTypeId::INT4)
|
||||
{
|
||||
weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int4b;
|
||||
real_n = n * INT8_INT4_RATIO;
|
||||
}
|
||||
if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled)
|
||||
{
|
||||
// Use CUDA kernels for small batch size
|
||||
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass
|
||||
// kernel when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
|
||||
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[1]), inputs[2], nullptr,
|
||||
inputs[0], nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type,
|
||||
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
|
||||
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_weightOnlyGemmRunner->gemm(inputs[0], inputs[1], inputs[2], outputs[0], m, real_n, k, *bestTactic,
|
||||
reinterpret_cast<char*>(workspace), ws_size, stream);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -60,7 +60,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("ids", &tpr::GenerationInput::ids)
|
||||
.def_readwrite("lengths", &tpr::GenerationInput::lengths)
|
||||
.def_readwrite("packed", &tpr::GenerationInput::packed)
|
||||
.def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBiasOpt)
|
||||
.def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBias)
|
||||
.def_readwrite("bad_words_list", &tpr::GenerationInput::badWordsList)
|
||||
.def_readwrite("stop_words_list", &tpr::GenerationInput::stopWordsList)
|
||||
.def_readwrite("max_new_tokens", &tpr::GenerationInput::maxNewTokens)
|
||||
@ -75,9 +75,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits);
|
||||
|
||||
py::class_<tb::kv_cache_manager::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<std::optional<tr::SizeType>, std::optional<float>>(), py::arg("max_tokens") = py::none(),
|
||||
.def(py::init<std::optional<tr::SizeType>, std::optional<tr::SizeType>, std::optional<float>>(),
|
||||
py::arg("max_tokens") = py::none(), py::arg("max_kv_cache_length") = py::none(),
|
||||
py::arg("free_gpu_memory_fraction") = py::none())
|
||||
.def_readwrite("max_tokens", &tb::kv_cache_manager::KvCacheConfig::maxTokens)
|
||||
.def_readwrite("max_kv_cache_length", &tb::kv_cache_manager::KvCacheConfig::maxKvCacheLength)
|
||||
.def_readwrite("free_gpu_memory_fraction", &tb::kv_cache_manager::KvCacheConfig::freeGpuMemoryFraction);
|
||||
|
||||
py::class_<tr::GptSession::Config>(m, "GptSessionConfig")
|
||||
|
||||
@ -40,8 +40,8 @@ std::shared_ptr<tr::GenerationInput> GenerationInput::toTrtLlm() const
|
||||
{
|
||||
auto input = std::make_shared<tr::GenerationInput>(
|
||||
endId, padId, tr::TorchView::of(ids.value()), tr::TorchView::of(lengths.value()), packed);
|
||||
if (embeddingBiasOpt)
|
||||
input->embeddingBiasOpt = tr::TorchView::of(embeddingBiasOpt.value());
|
||||
if (embeddingBias)
|
||||
input->embeddingBias = tr::TorchView::of(embeddingBias.value());
|
||||
if (badWordsList)
|
||||
input->badWordsList = tr::TorchView::of(badWordsList.value());
|
||||
if (stopWordsList)
|
||||
|
||||
@ -90,8 +90,8 @@ typename tl::DynamicDecodeLayer<T>::ForwardParams prepareInputs(DecodingInput co
|
||||
TLLM_CHECK(input.logits->getDataType() == TRTDataType<T>::value);
|
||||
|
||||
auto constexpr ite = 0; // no pipeline parallelism
|
||||
typename tl::DynamicDecodeLayer<T>::ForwardParams forwardParams{input.step, ite, input.maxLength, input.batchSize,
|
||||
tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)};
|
||||
typename tl::DynamicDecodeLayer<T>::ForwardParams forwardParams{input.step, ite, input.maxLength,
|
||||
input.maxKvCacheLength, input.batchSize, tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)};
|
||||
|
||||
if (input.cacheIndirection)
|
||||
{
|
||||
|
||||
@ -83,7 +83,7 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
auto& dInput = mJointDecodingInput;
|
||||
auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dInput = std::make_unique<DecodingInput>(0, 0, std::move(dummyLogits), std::move(endIds));
|
||||
dInput = std::make_unique<DecodingInput>(0, 0, 0, std::move(dummyLogits), std::move(endIds));
|
||||
|
||||
dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
@ -104,8 +104,8 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::setup(
|
||||
SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(maxBatchSize > 0);
|
||||
@ -114,6 +114,7 @@ void GptDecoderBatch::setup(
|
||||
|
||||
mActualBatchSize = maxBatchSize;
|
||||
mMaxSequenceLength = maxSequenceLength;
|
||||
mMaxKvCacheLength = maxKvCacheLength;
|
||||
|
||||
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
|
||||
auto const maxBatchSizeXmaxBeamWidth = ITensor::makeShape({maxBatchSize, maxBeamWidth});
|
||||
@ -211,7 +212,8 @@ void GptDecoderBatch::newRequest(
|
||||
|
||||
TensorPtr endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchIdx, localBatchSize)};
|
||||
kernels::invokeFill(*endIdTensorPtr, endId, *stream);
|
||||
dInput = std::make_unique<DecodingInput>(inputLength, localBatchSize, dJointInput.logits, endIdTensorPtr);
|
||||
dInput = std::make_unique<DecodingInput>(
|
||||
inputLength, mMaxKvCacheLength, localBatchSize, dJointInput.logits, endIdTensorPtr);
|
||||
|
||||
// Here, we need to add leading 1 dimension since decoderInput expects batchSize as leading dim
|
||||
// and decoder_batch::Request doesn't have batch dimension
|
||||
@ -458,17 +460,30 @@ void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig con
|
||||
}
|
||||
auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId, inputs.padId};
|
||||
|
||||
if (inputs.embeddingBiasOpt)
|
||||
if (inputs.embeddingBias)
|
||||
{
|
||||
TLLM_THROW("newBatch doesn't support embeddingBias yet.");
|
||||
}
|
||||
if (inputs.badWordsList)
|
||||
{
|
||||
TLLM_THROW("newBatch doesn't support badWordsList yet.");
|
||||
auto const& shape = inputs.badWordsList->getShape();
|
||||
if (shape.nbDims == 2)
|
||||
{
|
||||
request.badWordsList = inputs.badWordsList;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(shape.nbDims == 3);
|
||||
TensorPtr badWordsListView = ITensor::slice(inputs.badWordsList, batchIdx, 1);
|
||||
badWordsListView->squeeze(0);
|
||||
request.badWordsList = badWordsListView;
|
||||
}
|
||||
}
|
||||
if (inputs.stopWordsList)
|
||||
{
|
||||
TLLM_THROW("newBatch doesn't support stopWordsList yet.");
|
||||
TensorPtr stopWordsListView = ITensor::slice(inputs.stopWordsList, batchIdx, 1);
|
||||
stopWordsListView->squeeze(0);
|
||||
request.stopWordsList = stopWordsListView;
|
||||
}
|
||||
newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx));
|
||||
}
|
||||
|
||||
@ -117,8 +117,8 @@ void GptSession::createBuffers(SizeType numMicroBatches)
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
|
||||
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const vocabSize = mModelConfig.getVocabSize();
|
||||
@ -133,14 +133,14 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType
|
||||
mDecoders.emplace_back(std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream));
|
||||
else
|
||||
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
|
||||
mDecoders.back()->setup(batchSize, beamWidth, maxSequenceLength, logitsType);
|
||||
mDecoders.back()->setup(batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, logitsType);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::createKvCacheManager(
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config)
|
||||
void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, KvCacheConfig const& config)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
|
||||
@ -168,8 +168,9 @@ void GptSession::createKvCacheManager(
|
||||
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
|
||||
auto const maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock);
|
||||
|
||||
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize,
|
||||
tokensPerBlock, maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr());
|
||||
mKvCacheManager
|
||||
= std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
|
||||
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxKvCacheLength, kvDtype, mRuntime->getStreamPtr());
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -232,6 +233,9 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
auto const maxBatchSize = sessionConfig.maxBatchSize;
|
||||
auto const maxBeamWidth = sessionConfig.maxBeamWidth;
|
||||
auto const maxSequenceLength = sessionConfig.maxSequenceLength;
|
||||
auto const maxKvCacheLength = sessionConfig.kvCacheConfig.maxKvCacheLength.has_value()
|
||||
? std::min(sessionConfig.kvCacheConfig.maxKvCacheLength.value(), maxSequenceLength)
|
||||
: maxSequenceLength;
|
||||
|
||||
mMicroBatchConfig = MicroBatchConfig(maxBatchSize, mWorldConfig.getPipelineParallelism(),
|
||||
sessionConfig.genMicroBatchSize, sessionConfig.ctxMicroBatchSize);
|
||||
@ -244,16 +248,18 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
|
||||
// TODO refactor batch manager to remove dependency on maxSequenceLength.
|
||||
mDecoderMaxSequenceLength = maxSequenceLength;
|
||||
mDecoderMaxKvCacheLength = maxKvCacheLength;
|
||||
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
{
|
||||
createKvCacheManager(maxBatchSize, maxBeamWidth, maxSequenceLength, sessionConfig.kvCacheConfig);
|
||||
createKvCacheManager(
|
||||
maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, sessionConfig.kvCacheConfig);
|
||||
}
|
||||
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
|
||||
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength, logitsType,
|
||||
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, logitsType,
|
||||
sessionConfig.decoderPerRequest, mMicroBatchConfig.numGenBatches);
|
||||
}
|
||||
|
||||
@ -272,8 +278,8 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
for (auto& buffers : mBuffers)
|
||||
{
|
||||
// we don't know maxInputLength yet and ignore it for pre-allocation
|
||||
buffers->generationConfig
|
||||
= RuntimeBuffers::GenerationConfig{mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxSequenceLength};
|
||||
buffers->generationConfig = RuntimeBuffers::GenerationConfig{
|
||||
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxKvCacheLength, maxSequenceLength};
|
||||
buffers->reshape(mModelConfig, mWorldConfig);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -403,8 +409,9 @@ std::vector<GenerationInput> splitInputs(GenerationInput const& inputs, SizeType
|
||||
auto const offset = microBatchOffsets[batchId];
|
||||
auto const batchSize = microBatchOffsets[batchId + 1] - offset;
|
||||
|
||||
if (inputs.embeddingBiasOpt)
|
||||
batch.embeddingBiasOpt = inputs.embeddingBiasOpt;
|
||||
if (inputs.embeddingBias)
|
||||
batch.embeddingBias = inputs.embeddingBias;
|
||||
|
||||
if (inputs.badWordsList)
|
||||
{
|
||||
auto const& shape = inputs.badWordsList->getShape();
|
||||
@ -414,7 +421,7 @@ std::vector<GenerationInput> splitInputs(GenerationInput const& inputs, SizeType
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(nbDims == 3);
|
||||
assert(shape.nbDims == 3);
|
||||
batch.badWordsList = ITensor::slice(inputs.badWordsList, offset, batchSize);
|
||||
}
|
||||
}
|
||||
@ -524,7 +531,7 @@ void GptSession::generateBatched(
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
|
||||
mDecoderMaxSequenceLength, manager);
|
||||
mDecoderMaxKvCacheLength, mDecoderMaxSequenceLength, manager);
|
||||
buffers.reshape(mModelConfig, mWorldConfig);
|
||||
buffers.reset(manager);
|
||||
}
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
#include "tensorrt_llm/runtime/tensorView.h"
|
||||
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
@ -44,11 +43,8 @@ nvinfer1::Dims ITensor::makeShape(std::initializer_list<SizeType> const& dims)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(dims.size() <= nvinfer1::Dims::MAX_DIMS, "Number of dimensions is too large");
|
||||
nvinfer1::Dims shape{};
|
||||
shape.nbDims = dims.size();
|
||||
for (std::size_t i = 0; i < dims.size(); ++i)
|
||||
{
|
||||
shape.d[i] = std::data(dims)[i];
|
||||
}
|
||||
shape.nbDims = static_cast<decltype(Shape::nbDims)>(dims.size());
|
||||
std::copy(dims.begin(), dims.end(), shape.d);
|
||||
return shape;
|
||||
}
|
||||
|
||||
@ -97,6 +93,32 @@ ITensor::UniquePtr ITensor::wrap(void* data, nvinfer1::DataType type, nvinfer1::
|
||||
return result;
|
||||
}
|
||||
|
||||
ITensor::Shape ITensor::squeeze(Shape const& shape, SizeType dim)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(0 < shape.nbDims, "Cannot squeeze 1-dimensional tensor");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
dim < shape.nbDims, tc::fmtstr("Invalid index %d, tensor has %d dimensions", dim, shape.nbDims));
|
||||
TLLM_CHECK_WITH_INFO(shape.d[dim] == 1, "Can only squeeze dimension of size 1");
|
||||
|
||||
Shape newDims{shape.nbDims - 1};
|
||||
std::copy(shape.d, shape.d + dim, newDims.d);
|
||||
std::copy(shape.d + dim + 1, shape.d + shape.nbDims, newDims.d + dim);
|
||||
return newDims;
|
||||
}
|
||||
|
||||
ITensor::Shape ITensor::unsqueeze(Shape const& shape, SizeType dim)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(shape.nbDims < Shape::MAX_DIMS, "Too many dimensions to unsqueeze");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
0 <= dim && dim <= shape.nbDims, common::fmtstr("Invalid dim %d, tensor has %d dimensions", dim, shape.nbDims));
|
||||
|
||||
Shape newDims{shape.nbDims + 1};
|
||||
std::copy(shape.d, shape.d + dim, newDims.d);
|
||||
newDims.d[dim] = 1;
|
||||
std::copy(shape.d + dim, shape.d + shape.nbDims, newDims.d + dim + 1);
|
||||
return newDims;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
@ -21,7 +21,7 @@ namespace tensorrt_llm::runtime
|
||||
|
||||
void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize,
|
||||
const SizeType numContextRequests, const std::vector<SizeType>& reqBeamWidths,
|
||||
const std::vector<SizeType>& reqPromptLengths, BufferManager& manager, bool packedInput)
|
||||
const std::vector<SizeType>& reqPromptLengths, BufferManager const& manager, bool packedInput)
|
||||
{
|
||||
auto const& tasksHostShape = tasksHost->getShape();
|
||||
TLLM_CHECK_WITH_INFO(tasksHostShape.nbDims == 1, "tasksHost expected to have dimension [batchSize]");
|
||||
|
||||
@ -29,7 +29,8 @@ using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITensor const& inputIds,
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxSequenceLength)
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxKvCacheLength,
|
||||
SizeType const maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
|
||||
@ -57,7 +58,7 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
|
||||
"generated.");
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return GenerationConfig{batchSize, beamWidth, maxInputLength, maxSequenceLength, inputLengthSum};
|
||||
return GenerationConfig{batchSize, beamWidth, maxInputLength, maxKvCacheLength, maxSequenceLength, inputLengthSum};
|
||||
}
|
||||
|
||||
void RuntimeBuffers::clear()
|
||||
@ -154,6 +155,10 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i)
|
||||
{
|
||||
maxKvCacheLengths.emplace_back(manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -179,7 +184,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
}
|
||||
|
||||
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
|
||||
SizeType beamWidth, SizeType maxSequenceLength, BufferManager& manager)
|
||||
SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager)
|
||||
{
|
||||
contextLengthsDevice = inputLengths;
|
||||
contextLengthsHost->reshape(inputLengths->getShape());
|
||||
@ -187,7 +192,7 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp
|
||||
manager.getStream().synchronize(); // wait for context lengths to be copied to host
|
||||
|
||||
generationConfig = RuntimeBuffers::GenerationConfig::fromInput(
|
||||
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxSequenceLength);
|
||||
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxKvCacheLength, maxSequenceLength);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
@ -197,7 +202,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
auto const maxInputLength = generationConfig.maxInputLength;
|
||||
auto const maxSeqLength = generationConfig.maxSeqLength;
|
||||
auto const maxKvCacheLength = generationConfig.maxKvCacheLength;
|
||||
|
||||
if (worldConfig.isLastPipelineParallelRank() && !modelConfig.computeContextLogits())
|
||||
{
|
||||
@ -207,15 +212,15 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
|
||||
lastTokenIds->reshape(ITensor::makeShape({batchSize}));
|
||||
|
||||
auto kvCacheReserve
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxSeqLength, modelConfig.getSizePerHead()});
|
||||
auto kvCacheReserve = ITensor::makeShape(
|
||||
{batchSize, 2, modelConfig.getNbKvHeads(), maxKvCacheLength, modelConfig.getSizePerHead()});
|
||||
auto kvCacheShape
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
|
||||
auto const maxBlocksPerSeq = (maxSeqLength + tokensPerBlock - 1) / tokensPerBlock;
|
||||
auto const maxBlocksPerSeq = (maxKvCacheLength + tokensPerBlock - 1) / tokensPerBlock;
|
||||
|
||||
// reserve batchSize * beamWidth and resize to batchSize
|
||||
auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq});
|
||||
@ -233,6 +238,10 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
|
||||
for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i)
|
||||
{
|
||||
maxKvCacheLengths[i]->reshape(ITensor::makeShape({1}));
|
||||
}
|
||||
requestTypes->reshape(ITensor::makeShape({batchSize}));
|
||||
}
|
||||
else
|
||||
@ -243,7 +252,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheShape);
|
||||
}
|
||||
|
||||
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxSeqLength});
|
||||
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxKvCacheLength});
|
||||
cacheIndirectionDecoderInput->reshape(cacheIndirShape);
|
||||
cacheIndirectionDecoderOutput->reshape(cacheIndirShape);
|
||||
|
||||
@ -327,6 +336,7 @@ std::vector<RuntimeBuffers> RuntimeBuffers::split(
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
buffers.pastKeyValueLengths = ITensor::slice(pastKeyValueLengths, offset, batchSize);
|
||||
buffers.maxKvCacheLengths = maxKvCacheLengths;
|
||||
buffers.requestTypes = ITensor::slice(requestTypes, offset, batchSize);
|
||||
}
|
||||
else
|
||||
@ -523,6 +533,12 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
|
||||
std::fill_n(RequestTypesPtr, batchSize, 0);
|
||||
|
||||
// Set maxKvCacheLengths buffer to the same value currently.
|
||||
for (auto layer = 0; layer < modelConfig.getNbLayers(); ++layer)
|
||||
{
|
||||
bufferCast<SizeType>(*maxKvCacheLengths[layer])[0] = generationConfig.maxKvCacheLength;
|
||||
}
|
||||
|
||||
auto const& inputShape = inputIds->getShape();
|
||||
auto const contextLengthsHostPtr = bufferCast<SizeType const>(*contextLengthsHost);
|
||||
auto const modelVariant = modelConfig.getModelVariant();
|
||||
@ -788,6 +804,12 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
inputBuffers.insert_or_assign("host_request_types", requestTypes);
|
||||
inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
|
||||
|
||||
for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i)
|
||||
{
|
||||
std::string name = "host_max_kv_cache_length_" + std::to_string(i);
|
||||
inputBuffers.insert_or_assign(name, maxKvCacheLengths[i]);
|
||||
}
|
||||
|
||||
if (modelConfig.usePackedInput())
|
||||
{
|
||||
inputBuffers.insert_or_assign("host_context_lengths", contextLengthsHost);
|
||||
|
||||
@ -49,10 +49,11 @@ public:
|
||||
GenerationConfig() = default;
|
||||
|
||||
explicit GenerationConfig(SizeType batchSize, SizeType beamWidth, SizeType maxInputLength,
|
||||
SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
|
||||
SizeType maxKvCacheLength, SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
|
||||
: batchSize{batchSize}
|
||||
, beamWidth{beamWidth}
|
||||
, maxInputLength{maxInputLength}
|
||||
, maxKvCacheLength{maxKvCacheLength}
|
||||
, maxSeqLength{maxSeqLength}
|
||||
, inputLengthSum{inputLengthSum}
|
||||
{
|
||||
@ -61,11 +62,12 @@ public:
|
||||
SizeType batchSize{};
|
||||
SizeType beamWidth{};
|
||||
SizeType maxInputLength{};
|
||||
SizeType maxKvCacheLength{};
|
||||
SizeType maxSeqLength{};
|
||||
SizeType inputLengthSum{}; // Initialized only if inputPacked is set to true in fromInput.
|
||||
|
||||
static GenerationConfig fromInput(ITensor const& inputIds, ITensor const& inputLengths, bool inputPacked,
|
||||
SizeType beamWidth, SizeType maxSequenceLength);
|
||||
SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength);
|
||||
};
|
||||
|
||||
public:
|
||||
@ -88,6 +90,7 @@ public:
|
||||
|
||||
std::vector<TensorPtr> presentKeysVals;
|
||||
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
|
||||
std::vector<TensorPtr> maxKvCacheLengths; // with attention plugin, host tensor
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
|
||||
@ -119,7 +122,7 @@ public:
|
||||
void create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, SizeType beamWidth,
|
||||
SizeType maxSequenceLength, BufferManager& manager);
|
||||
SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager);
|
||||
|
||||
//! \brief Reshape buffers based on current GenerationConfig
|
||||
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
@ -41,7 +41,7 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
|
||||
auto& dInput = mDecodingInput;
|
||||
auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dInput = std::make_unique<DecodingInput>(0, 0, std::move(dummyLogits), std::move(endIds));
|
||||
dInput = std::make_unique<DecodingInput>(0, 0, 0, std::move(dummyLogits), std::move(endIds));
|
||||
|
||||
dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
@ -61,17 +61,18 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::setup(
|
||||
SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
mDecoder = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStream);
|
||||
|
||||
reshapeBuffers(maxBatchSize, maxBeamWidth, maxSequenceLength);
|
||||
reshapeBuffers(maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength)
|
||||
void StatefulGptDecoder::reshapeBuffers(
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
@ -79,6 +80,7 @@ void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth,
|
||||
TLLM_CHECK(maxSequenceLength > 0);
|
||||
|
||||
mMaxSequenceLength = maxSequenceLength;
|
||||
mMaxKvCacheLength = maxKvCacheLength;
|
||||
|
||||
auto const batchSizeShape = ITensor::makeShape({batchSize});
|
||||
auto const batchSizeXbeamWidth = ITensor::makeShape({batchSize, beamWidth});
|
||||
@ -129,7 +131,7 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
auto const batchSize = inputLengthsShape.d[0];
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
|
||||
reshapeBuffers(batchSize, beamWidth, mMaxSequenceLength);
|
||||
reshapeBuffers(batchSize, beamWidth, mMaxKvCacheLength, mMaxSequenceLength);
|
||||
mDecoder->setup(samplingConfig, batchSize);
|
||||
|
||||
// sanity checks, should always be true after reshape
|
||||
@ -159,9 +161,10 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
// inputs
|
||||
auto& dInput = *mDecodingInput;
|
||||
dInput.maxLength = maxInputLength;
|
||||
dInput.maxKvCacheLength = mMaxKvCacheLength;
|
||||
dInput.batchSize = batchSize;
|
||||
kernels::invokeFill(const_cast<ITensor&>(*dInput.endIds), endId, *stream);
|
||||
dInput.embeddingBias = inputs.embeddingBiasOpt;
|
||||
dInput.embeddingBias = inputs.embeddingBias;
|
||||
dInput.badWordsList = inputs.badWordsList;
|
||||
dInput.stopWordsList = inputs.stopWordsList;
|
||||
auto inputLengthsView = ITensor::view(dInput.lengths, ITensor::makeShape({batchSize * beamWidth}));
|
||||
|
||||
@ -39,8 +39,8 @@ public:
|
||||
StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(
|
||||
SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) override;
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
void newBatch(GenerationInput const& input, SamplingConfig const& samplingConfig) override;
|
||||
@ -72,7 +72,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
|
||||
void reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType mMaxKvCacheLength, SizeType maxSequenceLength);
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
@ -90,5 +90,6 @@ private:
|
||||
|
||||
SizeType mNbSteps;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxKvCacheLength{};
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -64,15 +64,7 @@ public:
|
||||
|
||||
void resize(std::size_t newSize) override
|
||||
{
|
||||
if (newSize != getSize())
|
||||
{
|
||||
using dimType = std::remove_reference_t<decltype(mDims.d[0])>;
|
||||
auto constexpr max_size = std::numeric_limits<dimType>::max();
|
||||
TLLM_CHECK_WITH_INFO(newSize <= max_size, "New size is too large. Use reshape() instead.");
|
||||
Base::resize(newSize);
|
||||
mDims.nbDims = 1;
|
||||
mDims.d[0] = static_cast<dimType>(newSize);
|
||||
}
|
||||
ITensor::resize(newSize);
|
||||
}
|
||||
|
||||
void release() override
|
||||
|
||||
@ -101,7 +101,7 @@ public:
|
||||
TLLM_CHECK_WITH_INFO(static_cast<bool>(mCudaStream), "Undefined CUDA stream");
|
||||
}
|
||||
|
||||
CudaStreamPtr getCudaStream() const
|
||||
[[nodiscard]] CudaStreamPtr getCudaStream() const
|
||||
{
|
||||
return mCudaStream;
|
||||
}
|
||||
@ -236,13 +236,14 @@ public:
|
||||
//!
|
||||
//! \brief Construct an empty buffer.
|
||||
//!
|
||||
explicit GenericBuffer(nvinfer1::DataType type, TAllocator allocator = {})
|
||||
explicit GenericBuffer(nvinfer1::DataType type, TAllocator allocator = {}) // NOLINT(*-pro-type-member-init)
|
||||
: GenericBuffer{0, type, std::move(allocator)} {};
|
||||
|
||||
//!
|
||||
//! \brief Construct a buffer with the specified allocation size in number of elements.
|
||||
//!
|
||||
explicit GenericBuffer(std::size_t size, nvinfer1::DataType type, TAllocator allocator = {})
|
||||
explicit GenericBuffer( // NOLINT(*-pro-type-member-init)
|
||||
std::size_t size, nvinfer1::DataType type, TAllocator allocator = {})
|
||||
: GenericBuffer{size, size, type, std::move(allocator)} {};
|
||||
|
||||
GenericBuffer(GenericBuffer&& buf) noexcept
|
||||
@ -280,21 +281,21 @@ public:
|
||||
//!
|
||||
void* data() override
|
||||
{
|
||||
return mBuffer;
|
||||
return TLLM_LIKELY(mSize > 0) ? mBuffer : nullptr;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Returns pointer to underlying array.
|
||||
//!
|
||||
const void* data() const override
|
||||
[[nodiscard]] void const* data() const override
|
||||
{
|
||||
return mBuffer;
|
||||
return TLLM_LIKELY(mSize > 0) ? mBuffer : nullptr;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Returns the size (in number of elements) of the buffer.
|
||||
//!
|
||||
std::size_t getSize() const override
|
||||
[[nodiscard]] std::size_t getSize() const override
|
||||
{
|
||||
return mSize;
|
||||
}
|
||||
@ -302,7 +303,7 @@ public:
|
||||
//!
|
||||
//! \brief Returns the capacity of the buffer.
|
||||
//!
|
||||
std::size_t getCapacity() const override
|
||||
[[nodiscard]] std::size_t getCapacity() const override
|
||||
{
|
||||
return mCapacity;
|
||||
}
|
||||
@ -310,7 +311,7 @@ public:
|
||||
//!
|
||||
//! \brief Returns the type of the buffer.
|
||||
//!
|
||||
nvinfer1::DataType getDataType() const override
|
||||
[[nodiscard]] nvinfer1::DataType getDataType() const override
|
||||
{
|
||||
return mType;
|
||||
}
|
||||
@ -318,7 +319,7 @@ public:
|
||||
//!
|
||||
//! \brief Returns the memory type of the buffer.
|
||||
//!
|
||||
MemoryType getMemoryType() const override
|
||||
[[nodiscard]] MemoryType getMemoryType() const override
|
||||
{
|
||||
return mAllocator.getMemoryType();
|
||||
}
|
||||
@ -328,11 +329,7 @@ public:
|
||||
//!
|
||||
void resize(std::size_t newSize) override
|
||||
{
|
||||
if (newSize == 0)
|
||||
{
|
||||
release();
|
||||
}
|
||||
else if (mCapacity < newSize)
|
||||
if (mCapacity < newSize)
|
||||
{
|
||||
mAllocator.deallocate(mBuffer, toBytes(mCapacity));
|
||||
mBuffer = mAllocator.allocate(toBytes(newSize));
|
||||
@ -444,7 +441,7 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
nvinfer1::Dims const& getShape() const override
|
||||
[[nodiscard]] nvinfer1::Dims const& getShape() const override
|
||||
{
|
||||
return mDims;
|
||||
}
|
||||
@ -457,15 +454,7 @@ public:
|
||||
|
||||
void resize(std::size_t newSize) override
|
||||
{
|
||||
if (newSize != getSize())
|
||||
{
|
||||
using dimType = std::remove_reference_t<decltype(mDims.d[0])>;
|
||||
auto constexpr max_size = std::numeric_limits<dimType>::max();
|
||||
TLLM_CHECK_WITH_INFO(newSize <= max_size, "New size is too large. Use reshape() instead.");
|
||||
Base::resize(newSize);
|
||||
mDims.nbDims = 1;
|
||||
mDims.d[0] = static_cast<dimType>(newSize);
|
||||
}
|
||||
ITensor::resize(newSize);
|
||||
}
|
||||
|
||||
void release() override
|
||||
|
||||
@ -42,16 +42,12 @@ public:
|
||||
|
||||
void* data() override
|
||||
{
|
||||
if (getSize() == 0)
|
||||
return nullptr;
|
||||
return mTensor.data_ptr();
|
||||
return TLLM_LIKELY(getSize() > 0) ? mTensor.data_ptr() : nullptr;
|
||||
}
|
||||
|
||||
[[nodiscard]] void const* data() const override
|
||||
{
|
||||
if (getSize() == 0)
|
||||
return nullptr;
|
||||
return mTensor.data_ptr();
|
||||
return TLLM_LIKELY(getSize() > 0) ? mTensor.data_ptr() : nullptr;
|
||||
}
|
||||
|
||||
[[nodiscard]] size_t getSize() const override
|
||||
@ -76,17 +72,7 @@ public:
|
||||
|
||||
void resize(std::size_t newSize) override
|
||||
{
|
||||
TLLM_CHECK(newSize <= getCapacity());
|
||||
|
||||
if (newSize != getSize())
|
||||
{
|
||||
using dimType = std::remove_reference_t<decltype(mDims.d[0])>;
|
||||
auto constexpr max_size = std::numeric_limits<dimType>::max();
|
||||
TLLM_CHECK_WITH_INFO(newSize <= max_size, "New size is too large. Use reshape() instead.");
|
||||
mTensor.resize_({static_cast<at::IntArrayRef::value_type>(newSize)});
|
||||
mDims.nbDims = 1;
|
||||
mDims.d[0] = static_cast<dimType>(newSize);
|
||||
}
|
||||
ITensor::resize(newSize);
|
||||
}
|
||||
|
||||
void release() override
|
||||
|
||||
@ -138,7 +138,7 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
|
||||
|
||||
template <typename T>
|
||||
void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
|
||||
int step, int max_input_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -156,8 +156,8 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
|
||||
{
|
||||
auto const& logits_converted = convert_tensor<float>(logits);
|
||||
auto const& end_ids_converted = convert_tensor<int>(end_id);
|
||||
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::ForwardParams forwardParams{
|
||||
step, static_cast<int>(ite), max_input_length, local_batch_size, logits_converted, end_ids_converted};
|
||||
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::ForwardParams forwardParams{step, static_cast<int>(ite),
|
||||
max_input_length, max_kv_cache_length, local_batch_size, logits_converted, end_ids_converted};
|
||||
|
||||
safeUpdate<int>(src_cache_indirection_opt, forwardParams.src_cache_indirection);
|
||||
safeUpdate<int>(sequence_limit_length_opt, forwardParams.sequence_limit_length);
|
||||
@ -272,8 +272,9 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
|
||||
top_p_reset_ids_opt);
|
||||
}
|
||||
|
||||
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length, int64_t ite,
|
||||
int64_t local_batch_size, th::Tensor end_id, th::optional<th::Tensor> embedding_bias_opt,
|
||||
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length,
|
||||
int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt,
|
||||
th::optional<th::Tensor> input_lengths_opt, // length of input contexts.
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -339,9 +340,10 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
|
||||
|
||||
dynamic_decode_->forward(
|
||||
// Inputs
|
||||
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<uint32_t>(ite),
|
||||
static_cast<int>(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt, sequence_limit_length_opt,
|
||||
stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt, src_cache_indirection_opt,
|
||||
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<int>(max_kv_cache_length),
|
||||
static_cast<uint32_t>(ite), static_cast<int>(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt,
|
||||
sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt,
|
||||
src_cache_indirection_opt,
|
||||
// Outputs
|
||||
output_token_ids, newTokens, should_stop, finished_opt, seuqence_lengths_opt, cum_log_probs_opt,
|
||||
output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt,
|
||||
|
||||
@ -39,7 +39,7 @@ public:
|
||||
= 0;
|
||||
|
||||
virtual void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
|
||||
int step, int max_input_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -77,7 +77,7 @@ public:
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt) override;
|
||||
|
||||
void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
|
||||
int step, int max_input_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -121,8 +121,8 @@ public:
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt);
|
||||
|
||||
th::Tensor forward(th::Tensor logits, // (batch_size, beam_width, vocab_size)
|
||||
int64_t step, int64_t max_input_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt,
|
||||
int64_t step, int64_t max_input_length, int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size,
|
||||
th::Tensor end_id, th::optional<th::Tensor> embedding_bias_opt,
|
||||
th::optional<th::Tensor> input_lengths_opt, // length of input contexts.
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
|
||||
@ -75,7 +75,10 @@ add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp)
|
||||
add_gtest(runtimeKernelTest runtime/runtimeKernelTest.cpp)
|
||||
add_gtest(samplingTest runtime/samplingTest.cpp)
|
||||
add_gtest(iTensorTest runtime/iTensorTest.cpp)
|
||||
add_gtest(torchTest runtime/torchTest.cpp)
|
||||
if(${BUILD_PYT})
|
||||
add_gtest(torchTest runtime/torchTest.cpp)
|
||||
target_link_libraries(torchTest PUBLIC ${TORCH_LIBRARIES})
|
||||
endif()
|
||||
set(SAMPLING_KERNEL_TEST_SRC
|
||||
kernels/sampling/samplingTest.cpp
|
||||
kernels/sampling/samplingTopKTest.cpp
|
||||
@ -83,7 +86,7 @@ set(SAMPLING_KERNEL_TEST_SRC
|
||||
kernels/sampling/samplingPenaltyTest.cpp
|
||||
kernels/sampling/samplingUtilsTest.cu)
|
||||
add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}")
|
||||
target_link_libraries(torchTest PUBLIC ${TORCH_LIBRARIES})
|
||||
add_gtest(weightOnlyKernelTest kernels/weightOnly/weightOnlyKernelTest.cpp)
|
||||
|
||||
if(BUILD_BATCH_MANAGER)
|
||||
add_subdirectory(batch_manager)
|
||||
|
||||
429
cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp
Normal file
429
cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp
Normal file
@ -0,0 +1,429 @@
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h"
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
using tensorrt_llm::kernels::WeightOnlyParams;
|
||||
using tensorrt_llm::kernels::WeightOnlyType;
|
||||
using tensorrt_llm::kernels::WeightOnlyQuantType;
|
||||
using tensorrt_llm::kernels::WeightOnlyActivationType;
|
||||
using tensorrt_llm::kernels::WeightOnlyActivationFunctionType;
|
||||
template <WeightOnlyActivationType T>
|
||||
struct AType;
|
||||
|
||||
template <>
|
||||
struct AType<WeightOnlyActivationType::FP16>
|
||||
{
|
||||
using CudaKernelAType = half;
|
||||
using CutlassKernelAType = half;
|
||||
};
|
||||
#if defined(ENABLE_BF16)
|
||||
template <>
|
||||
struct AType<WeightOnlyActivationType::BF16>
|
||||
{
|
||||
using CudaKernelAType = __nv_bfloat16;
|
||||
using CutlassKernelAType = __nv_bfloat16;
|
||||
};
|
||||
#endif
|
||||
template <WeightOnlyQuantType T>
|
||||
struct BType;
|
||||
|
||||
template <>
|
||||
struct BType<WeightOnlyQuantType::Int4b>
|
||||
{
|
||||
using CudaKernelBType = uint8_t;
|
||||
using CutlassKernelBType = cutlass::uint4b_t;
|
||||
static constexpr int elemsPerByte = 2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BType<WeightOnlyQuantType::Int8b>
|
||||
{
|
||||
using CudaKernelBType = uint8_t;
|
||||
using CutlassKernelBType = uint8_t;
|
||||
static constexpr int elemsPerByte = 1;
|
||||
};
|
||||
struct CutlassKernel;
|
||||
struct CudaKernel;
|
||||
|
||||
template <typename KernelFlag, WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
|
||||
float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n,
|
||||
int k, int group_size, int warmup, int iter)
|
||||
{
|
||||
assert(zeros == nullptr && bias == nullptr && group_size == 0);
|
||||
cudaStream_t s;
|
||||
cudaStreamCreate(&s);
|
||||
cudaEvent_t begin, end;
|
||||
cudaEventCreate(&begin);
|
||||
cudaEventCreate(&end);
|
||||
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
|
||||
{
|
||||
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
|
||||
BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag};
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
|
||||
}
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < iter; ++i)
|
||||
{
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
|
||||
}
|
||||
}
|
||||
else if (std::is_same_v<KernelFlag, CutlassKernel>)
|
||||
{
|
||||
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<typename AType<AFlag>::CutlassKernelAType,
|
||||
typename BType<BFlag>::CutlassKernelBType, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
|
||||
gemm;
|
||||
auto configs = gemm.getConfigs();
|
||||
int ws_bytes = gemm.getWorkspaceSize(m, n, k);
|
||||
char* ws_ptr = nullptr;
|
||||
if (ws_bytes)
|
||||
cudaMalloc(&ws_ptr, ws_bytes);
|
||||
float fast_time = 1e8;
|
||||
auto best_config = configs[0];
|
||||
for (auto& config : configs)
|
||||
{
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < 5; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaEventRecord(end, s);
|
||||
cudaEventSynchronize(end);
|
||||
float time;
|
||||
cudaEventElapsedTime(&time, begin, end);
|
||||
fast_time = std::min(fast_time, time);
|
||||
if (time < fast_time)
|
||||
{
|
||||
fast_time = time;
|
||||
best_config = config;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaProfilerStart();
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < iter; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
if (ws_ptr)
|
||||
cudaFree(ws_ptr);
|
||||
}
|
||||
|
||||
cudaEventRecord(end, s);
|
||||
cudaEventSynchronize(end);
|
||||
float time;
|
||||
cudaEventElapsedTime(&time, begin, end);
|
||||
cudaEventDestroy(begin);
|
||||
cudaEventDestroy(end);
|
||||
cudaStreamDestroy(s);
|
||||
return time / iter;
|
||||
}
|
||||
|
||||
template <typename KernelFlag, WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
|
||||
float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n,
|
||||
int k, int group_size, int warmup, int iter)
|
||||
{
|
||||
assert(zeros && bias && (group_size == 64 || group_size == 128));
|
||||
cudaStream_t s;
|
||||
cudaStreamCreate(&s);
|
||||
cudaEvent_t begin, end;
|
||||
cudaEventCreate(&begin);
|
||||
cudaEventCreate(&end);
|
||||
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
|
||||
{
|
||||
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
|
||||
BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag};
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
|
||||
}
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < iter; ++i)
|
||||
{
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
|
||||
}
|
||||
}
|
||||
else if (std::is_same_v<KernelFlag, CutlassKernel>)
|
||||
{
|
||||
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<typename AType<AFlag>::CutlassKernelAType,
|
||||
typename BType<BFlag>::CutlassKernelBType, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
|
||||
gemm;
|
||||
auto configs = gemm.getConfigs();
|
||||
int ws_bytes = gemm.getWorkspaceSize(m, n, k);
|
||||
char* ws_ptr = nullptr;
|
||||
if (ws_bytes)
|
||||
cudaMalloc(&ws_ptr, ws_bytes);
|
||||
float fast_time = 1e8;
|
||||
auto best_config = configs[0];
|
||||
for (auto& config : configs)
|
||||
{
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < 5; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaEventRecord(end, s);
|
||||
cudaEventSynchronize(end);
|
||||
float time;
|
||||
cudaEventElapsedTime(&time, begin, end);
|
||||
fast_time = std::min(fast_time, time);
|
||||
if (time < fast_time)
|
||||
{
|
||||
fast_time = time;
|
||||
best_config = config;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaProfilerStart();
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < iter; ++i)
|
||||
{
|
||||
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
if (ws_ptr)
|
||||
cudaFree(ws_ptr);
|
||||
}
|
||||
|
||||
cudaEventRecord(end, s);
|
||||
cudaEventSynchronize(end);
|
||||
float time;
|
||||
cudaEventElapsedTime(&time, begin, end);
|
||||
cudaEventDestroy(begin);
|
||||
cudaEventDestroy(end);
|
||||
cudaStreamDestroy(s);
|
||||
return time / iter;
|
||||
}
|
||||
|
||||
struct CudaBuffer
|
||||
{
|
||||
void* _data;
|
||||
int _size;
|
||||
|
||||
CudaBuffer(int size_in_bytes)
|
||||
: _size(size_in_bytes)
|
||||
{
|
||||
cudaMalloc(&_data, _size);
|
||||
}
|
||||
|
||||
template <typename T = void>
|
||||
T* data()
|
||||
{
|
||||
return reinterpret_cast<T*>(_data);
|
||||
}
|
||||
|
||||
void copy_to(void* dst)
|
||||
{
|
||||
cudaMemcpy(dst, _data, _size, cudaMemcpyDeviceToHost);
|
||||
}
|
||||
|
||||
void copy_from(void* src)
|
||||
{
|
||||
cudaMemcpy(_data, src, _size, cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~CudaBuffer()
|
||||
{
|
||||
cudaFree(_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
float compare(void* _pa, void* _pb, int size, float scale)
|
||||
{
|
||||
auto pa = reinterpret_cast<T*>(_pa);
|
||||
auto pb = reinterpret_cast<T*>(_pb);
|
||||
float max_diff = 0.f, tot_diff = 0.f;
|
||||
float max_val = 0.f;
|
||||
int diff_cnt = 0;
|
||||
float threshold = 1e-7;
|
||||
for (int n = 0; n < size; ++n)
|
||||
{
|
||||
float va = static_cast<float>(pa[n]);
|
||||
float vb = static_cast<float>(pb[n]);
|
||||
max_val = std::max(max_val, vb);
|
||||
float diff = std::abs(va - vb);
|
||||
if (diff > threshold)
|
||||
{
|
||||
max_diff = std::max(max_diff, diff);
|
||||
tot_diff += diff;
|
||||
++diff_cnt;
|
||||
}
|
||||
}
|
||||
float diff_thres = max_val * scale;
|
||||
#if defined(ENABLE_BF16)
|
||||
if constexpr (std::is_same_v<T, __nv_bfloat16>)
|
||||
{
|
||||
// bfloat16 has fewer mantissa digits than float16, so the cumulative error will be larger.
|
||||
diff_thres *= 2.f;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
diff_thres *= 1.5f;
|
||||
}
|
||||
printf("max diff %f (diff threshold %f), avg diff %f, diff cnt %d/%d\n", max_diff, diff_thres, tot_diff / diff_cnt,
|
||||
diff_cnt, size);
|
||||
return max_diff <= diff_thres;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
void random_fill(std::vector<T1>& vec, T2 minv, T2 maxv)
|
||||
{
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dis(static_cast<float>(minv), static_cast<float>(maxv));
|
||||
for (auto& v : vec)
|
||||
{
|
||||
v = static_cast<T1>(dis(gen));
|
||||
}
|
||||
}
|
||||
|
||||
template <WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
|
||||
bool benchmark(int m, int n, int k, int group_size, int warmup, int iter)
|
||||
{
|
||||
printf("benchmark mnk (%d, %d, %d) ", m, n, k);
|
||||
if (AFlag == WeightOnlyActivationType::FP16)
|
||||
{
|
||||
printf("FP16 Activation ");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("BF16 Activation ");
|
||||
}
|
||||
if (BFlag == WeightOnlyQuantType::Int8b)
|
||||
{
|
||||
printf("Int8b ");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Int4b ");
|
||||
}
|
||||
if (group_size == 0)
|
||||
{
|
||||
printf("PerChannel Weight Only\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("GroupWise%d Weight Only\n", group_size);
|
||||
}
|
||||
using AT = typename AType<AFlag>::CudaKernelAType;
|
||||
using BT = typename BType<BFlag>::CudaKernelBType;
|
||||
constexpr int elem_per_byte = BType<BFlag>::elemsPerByte;
|
||||
CudaBuffer d_act(m * k * sizeof(AT));
|
||||
CudaBuffer d_weight(k * n * sizeof(uint8_t) / elem_per_byte);
|
||||
CudaBuffer d_scales(n * k * sizeof(AT));
|
||||
CudaBuffer d_zeros(n * k * sizeof(AT));
|
||||
CudaBuffer d_bias(n * sizeof(AT));
|
||||
CudaBuffer d_out(m * n * sizeof(AT));
|
||||
std::vector<AT> h_act(m * k);
|
||||
std::vector<uint8_t> h_weight(k * n);
|
||||
std::vector<AT> h_scales(n * k), h_zeros(n * k), h_bias(n);
|
||||
std::vector<AT> h_out1(m * n), h_out2(m * n);
|
||||
|
||||
random_fill(h_act, -1.f, 1.f);
|
||||
random_fill(h_scales, -1.f, 1.f);
|
||||
|
||||
for (uint8_t& v : h_weight)
|
||||
{
|
||||
v = rand() % 256;
|
||||
}
|
||||
|
||||
d_act.copy_from(h_act.data());
|
||||
d_weight.copy_from(h_weight.data());
|
||||
d_scales.copy_from(h_scales.data());
|
||||
d_zeros.copy_from(h_zeros.data());
|
||||
d_bias.copy_from(h_bias.data());
|
||||
|
||||
void* p_zeros = nullptr;
|
||||
void* p_bias = nullptr;
|
||||
if (group_size == 64 || group_size == 128)
|
||||
{
|
||||
p_zeros = d_zeros.data();
|
||||
p_bias = d_bias.data();
|
||||
}
|
||||
|
||||
float time1, time2;
|
||||
time1 = benchmark_perchannel<CudaKernel, AFlag, BFlag>(
|
||||
d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, 0, warmup, iter);
|
||||
d_out.copy_to(h_out1.data());
|
||||
time2 = benchmark_perchannel<CutlassKernel, AFlag, BFlag>(
|
||||
d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, 0, warmup, iter);
|
||||
d_out.copy_to(h_out2.data());
|
||||
float quant_scale = 1.f / (1 << (8 / elem_per_byte - 1));
|
||||
bool pass = compare<AT>(h_out1.data(), h_out2.data(), m * n, quant_scale);
|
||||
printf(
|
||||
"cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1);
|
||||
return pass;
|
||||
}
|
||||
|
||||
TEST(Kernel, WeightOnly)
|
||||
{
|
||||
bool pass;
|
||||
int warmup = 10, iter = 30;
|
||||
std::vector<int> ms{1, 2, 4};
|
||||
std::vector<int> ns{512, 1024, 2048, 4096};
|
||||
std::vector<int> ks{512, 1024, 2048, 4096};
|
||||
std::vector<int> gss{0, 64, 128};
|
||||
for (auto m : ms)
|
||||
{
|
||||
for (auto n : ns)
|
||||
{
|
||||
for (auto k : ks)
|
||||
{
|
||||
for (auto gs : gss)
|
||||
{
|
||||
pass = benchmark<WeightOnlyActivationType::FP16, WeightOnlyQuantType::Int8b>(
|
||||
m, n, k, gs, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark<WeightOnlyActivationType::FP16, WeightOnlyQuantType::Int4b>(
|
||||
m, n, k, gs, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
#if defined(ENABLE_BF16)
|
||||
pass = benchmark<WeightOnlyActivationType::BF16, WeightOnlyQuantType::Int8b>(
|
||||
m, n, k, gs, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark<WeightOnlyActivationType::BF16, WeightOnlyQuantType::Int4b>(
|
||||
m, n, k, gs, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -120,8 +120,10 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> con
|
||||
SizeType constexpr maxInputLength{8};
|
||||
SizeType constexpr maxNewTokens{2};
|
||||
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, modelConfig.getDataType());
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType());
|
||||
|
||||
std::vector<SizeType> const inputLengths{4, 5, 6, 7};
|
||||
std::vector<SizeType> tiledInputLengths;
|
||||
@ -240,8 +242,10 @@ void testDecoderWavefront(
|
||||
SizeType constexpr maxInputLength{8};
|
||||
SizeType constexpr maxNewTokens{8};
|
||||
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, modelConfig.getDataType());
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType());
|
||||
|
||||
std::vector<SizeType> const inputLengths{4, 5, 6, 7};
|
||||
std::vector<SizeType> tiledInputLengths;
|
||||
|
||||
@ -71,7 +71,7 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC
|
||||
auto endIds
|
||||
= std::shared_ptr(manager.copyFrom(endIdsVec, ITensor::makeShape({batchSize, beamWidth}), MemoryType::kGPU));
|
||||
|
||||
DecodingInput inputs{maxInputLength, batchSize, logits, endIds};
|
||||
DecodingInput inputs{maxInputLength, maxSeqLength, batchSize, logits, endIds};
|
||||
std::vector<std::int32_t> sequenceLimitLengthsVec(batchSize, maxSeqLength);
|
||||
inputs.sequenceLimitLength
|
||||
= manager.copyFrom(sequenceLimitLengthsVec, ITensor::makeShape({batchSize}), MemoryType::kGPU);
|
||||
|
||||
@ -739,6 +739,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
|
||||
samplingConfig.randomSeed = std::vector{1ull};
|
||||
samplingConfig.topK = std::vector{1};
|
||||
samplingConfig.topP = std::vector{1.0f};
|
||||
samplingConfig.lengthPenalty = std::vector{1.0f};
|
||||
|
||||
auto const padId = modelIds.padId;
|
||||
auto const endId = modelIds.endId;
|
||||
|
||||
@ -21,12 +21,30 @@
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace ::testing;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace
|
||||
TEST(ITensorTest, SqueezeTensor)
|
||||
{
|
||||
auto dims = ITensor::makeShape({16, 1, 4});
|
||||
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
|
||||
ITensor::SharedPtr tensor{BufferManager::cpu(dims, dataType)};
|
||||
|
||||
TEST(iTensorTest, UnsqueezeShape)
|
||||
auto squeezeDim = 0;
|
||||
EXPECT_THROW(tensor->squeeze(squeezeDim), std::runtime_error);
|
||||
squeezeDim = 1;
|
||||
auto squeezed = ITensor::view(tensor, ITensor::squeeze(dims, squeezeDim));
|
||||
|
||||
EXPECT_EQ(squeezed->getSize(), tensor->getSize());
|
||||
EXPECT_EQ(squeezed->getShape().nbDims, tensor->getShape().nbDims - 1);
|
||||
EXPECT_EQ(squeezed->getShape().d[0], tensor->getShape().d[0]);
|
||||
EXPECT_EQ(squeezed->getShape().d[1], tensor->getShape().d[2]);
|
||||
|
||||
EXPECT_NO_THROW(squeezed->release());
|
||||
EXPECT_EQ(squeezed->data(), nullptr);
|
||||
EXPECT_NE(tensor->data(), nullptr);
|
||||
}
|
||||
|
||||
TEST(ITensorTest, UnsqueezeShape)
|
||||
{
|
||||
auto oldShape = ITensor::makeShape({2, 3, 4, 5});
|
||||
{
|
||||
@ -66,7 +84,7 @@ TEST(iTensorTest, UnsqueezeShape)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto shape = ITensor::unsqueeze(oldShape, invalidDim);
|
||||
ITensor::unsqueeze(oldShape, invalidDim);
|
||||
FAIL() << "Expected failure";
|
||||
}
|
||||
catch (tensorrt_llm::common::TllmException const& e)
|
||||
@ -80,13 +98,12 @@ TEST(iTensorTest, UnsqueezeShape)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(iTensorTest, UnsqueezeTensor)
|
||||
TEST(ITensorTest, UnsqueezeTensor)
|
||||
{
|
||||
auto oldShape = ITensor::makeShape({2, 3, 4, 5});
|
||||
BufferManager manager(std::make_shared<CudaStream>());
|
||||
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(0);
|
||||
auto shape = tensor->getShape();
|
||||
|
||||
@ -98,7 +115,7 @@ TEST(iTensorTest, UnsqueezeTensor)
|
||||
EXPECT_EQ(shape.d[4], 5);
|
||||
}
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(1);
|
||||
auto shape = tensor->getShape();
|
||||
|
||||
@ -111,7 +128,7 @@ TEST(iTensorTest, UnsqueezeTensor)
|
||||
}
|
||||
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(4);
|
||||
auto shape = tensor->getShape();
|
||||
|
||||
@ -128,7 +145,7 @@ TEST(iTensorTest, UnsqueezeTensor)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(invalidDim);
|
||||
FAIL() << "Expected failure";
|
||||
}
|
||||
@ -143,4 +160,61 @@ TEST(iTensorTest, UnsqueezeTensor)
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
TEST(ITensorTest, TensorView)
|
||||
{
|
||||
auto const dims = ITensor::makeShape({16, 1, 4});
|
||||
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
|
||||
ITensor::SharedPtr tensor = BufferManager::cpu(dims, dataType);
|
||||
|
||||
auto const viewDims = ITensor::makeShape({16, 1, 2});
|
||||
|
||||
auto view = ITensor::view(tensor, viewDims);
|
||||
EXPECT_EQ(view->getSize(), tensor->getSize() / 2);
|
||||
EXPECT_EQ(view->getShape().nbDims, tensor->getShape().nbDims);
|
||||
EXPECT_EQ(view->getShape().d[2], tensor->getShape().d[2] / 2);
|
||||
|
||||
EXPECT_NO_THROW(view->release());
|
||||
EXPECT_EQ(view->data(), nullptr);
|
||||
EXPECT_NE(tensor->data(), nullptr);
|
||||
}
|
||||
|
||||
TEST(ITensorTest, TensorSlice)
|
||||
{
|
||||
auto dims = ITensor::makeShape({16, 8, 4});
|
||||
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
|
||||
ITensor::SharedPtr tensor{BufferManager::cpu(dims, dataType)};
|
||||
auto offset = dims.d[0] / 4;
|
||||
auto slice = ITensor::slice(tensor, offset);
|
||||
auto const sizeSlice = 3 * tensor->getSize() / 4;
|
||||
EXPECT_EQ(slice->getShape().d[0], dims.d[0] - offset);
|
||||
EXPECT_EQ(slice->getSize(), sizeSlice);
|
||||
EXPECT_EQ(slice->getCapacity(), sizeSlice);
|
||||
EXPECT_EQ(static_cast<std::uint8_t*>(slice->data()) - static_cast<std::uint8_t*>(tensor->data()),
|
||||
offset * ITensor::volume(dims) / dims.d[0] * BufferDataType(dataType).getSize());
|
||||
|
||||
auto dimsNew = ITensor::makeShape({12, 32});
|
||||
EXPECT_EQ(ITensor::volume(dimsNew), sizeSlice);
|
||||
EXPECT_NO_THROW(slice->reshape(dimsNew));
|
||||
EXPECT_EQ(slice->getShape().d[1], dimsNew.d[1]);
|
||||
dimsNew.d[0] = 6;
|
||||
EXPECT_LT(ITensor::volume(dimsNew), sizeSlice);
|
||||
EXPECT_NO_THROW(slice->reshape(dimsNew));
|
||||
EXPECT_EQ(slice->getShape().d[0], dimsNew.d[0]);
|
||||
dimsNew.d[0] = 16;
|
||||
EXPECT_GT(ITensor::volume(dimsNew), sizeSlice);
|
||||
EXPECT_THROW(slice->reshape(dimsNew), std::runtime_error);
|
||||
|
||||
EXPECT_NO_THROW(slice->resize(sizeSlice));
|
||||
EXPECT_NO_THROW(slice->resize(sizeSlice / 2));
|
||||
EXPECT_EQ(slice->getShape().d[0], sizeSlice / 2);
|
||||
EXPECT_THROW(slice->resize(sizeSlice * 2), std::runtime_error);
|
||||
EXPECT_NO_THROW(slice->release());
|
||||
EXPECT_EQ(slice->data(), nullptr);
|
||||
EXPECT_NE(tensor->data(), nullptr);
|
||||
|
||||
std::shared_ptr<ITensor const> constTensor{tensor};
|
||||
auto constSlice = ITensor::slice(constTensor, offset);
|
||||
EXPECT_EQ(constSlice->getShape().d[0], dims.d[0] - offset);
|
||||
auto uniqueSlice = ITensor::slice(std::move(constSlice), 1);
|
||||
EXPECT_EQ(uniqueSlice->getShape().d[0], dims.d[0] - offset - 1);
|
||||
}
|
||||
|
||||
@ -104,7 +104,7 @@ typename tl::DynamicDecodeLayer<float>::OutputParams dynamicDecodeTest(BufferMan
|
||||
ddLayer.setup(batchSize, beamWidth, setupParams);
|
||||
|
||||
typename tl::DynamicDecodeLayer<float>::ForwardParams forwardParams(
|
||||
step, ite, maxInputLength, localBatchSize, logits, endIds);
|
||||
step, ite, maxInputLength, static_cast<int>(maxSeqLength), localBatchSize, logits, endIds);
|
||||
forwardParams.no_repeat_ngram_size = noRepeatNgramSize;
|
||||
|
||||
typename tl::DynamicDecodeLayer<float>::OutputParams outputParams(outputIds);
|
||||
|
||||
@ -36,9 +36,6 @@ protected:
|
||||
void SetUp() override
|
||||
{
|
||||
mDeviceCount = tc::getDeviceCount();
|
||||
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
}
|
||||
|
||||
void TearDown() override {}
|
||||
@ -48,6 +45,9 @@ protected:
|
||||
|
||||
TEST_F(TllmBuffersTest, Stream)
|
||||
{
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
|
||||
CudaStream stream{};
|
||||
EXPECT_NE(stream.get(), nullptr);
|
||||
auto ptr = std::make_shared<CudaStream>();
|
||||
@ -109,6 +109,9 @@ TEST_F(TllmBuffersTest, HostAllocator)
|
||||
|
||||
TEST_F(TllmBuffersTest, CudaAllocatorAsync)
|
||||
{
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
auto constexpr size = 1024;
|
||||
CudaAllocatorAsync allocator{streamPtr};
|
||||
@ -171,6 +174,9 @@ void testBuffer(IBuffer& buffer, std::int32_t typeSize)
|
||||
|
||||
TEST_F(TllmBuffersTest, DeviceBuffer)
|
||||
{
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
auto constexpr size = 1024;
|
||||
CudaAllocatorAsync allocator{streamPtr};
|
||||
@ -186,6 +192,9 @@ TEST_F(TllmBuffersTest, DeviceBuffer)
|
||||
|
||||
TEST_F(TllmBuffersTest, DeviceTensor)
|
||||
{
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
nvinfer1::Dims constexpr dims{3, 16, 8, 4};
|
||||
CudaAllocatorAsync allocator{streamPtr};
|
||||
@ -228,91 +237,11 @@ TEST_F(TllmBuffersTest, BufferSlice)
|
||||
EXPECT_EQ(uniqueSlice->getSize(), sizeSlice - 1);
|
||||
}
|
||||
|
||||
TEST_F(TllmBuffersTest, TensorSlice)
|
||||
{
|
||||
auto dims = ITensor::makeShape({16, 8, 4});
|
||||
HostAllocator allocator{};
|
||||
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
|
||||
auto tensor = std::make_shared<HostTensor>(dims, dataType, allocator);
|
||||
auto offset = dims.d[0] / 4;
|
||||
auto slice = ITensor::slice(tensor, offset);
|
||||
auto const sizeSlice = 3 * tensor->getSize() / 4;
|
||||
EXPECT_EQ(slice->getShape().d[0], dims.d[0] - offset);
|
||||
EXPECT_EQ(slice->getSize(), sizeSlice);
|
||||
EXPECT_EQ(slice->getCapacity(), sizeSlice);
|
||||
EXPECT_EQ(static_cast<std::uint8_t*>(slice->data()) - static_cast<std::uint8_t*>(tensor->data()),
|
||||
offset * ITensor::volume(dims) / dims.d[0] * BufferDataType(dataType).getSize());
|
||||
|
||||
auto dimsNew = ITensor::makeShape({12, 32});
|
||||
EXPECT_EQ(ITensor::volume(dimsNew), sizeSlice);
|
||||
EXPECT_NO_THROW(slice->reshape(dimsNew));
|
||||
EXPECT_EQ(slice->getShape().d[1], dimsNew.d[1]);
|
||||
dimsNew.d[0] = 6;
|
||||
EXPECT_LT(ITensor::volume(dimsNew), sizeSlice);
|
||||
EXPECT_NO_THROW(slice->reshape(dimsNew));
|
||||
EXPECT_EQ(slice->getShape().d[0], dimsNew.d[0]);
|
||||
dimsNew.d[0] = 16;
|
||||
EXPECT_GT(ITensor::volume(dimsNew), sizeSlice);
|
||||
EXPECT_THROW(slice->reshape(dimsNew), std::runtime_error);
|
||||
|
||||
EXPECT_NO_THROW(slice->resize(sizeSlice));
|
||||
EXPECT_NO_THROW(slice->resize(sizeSlice / 2));
|
||||
EXPECT_EQ(slice->getShape().d[0], sizeSlice / 2);
|
||||
EXPECT_THROW(slice->resize(sizeSlice * 2), std::runtime_error);
|
||||
EXPECT_NO_THROW(slice->release());
|
||||
EXPECT_EQ(slice->data(), nullptr);
|
||||
EXPECT_NE(tensor->data(), nullptr);
|
||||
|
||||
std::shared_ptr<HostTensor const> constTensor{tensor};
|
||||
auto constSlice = ITensor::slice(constTensor, offset);
|
||||
EXPECT_EQ(constSlice->getShape().d[0], dims.d[0] - offset);
|
||||
auto uniqueSlice = ITensor::slice(std::move(constSlice), 1);
|
||||
EXPECT_EQ(uniqueSlice->getShape().d[0], dims.d[0] - offset - 1);
|
||||
}
|
||||
|
||||
TEST_F(TllmBuffersTest, TensorSqueeze)
|
||||
{
|
||||
auto dims = ITensor::makeShape({16, 1, 4});
|
||||
HostAllocator allocator{};
|
||||
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
|
||||
auto tensor = std::make_shared<HostTensor>(dims, dataType, allocator);
|
||||
|
||||
auto squeezeDim = 0;
|
||||
EXPECT_THROW(tensor->squeeze(squeezeDim), std::runtime_error);
|
||||
squeezeDim = 1;
|
||||
auto squeezed = ITensor::view(tensor, ITensor::squeeze(dims, squeezeDim));
|
||||
|
||||
EXPECT_EQ(squeezed->getSize(), tensor->getSize());
|
||||
EXPECT_EQ(squeezed->getShape().nbDims, tensor->getShape().nbDims - 1);
|
||||
EXPECT_EQ(squeezed->getShape().d[0], tensor->getShape().d[0]);
|
||||
EXPECT_EQ(squeezed->getShape().d[1], tensor->getShape().d[2]);
|
||||
|
||||
EXPECT_NO_THROW(squeezed->release());
|
||||
EXPECT_EQ(squeezed->data(), nullptr);
|
||||
EXPECT_NE(tensor->data(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TllmBuffersTest, TensorView)
|
||||
{
|
||||
auto const dims = ITensor::makeShape({16, 1, 4});
|
||||
HostAllocator allocator{};
|
||||
auto constexpr dataType = nvinfer1::DataType::kFLOAT;
|
||||
auto tensor = std::make_shared<HostTensor>(dims, dataType, allocator);
|
||||
|
||||
auto const viewDims = ITensor::makeShape({16, 1, 2});
|
||||
|
||||
auto view = ITensor::view(tensor, viewDims);
|
||||
EXPECT_EQ(view->getSize(), tensor->getSize() / 2);
|
||||
EXPECT_EQ(view->getShape().nbDims, tensor->getShape().nbDims);
|
||||
EXPECT_EQ(view->getShape().d[2], tensor->getShape().d[2] / 2);
|
||||
|
||||
EXPECT_NO_THROW(view->release());
|
||||
EXPECT_EQ(view->data(), nullptr);
|
||||
EXPECT_NE(tensor->data(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TllmBuffersTest, BufferOutput)
|
||||
{
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
CudaAllocatorAsync allocator{streamPtr};
|
||||
for (std::size_t size : {0, 16})
|
||||
@ -331,6 +260,9 @@ TEST_F(TllmBuffersTest, BufferOutput)
|
||||
|
||||
TEST_F(TllmBuffersTest, TensorOutput)
|
||||
{
|
||||
if (mDeviceCount == 0)
|
||||
GTEST_SKIP();
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
nvinfer1::Dims constexpr dims{3, 16, 8, 4};
|
||||
CudaAllocatorAsync allocator{streamPtr};
|
||||
|
||||
@ -164,6 +164,34 @@ the MHA/MQA kernel. The scaling factor to dequantize those values is stored in
|
||||
the `kv_quant_orig_scale` tensor. That tensor contains a single value (per
|
||||
tensor scaling).
|
||||
|
||||
|
||||
## Sliding Window Attention, Cyclic (Rolling Buffer) KV Cache
|
||||
|
||||
TensorRT-LLM has a feature called `Cyclic KV Cache`, which treats the kv cache
|
||||
as a circular buffer. This means that it only stores the kv cache for the last N
|
||||
tokens, where N is determined by the `max_kv_cache_length` parameter in
|
||||
`GenerationSession.setup`. You can see examples of this in the `run.py` or
|
||||
`summarize.py` files. When the cache is full, new tokens’ kv cache will
|
||||
overwrite the "least recently used" caches.
|
||||
|
||||
In the context phase, if the input length surpasses the `max_kv_cache_length`,
|
||||
`Sliding Window Attention` will be activated. This serves the same function as
|
||||
the `sliding window_size`.
|
||||
|
||||
This feature helps to reduce the memory footprint of the kv cache when
|
||||
dealing with very long sequences.
|
||||
|
||||
_Note that when using beam search, cyclic kv cache may not perform as well as
|
||||
full kv cache when the current step exceeds `max_kv_cache_length`.
|
||||
This issue will be addressed in future releases._
|
||||
|
||||
_The experimental feature, which allows different `max_kv_cache_length` values
|
||||
for each layer, is also supported. To utilize this feature, simply provide an
|
||||
`int32 torch.Tensor` with a shape of `[num_layers]` to the `GenerationSession.setup`.
|
||||
This tensor will serve as the buffer for `max_kv_cache_length`,
|
||||
setting unique values for each layer. However, it’s important to note that the
|
||||
memory allocation for the kv cache still relies on the buffer’s maximum value._
|
||||
|
||||
## Beam-Search
|
||||
|
||||
The GPT attention operator supports beam-search. In the context phase, a single
|
||||
|
||||
@ -44,7 +44,8 @@ optional object to log information, warnings and errors:
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
GptSession session(modelConfig, // Description of the model,
|
||||
GptSession session(sessionConfig, // Configuration of the session,
|
||||
modelConfig, // Description of the model,
|
||||
worldConfig, // Description of the environment,
|
||||
engineBuffer, // The compiled TensorRT engine (const void*),
|
||||
engineSize, // The size in bytes of the TensorRT engine (size_t),
|
||||
@ -56,6 +57,35 @@ associated size (in bytes) of that buffer. There exist other overloaded
|
||||
versions that take `std::vector<uint8_t>` or `std::string` arguments to
|
||||
encapsulate the engine.
|
||||
|
||||
#### Session Configuration
|
||||
|
||||
The session configuration is an instance of the
|
||||
[`GptSession::Config`](source:cpp/include/tensorrt_llm/runtime/gptSession.h) class.
|
||||
The constructor of this class requires three arguments:
|
||||
|
||||
* `maxBatchSize`, the maximum number of sequences in a batch,
|
||||
* `maxBeamWidth`, the maximum width of the beams in beam-search,
|
||||
* `maxSequenceLength`, the length of the longest input sequence,
|
||||
|
||||
Additionally, the class encapsulates the following optional parameters
|
||||
(they are declared as public member variables and can be accessed directly):
|
||||
|
||||
* `decoderPerRequest`, whether the session will use a different decoder per
|
||||
request. It must be set to `true` when running in-flight batching,
|
||||
* `cudaGraphMode`, whether the session will use CUDA graphs for the engine
|
||||
execution in generation phase,
|
||||
* `kvCacheConfig` encapsulates parameters to configure paged KV cache, when the paged KV cache is enabled in the engine:
|
||||
* `maxTokens`, the maximum number of tokens that will have to be
|
||||
stored in the paged KV cache,
|
||||
* `freeGpuMemoryFraction`, the fraction of free GPU memory that will be
|
||||
reserved for paged KV cache,
|
||||
* `ctxMicroBatchSize`, the micro batch size to be used in context phase.
|
||||
Batches entered in `GptSession::generation` will be split into smaller
|
||||
micro batches of this size,
|
||||
* `genMicroBatchSize`, the micro batch size to be used in generation phase,
|
||||
Batches entered in `GptSession::generation` will be split into smaller
|
||||
micro batches of this size.
|
||||
|
||||
#### Model Configuration
|
||||
|
||||
The model configuration is an instance of the
|
||||
@ -152,7 +182,7 @@ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
tensorrt_llm::runtime::WorldConfig worldConfig(tensorParallelism, pipelineParallelism, rank);
|
||||
|
||||
// Create the GPT session (as shown above).
|
||||
tensorrt_llm::runtime::GptSession session(modelConfig, worldConfig, ...);
|
||||
tensorrt_llm::runtime::GptSession session(sessionConfig, modelConfig, worldConfig, ...);
|
||||
```
|
||||
|
||||
For simplicity, TensorRT-LLM provides users with the following simplified API:
|
||||
@ -169,22 +199,6 @@ installed on the system (talk to your system administrator if needed):
|
||||
mpirun -n 2 ...
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
***GptSession***
|
||||
|
||||
The `GptSession::setup` member function must be called to prepare the runtime
|
||||
to execute the inference on a batch of input sequences. That member function
|
||||
takes four arguments:
|
||||
|
||||
* `batchSize`, the number of sequences in the batch,
|
||||
* `beamWidth`, the width of the beams in beam-search,
|
||||
* `maxSequenceLength`, the length of the longest input sequence,
|
||||
* `decoderPerRequest`, is the session asked to use a different decoder per
|
||||
request. It must be set to `true` when running in-flight batching,
|
||||
* `maxTokensInPagedKvCache`, the maximum number of tokens that will have to be
|
||||
stored in the KV cache when the paged KV cache is enabled.
|
||||
|
||||
### Generation
|
||||
|
||||
The `GptSession::generate` member function performs the generation loop. Given
|
||||
@ -230,10 +244,10 @@ populates an instance of the
|
||||
sequences). It can be set to the same value as `endId`,
|
||||
* `ids`, is the tensor of input IDs. That tensor must be allocated on the GPU.
|
||||
When the input tensor is padded, the shape of `ids` is `[batchSize,
|
||||
maxInputLength]`, where `batchSize` and `maxInputLength` correspond to the
|
||||
arguments passed to the `GptSession::setup` member function. When the input
|
||||
is packed, the shape of `ids` is `[numTokens]`, where `numTokens` is the sum
|
||||
of the lengths of the different sequences in the batch,
|
||||
maxInputLength]`, where `batchSize` and `maxInputLength` must respect the
|
||||
maximum sizes in `sessionConfig` passed to the `GptSession` constructor.
|
||||
When the input is packed, the shape of `ids` is `[numTokens]`, where
|
||||
`numTokens` is the sum of the lengths of the different sequences in the batch,
|
||||
* `lengths`, is the tensor of input sequence lengths. That tensor must be
|
||||
allocated on the GPU and contain `batchSize` values,
|
||||
* `packed`, indicates if the `ids` tensor is packed or padded. In this
|
||||
|
||||
@ -126,6 +126,11 @@ def parse_arguments():
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'bfloat16', 'float16'])
|
||||
parser.add_argument('--logits_dtype',
|
||||
type=str,
|
||||
default='float32',
|
||||
choices=['float16', 'float32'])
|
||||
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
@ -163,6 +168,14 @@ def parse_arguments():
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
|
||||
)
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument('--visualize', default=False, action='store_true')
|
||||
parser.add_argument('--enable_debug_output',
|
||||
@ -252,6 +265,14 @@ def parse_arguments():
|
||||
default=None,
|
||||
help='Define the max number of tokens supported by the engine')
|
||||
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert not (
|
||||
@ -378,7 +399,8 @@ def build_rank_engine(builder: Builder,
|
||||
dtype=dtype,
|
||||
mlp_hidden_size=args.inter_size,
|
||||
mapping=mapping,
|
||||
quant_mode=args.quant_mode)
|
||||
quant_mode=args.quant_mode,
|
||||
logits_dtype=args.logits_dtype)
|
||||
if args.use_smooth_quant or args.use_weight_only:
|
||||
tensorrt_llm_baichuan = quantize_model(tensorrt_llm_baichuan,
|
||||
args.quant_mode)
|
||||
@ -408,6 +430,7 @@ def build_rank_engine(builder: Builder,
|
||||
elif args.bin_model_dir is not None:
|
||||
load_from_binary(tensorrt_llm_baichuan,
|
||||
args.bin_model_dir,
|
||||
args.model_version,
|
||||
mapping,
|
||||
fp16=(args.dtype == 'float16'),
|
||||
multi_query_mode=False)
|
||||
@ -432,6 +455,8 @@ def build_rank_engine(builder: Builder,
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
if args.use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
@ -514,7 +539,8 @@ def build(rank, args):
|
||||
max_output_len=args.max_output_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
int8=int8_trt_flag,
|
||||
quant_mode=args.quant_mode)
|
||||
quant_mode=args.quant_mode,
|
||||
strongly_typed=args.strongly_typed)
|
||||
engine_name = get_engine_name(model_name, args.dtype, args.world_size,
|
||||
cur_rank)
|
||||
engine = build_rank_engine(builder, builder_config, engine_name,
|
||||
|
||||
@ -104,6 +104,12 @@ def parse_input(input_text: str, input_file: str, tokenizer, end_id: int,
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--model_version',
|
||||
type=str,
|
||||
@ -179,6 +185,7 @@ def generate(
|
||||
output_csv: str = None,
|
||||
output_npy: str = None,
|
||||
tokenizer_dir: str = None,
|
||||
max_kv_cache_len: int = None,
|
||||
num_beams: int = 1,
|
||||
):
|
||||
tensorrt_llm.logger.set_level(log_level)
|
||||
@ -231,7 +238,8 @@ def generate(
|
||||
decoder.setup(input_lengths.size(0),
|
||||
max_input_length,
|
||||
max_output_len,
|
||||
beam_width=num_beams)
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=max_kv_cache_len)
|
||||
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -182,10 +182,12 @@ def main(args):
|
||||
end_id=end_id, pad_id=pad_id, top_k=top_k, num_beams=num_beams)
|
||||
|
||||
with torch.no_grad():
|
||||
tensorrt_llm_baichuan.setup(batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
tensorrt_llm_baichuan.setup(
|
||||
batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
if tensorrt_llm_baichuan.remove_input_padding:
|
||||
output_ids = tensorrt_llm_baichuan.decode_batch(
|
||||
line_encoded, sampling_config)
|
||||
@ -381,6 +383,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--engine_dir', type=str, default='baichuan_outputs')
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
|
||||
@ -206,6 +206,7 @@ def gen_suffix(rank, use_smooth_quant, quant_per_channel):
|
||||
|
||||
def load_from_binary(tensorrt_llm_baichuan: BaichuanForCausalLM,
|
||||
dir_path,
|
||||
model_version,
|
||||
mapping=Mapping(),
|
||||
fp16=False,
|
||||
multi_query_mode=False):
|
||||
@ -313,6 +314,12 @@ def load_from_binary(tensorrt_llm_baichuan: BaichuanForCausalLM,
|
||||
# share input embedding
|
||||
lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin',
|
||||
[vocab_size, n_embd])
|
||||
if model_version.startswith('v2'):
|
||||
# baichuan v2 models use NormHead
|
||||
tensorrt_llm.logger.info(
|
||||
f'Normalizing lm_head.weight for {model_version}')
|
||||
lm_head_weight = lm_head_weight / np.linalg.norm(
|
||||
lm_head_weight, axis=1, keepdims=True)
|
||||
|
||||
if vocab_size % mapping.tp_size != 0:
|
||||
# padding
|
||||
|
||||
@ -158,6 +158,14 @@ def parse_arguments():
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_layernorm_plugin',
|
||||
nargs='?',
|
||||
@ -261,6 +269,13 @@ def parse_arguments():
|
||||
default=False,
|
||||
choices=['float16', 'float32', 'bfloat16'],
|
||||
help="Activates the lookup plugin which enables embedding sharing.")
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.set_level(args.log_level)
|
||||
@ -395,6 +410,8 @@ def build_rank_engine(builder: Builder,
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
@ -476,7 +493,8 @@ def build(rank, args):
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
int8=int8_trt_flag,
|
||||
quant_mode=args.quant_mode)
|
||||
quant_mode=args.quant_mode,
|
||||
strongly_typed=args.strongly_typed)
|
||||
builder_config.trt_builder_config.builder_optimization_level = 1
|
||||
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
|
||||
cur_rank)
|
||||
|
||||
@ -31,6 +31,12 @@ PAD_TOKEN = 3
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='bloom_outputs')
|
||||
parser.add_argument('--tokenizer_dir',
|
||||
@ -91,7 +97,8 @@ if __name__ == '__main__':
|
||||
runtime_mapping)
|
||||
decoder.setup(input_ids.size(0),
|
||||
max_context_length=input_ids.size(1),
|
||||
max_new_tokens=args.max_output_len)
|
||||
max_new_tokens=args.max_output_len,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@ -170,7 +170,8 @@ def main(args):
|
||||
tensorrt_llm_bloom.setup(line_encoded.size(0),
|
||||
max_context_length=line_encoded.size(1),
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
|
||||
output_ids = tensorrt_llm_bloom.decode(
|
||||
line_encoded,
|
||||
@ -358,6 +359,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--engine_dir', type=str, default='bloom_outputs')
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
|
||||
@ -76,6 +76,10 @@ def parse_arguments(args):
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16'])
|
||||
parser.add_argument('--logits_dtype',
|
||||
type=str,
|
||||
default='float32',
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
@ -141,6 +145,14 @@ def parse_arguments(args):
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
|
||||
)
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument('--builder_opt', type=int, default=None)
|
||||
parser.add_argument(
|
||||
@ -413,6 +425,8 @@ def build_rank_engine(builder: Builder,
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
if args.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
if args.paged_kv_cache:
|
||||
|
||||
@ -209,7 +209,8 @@ def main(args):
|
||||
tensorrt_llm_gpt.setup(batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
|
||||
if tensorrt_llm_gpt.remove_input_padding:
|
||||
output_ids = tensorrt_llm_gpt.decode_batch(
|
||||
@ -439,6 +440,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--output_len', type=int, default=100)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--check_accuracy', action='store_true', default=True)
|
||||
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
|
||||
@ -130,6 +130,14 @@ def parse_arguments(args, component):
|
||||
choices=['float16', 'float32', 'bfloat16'],
|
||||
help="Activates the lookup plugin which enables embedding sharing.")
|
||||
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
|
||||
args = parser.parse_args(args)
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
@ -325,7 +333,7 @@ def build(rank, args):
|
||||
cross_attention=(args.component == 'decoder'),
|
||||
has_position_embedding=args.has_position_embedding,
|
||||
has_token_type_embedding=args.has_token_type_embedding,
|
||||
)
|
||||
strongly_typed=args.strongly_typed)
|
||||
|
||||
engine_name = get_engine_name(MODEL_NAME, args.dtype, world_size,
|
||||
cur_rank)
|
||||
|
||||
@ -33,6 +33,7 @@ from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models import quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.profiler import check_gpt_mem_usage
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
from weight import get_scaling_factors # isort:skip
|
||||
@ -219,6 +220,14 @@ def parse_arguments():
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
|
||||
)
|
||||
parser.add_argument('--visualize', default=False, action='store_true')
|
||||
parser.add_argument('--load_by_shard',
|
||||
action='store_true',
|
||||
@ -458,6 +467,8 @@ def build_rank_engine(builder: Builder,
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype,
|
||||
@ -545,6 +556,26 @@ def build(rank, args):
|
||||
assert engine is not None, \
|
||||
f'Failed to build engine for rank {cur_rank}'
|
||||
|
||||
local_num_kv_heads = (args.n_kv_head + args.world_size -
|
||||
1) // args.world_size
|
||||
kv_dtype = str_dtype_to_trt(args.dtype)
|
||||
if args.quant_mode.has_int8_kv_cache():
|
||||
kv_dtype = str_dtype_to_trt('int8')
|
||||
elif args.quant_mode.has_fp8_kv_cache():
|
||||
kv_dtype = str_dtype_to_trt('fp8')
|
||||
check_gpt_mem_usage(
|
||||
engine=engine,
|
||||
kv_dtype=kv_dtype,
|
||||
use_gpt_attention_plugin=args.use_gpt_attention_plugin,
|
||||
paged_kv_cache=args.paged_kv_cache,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
local_num_kv_heads=local_num_kv_heads,
|
||||
head_size=args.n_embd / args.n_head,
|
||||
num_layers=args.n_layer)
|
||||
|
||||
if cur_rank == 0:
|
||||
# Use in-memory timing cache for multiple builder passes.
|
||||
if not args.parallel_build:
|
||||
|
||||
@ -31,6 +31,12 @@ from build import get_engine_name # isort:skip
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='falcon_outputs')
|
||||
parser.add_argument('--tokenizer_dir',
|
||||
@ -216,7 +222,8 @@ def main():
|
||||
decoder.setup(input_ids.size(0),
|
||||
max_context_length=input_ids.size(1),
|
||||
max_new_tokens=args.max_output_len,
|
||||
beam_width=args.num_beams)
|
||||
beam_width=args.num_beams,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@ -208,7 +208,8 @@ def main(args):
|
||||
tensorrt_llm_falcon.setup(batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
|
||||
if tensorrt_llm_falcon.remove_input_padding:
|
||||
output_ids = tensorrt_llm_falcon.decode_batch(
|
||||
@ -413,6 +414,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--engine_dir', type=str, default='falcon_outputs')
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
|
||||
@ -29,6 +29,7 @@ from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models import quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.profiler import check_gpt_mem_usage
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
from weight import load_from_ft, parse_ft_config, check_embedding_share # isort:skip
|
||||
@ -136,6 +137,14 @@ def parse_arguments(args):
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
|
||||
)
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument('--builder_opt', type=int, default=None)
|
||||
parser.add_argument(
|
||||
@ -482,6 +491,8 @@ def build_rank_engine(builder: Builder,
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
if args.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
if args.paged_kv_cache:
|
||||
@ -555,6 +566,7 @@ def build(rank, args):
|
||||
int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or (
|
||||
args.paged_kv_cache == False
|
||||
and args.quant_mode.has_int8_kv_cache())
|
||||
num_kv_heads = 1 if args.multi_query_mode else args.n_head
|
||||
builder_config = builder.create_builder_config(
|
||||
name=MODEL_NAME,
|
||||
precision=args.dtype,
|
||||
@ -563,7 +575,7 @@ def build(rank, args):
|
||||
parallel_build=args.parallel_build,
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
num_kv_heads=1 if args.multi_query_mode else args.n_head,
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=args.n_embd,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
@ -589,6 +601,26 @@ def build(rank, args):
|
||||
cur_rank, args)
|
||||
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
||||
|
||||
local_num_kv_heads = (num_kv_heads + args.world_size -
|
||||
1) // args.world_size
|
||||
kv_dtype = str_dtype_to_trt(args.dtype)
|
||||
if args.quant_mode.has_int8_kv_cache():
|
||||
kv_dtype = str_dtype_to_trt('int8')
|
||||
elif args.quant_mode.has_fp8_kv_cache():
|
||||
kv_dtype = str_dtype_to_trt('fp8')
|
||||
check_gpt_mem_usage(
|
||||
engine=engine,
|
||||
kv_dtype=kv_dtype,
|
||||
use_gpt_attention_plugin=args.use_gpt_attention_plugin,
|
||||
paged_kv_cache=args.paged_kv_cache,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
local_num_kv_heads=local_num_kv_heads,
|
||||
head_size=args.n_embd / args.n_head,
|
||||
num_layers=args.n_layer)
|
||||
|
||||
if cur_rank == 0:
|
||||
# Use in-memory timing cache for multiple builder passes.
|
||||
if not args.parallel_build:
|
||||
|
||||
@ -184,6 +184,12 @@ def print_output(output_ids, input_lengths, sequence_lengths, tokenizer,
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='gpt_outputs')
|
||||
parser.add_argument('--input_text',
|
||||
@ -234,6 +240,7 @@ def generate(
|
||||
output_npy: str = None,
|
||||
tokenizer_path: str = 'gpt2',
|
||||
vocab_file=None,
|
||||
max_kv_cache_len: int = None,
|
||||
num_beams: int = 1,
|
||||
prompt_table: Path = None,
|
||||
tasks: str = None,
|
||||
@ -290,7 +297,8 @@ def generate(
|
||||
decoder.setup(input_lengths.size(0),
|
||||
max_input_length,
|
||||
max_output_len,
|
||||
beam_width=num_beams)
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=max_kv_cache_len)
|
||||
|
||||
ptuning_args = [] if model_config.max_prompt_embedding_table_size == 0 else ptuning_setup(
|
||||
prompt_table, dtype, model_config.hidden_size, tasks, input_ids,
|
||||
|
||||
@ -154,6 +154,8 @@ def main(args):
|
||||
model.cuda()
|
||||
if args.data_type == 'fp16':
|
||||
model.half()
|
||||
elif args.data_type == 'bf16':
|
||||
model.bfloat16()
|
||||
|
||||
def eval_tensorrt_llm(datapoint, eval_type='summarize'):
|
||||
batch_size = len(datapoint)
|
||||
@ -207,7 +209,8 @@ def main(args):
|
||||
tensorrt_llm_gpt.setup(batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
beam_width=num_beams,
|
||||
max_kv_cache_length=args.max_kv_cache_len)
|
||||
|
||||
if tensorrt_llm_gpt.remove_input_padding:
|
||||
outputs = tensorrt_llm_gpt.decode_batch(
|
||||
@ -503,7 +506,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--test_trt_llm', action='store_true')
|
||||
parser.add_argument('--data_type',
|
||||
type=str,
|
||||
choices=['fp32', 'fp16'],
|
||||
choices=['fp32', 'fp16', 'bf16'],
|
||||
default='fp32')
|
||||
parser.add_argument('--dataset_path', type=str, default='')
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
@ -511,6 +514,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--output_len', type=int, default=100)
|
||||
parser.add_argument('--max_kv_cache_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The max kv cache length. \
|
||||
If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \
|
||||
If it is set to None, we will use the max sequence length.')
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user